Skip to content

Commit 2e19f12

Browse files
committed
Adds oidc login
Ref inventree/InvenTree#9333
1 parent 2c30c02 commit 2e19f12

File tree

4 files changed

+129
-5
lines changed

4 files changed

+129
-5
lines changed

inventree/api.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import requests
1414
from requests.auth import HTTPBasicAuth
1515
from requests.exceptions import Timeout
16+
from . import oAuthClient as oauth
1617

1718
logger = logging.getLogger('inventree')
1819

@@ -45,6 +46,9 @@ def __init__(self, host=None, **kwargs):
4546
token - Authentication token (if provided, username/password are ignored)
4647
token-name - Name of the token to use (default = 'inventree-python-client')
4748
use_token_auth - Use token authentication? (default = True)
49+
use_oidc_auth - Use OIDC authentication? (default = False)
50+
oidc_client_id - OIDC client ID (defaults to InvenTree public client)
51+
oidc_scopes - OIDC scopes (default = ['openid', 'g:read'])
4852
verbose - Print extra debug messages (default = False)
4953
strict - Enforce strict HTTPS certificate checking (default = True)
5054
timeout - Set timeout to use (in seconds). Default: 10
@@ -56,6 +60,9 @@ def __init__(self, host=None, **kwargs):
5660
INVENTREE_API_PASSWORD - Password
5761
INVENTREE_API_TOKEN - User access token
5862
INVENTREE_API_TIMEOUT - Timeout value, in seconds
63+
INVENTREE_API_OIDC - Use OIDC
64+
INVENTREE_API_OIDC_CLIENT_ID - OIDC client ID
65+
INVENTREE_API_OIDC_SCOPES - OIDC scopes
5966
"""
6067

6168
self.setHostName(host or os.environ.get('INVENTREE_API_HOST', None))
@@ -68,8 +75,11 @@ def __init__(self, host=None, **kwargs):
6875
self.timeout = kwargs.get('timeout', os.environ.get('INVENTREE_API_TIMEOUT', 10))
6976
self.proxies = kwargs.get('proxies', dict())
7077
self.strict = bool(kwargs.get('strict', True))
78+
self.oidc_client_id = kwargs.get('oidc_client_id', os.environ.get('INVENTREE_API_OIDC_CLIENT_ID', 'zDFnsiRheJIOKNx6aCQ0quBxECg1QBHtVFDPloJ6'))
79+
self.oidc_scopes = kwargs.get('oidc_scopes', os.environ.get('INVENTREE_API_OIDC_SCOPES', ['openid', 'g:read']))
7180

7281
self.use_token_auth = kwargs.get('use_token_auth', True)
82+
self.use_oidc_auth = kwargs.get('use_oidc_auth', os.environ.get('INVENTREE_API_OIDC', None))
7383
self.verbose = kwargs.get('verbose', False)
7484

7585
self.auth = None
@@ -132,9 +142,10 @@ def connect(self):
132142
if not self.testAuth():
133143
raise ConnectionError("Authentication at InvenTree server failed")
134144

135-
if self.use_token_auth:
136-
if not self.token:
137-
self.requestToken()
145+
if self.use_token_auth and not self.token:
146+
self.requestToken()
147+
elif self.use_oidc_auth and not self.token:
148+
self.requestOidcToken()
138149

139150
def constructApiUrl(self, endpoint_url):
140151
"""Construct an API endpoint URL based on the provided API URL.
@@ -273,6 +284,14 @@ def requestToken(self):
273284

274285
return self.token
275286

287+
def requestOidcToken(self):
288+
"""Return authentication token from the server using OIDC."""
289+
client = oauth.OAuthClient(self.base_url, self.oidc_client_id, self.oidc_scopes)
290+
self.token = client._access_token
291+
292+
return self.token
293+
294+
276295
def request(self, api_url, **kwargs):
277296
""" Perform a URL request to the Inventree API """
278297

@@ -316,7 +335,7 @@ def request(self, api_url, **kwargs):
316335
'timeout': kwargs.get('timeout', self.timeout),
317336
}
318337

319-
if self.use_token_auth and self.token:
338+
if (self.use_token_auth or self.use_oidc_auth) and self.token:
320339
headers['AUTHORIZATION'] = f'Token {self.token}'
321340
auth = None
322341
else:

inventree/oAuthClient.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import os
2+
from http.server import BaseHTTPRequestHandler, HTTPServer
3+
from requests_oauthlib import OAuth2Session
4+
import webbrowser
5+
import urllib.parse as urlparse
6+
7+
# Environment setup
8+
os.environ["OAUTHLIB_INSECURE_TRANSPORT"] = "1"
9+
USABLE_PORT_RANGE = (29170, 292180)
10+
11+
12+
class OAuthClient:
13+
def __init__(self, server_url: str = "http://localhost:8000", client_id: str ='', scopes: list[str] = None) -> None:
14+
self.server_url = server_url
15+
self.client_id = client_id
16+
self.scopes = scopes if scopes is not None else []
17+
18+
self._handler_wrapper = RequestHandlerWrapper(self)
19+
self._setup_callback()
20+
self._poll_user()
21+
22+
def get_url(self, path: str) -> str:
23+
"""Get the authorization URL."""
24+
return urlparse.urljoin(self.server_url, path)
25+
26+
def _setup_callback(self):
27+
for port in range(*USABLE_PORT_RANGE):
28+
try:
29+
self.server = HTTPServer(("127.0.0.1", port), self._handler_wrapper.request_handler)
30+
self._port = port
31+
break
32+
except OSError:
33+
continue
34+
else:
35+
raise Exception("No port found.")
36+
37+
def _poll_user(self):
38+
self._session = OAuth2Session(
39+
self.client_id, scope=self.scopes, redirect_uri=f"http://localhost:{self._port}", pkce="S256"
40+
)
41+
auth_url, state = self._session.authorization_url(self.get_url('/o/authorize/'), access_type="offline")
42+
self._state = state
43+
webbrowser.open_new_tab(auth_url)
44+
45+
while not self._handler_wrapper.done:
46+
self.server.handle_request()
47+
if self._handler_wrapper.error:
48+
raise Exception(self._handler_wrapper.error)
49+
50+
def callback(self, callback_url: str):
51+
self._session.fetch_token(self.get_url("/o/token/"), authorization_response=callback_url, include_client_id=True)
52+
self._access_token = self._session.access_token
53+
54+
55+
class RequestHandlerWrapper:
56+
"""Provides callback for OIDC endpint."""
57+
def __init__(self, oauth_client) -> None:
58+
self.done = False
59+
self.error = None
60+
self.client: OAuthClient = oauth_client
61+
62+
@property
63+
def request_handler(self):
64+
wrapper = self
65+
66+
class RequestHandler(BaseHTTPRequestHandler):
67+
def do_GET(self):
68+
parsed_url = urlparse.urlparse(self.path)
69+
if parsed_url.path == "/":
70+
error = urlparse.parse_qs(parsed_url.query).get("error", [None])[0]
71+
if error:
72+
wrapper.error = error
73+
self.send(200)
74+
else:
75+
try:
76+
wrapper.client.callback(self.path)
77+
except OAuthError as e:
78+
wrapper.error = e.message
79+
self.send(400)
80+
else:
81+
self.send(200, 'Success! You can close this window.')
82+
wrapper.done = True
83+
else:
84+
self.send(404)
85+
86+
def send(self, status_code, content=None):
87+
self.send_response(status_code)
88+
if content:
89+
self.wfile.write(content.encode("utf-8"))
90+
else:
91+
self.wfile.write(b"")
92+
self.send_header("Content-type", "text/html")
93+
self.end_headers()
94+
95+
def log_message(self, *args):
96+
pass # Suppress logging
97+
98+
return RequestHandler
99+
100+
class OAuthError(Exception):
101+
"""Exception raised during the OAuth process."""
102+
def __init__(self, message: str) -> None:
103+
self.message = message

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ invoke>=1.4.0
55
coverage>=6.4.1 # Run tests, measure coverage
66
coveralls>=3.3.1
77
Pillow>=9.1.1
8+
requests-oauthlib # Modern auth experience

setup.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,8 @@
3838
),
3939

4040
install_requires=[
41-
"requests>=2.27.0"
41+
"requests>=2.27.0",
42+
"requests-oauthlib"
4243
],
4344

4445
setup_requires=[

0 commit comments

Comments
 (0)