Skip to content

Commit 83a4fcd

Browse files
bartvinmantaci
authored andcommitted
Add JWT based authentication support (PR #7570)
Pull request opened by the merge tool on behalf of #7570
1 parent ff1a913 commit 83a4fcd

File tree

14 files changed

+329
-140
lines changed

14 files changed

+329
-140
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
description: Add JWT based authentication support
2+
change-type: minor
3+
destination-branches: [master, iso7]
4+
sections:
5+
minor-improvement: Update auth to be able to authenticate against a provided JWT

src/inmanta/data/model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -729,6 +729,12 @@ class User(BaseModel):
729729
auth_method: AuthMethod
730730

731731

732+
class CurrentUser(BaseModel):
733+
"""Information about the current logged in user"""
734+
735+
username: str
736+
737+
732738
class LoginReturn(BaseModel):
733739
"""
734740
Login information

src/inmanta/protocol/auth.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def encode_token(
162162
return jwt.encode(payload=payload, key=cfg.key, algorithm=cfg.algo)
163163

164164

165-
def decode_token(token: str) -> claim_type:
165+
def decode_token(token: str) -> tuple[claim_type, "AuthJWTConfig"]:
166166
try:
167167
# First decode the token without verification
168168
header = jwt.get_unverified_header(token)
@@ -200,6 +200,7 @@ def decode_token(token: str) -> claim_type:
200200
try:
201201
# copy the payload and make sure the type is claim_type
202202
decoded_payload: MutableMapping[str, str | Sequence[str]] = {}
203+
unsupported = []
203204
for k, v in jwt.decode(token, key, audience=cfg.audience, algorithms=[cfg.algo]).items():
204205
match v:
205206
case str():
@@ -213,19 +214,23 @@ def decode_token(token: str) -> claim_type:
213214
)
214215
decoded_payload[k] = v
215216
case _:
216-
logging.getLogger(__name__).info(
217-
"Only claims of type string or list of strings are supported. %s is filtered out.", k
218-
)
217+
unsupported.append(k)
218+
219+
if unsupported:
220+
logging.getLogger(__name__).debug(
221+
"Only claims of type string or list of strings are supported. %s are filtered out.", ", ".join(unsupported)
222+
)
219223

220224
ct_key = const.INMANTA_URN + "ct"
221-
decoded_payload[ct_key] = [x.strip() for x in str(payload[ct_key]).split(",")]
225+
ct_value = str(payload.get(ct_key, "api"))
226+
decoded_payload[ct_key] = [x.strip() for x in ct_value.split(",")]
222227
except Exception as e:
223228
raise exceptions.Forbidden(*e.args)
224229

225230
if not check_custom_claims(claims=decoded_payload, claim_constraints=cfg.claims):
226231
raise exceptions.Forbidden("The configured claims constraints did not match. See logs for details.")
227232

228-
return decoded_payload
233+
return decoded_payload, cfg
229234

230235

231236
#############################
@@ -325,6 +330,13 @@ def __init__(self, name: str, section: str, config: configparser.SectionProxy) -
325330
self.keys: dict[str, bytes] = {}
326331
self._config: configparser.SectionProxy = config
327332
self.claims: list[ClaimMatch] = []
333+
334+
self.jwt_username_claim: str = "sub"
335+
self.expire: int = 0
336+
self.sign: bool = False
337+
self.issuer: str = "https://localhost:8888/"
338+
self.audience: str
339+
328340
if "algorithm" not in config:
329341
raise ValueError("algorithm is required in %s section" % self.section)
330342

@@ -344,8 +356,6 @@ def validate_generic(self) -> None:
344356
"""
345357
if "sign" in self._config:
346358
self.sign = config.is_bool(self._config["sign"])
347-
else:
348-
self.sign = False
349359

350360
if "client_types" not in self._config:
351361
raise ValueError("client_types is a required options for %s" % self.section)
@@ -357,13 +367,9 @@ def validate_generic(self) -> None:
357367

358368
if "expire" in self._config:
359369
self.expire = config.is_int(self._config["expire"])
360-
else:
361-
self.expire = 0
362370

363371
if "issuer" in self._config:
364372
self.issuer = config.is_str(self._config["issuer"])
365-
else:
366-
self.issuer = "https://localhost:8888/"
367373

368374
if "audience" in self._config:
369375
self.audience = config.is_str(self._config["audience"])
@@ -373,6 +379,11 @@ def validate_generic(self) -> None:
373379
if "claims" in self._config:
374380
self.parse_claim_matching(self._config["claims"])
375381

382+
if "jwt-username-claim" in self._config:
383+
if self.sign:
384+
raise ValueError(f"auth config {self.section} used for signing cannot use a custom claim.")
385+
self.jwt_username_claim = self._config.get("jwt-username-claim")
386+
376387
def parse_claim_matching(self, claim_conf: str) -> None:
377388
"""Parse claim matching expressions"""
378389
items = re.findall(AUTH_JWT_CLAIM_RE, claim_conf, re.MULTILINE)

src/inmanta/protocol/common.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
from inmanta import const, execute, types, util
4646
from inmanta.data.model import BaseModel, DateTimeNormalizerModel
47+
from inmanta.protocol import auth
4748
from inmanta.protocol.exceptions import BadRequest, BaseHttpException
4849
from inmanta.protocol.openapi import model as openapi_model
4950
from inmanta.stable_api import stable_api
@@ -66,6 +67,21 @@
6667
HTML_CONTENT_WITH_UTF8_CHARSET = f"{HTML_CONTENT}; {UTF8_CHARSET}"
6768

6869

70+
class CallContext:
71+
"""A context variable that provides more information about the current call context"""
72+
73+
request_headers: dict[str, str]
74+
auth_token: Optional[auth.claim_type]
75+
auth_username: Optional[str]
76+
77+
def __init__(
78+
self, request_headers: dict[str, str], auth_token: Optional[auth.claim_type], auth_username: Optional[str]
79+
) -> None:
80+
self.request_headers = request_headers
81+
self.auth_token = auth_token
82+
self.auth_username = auth_username
83+
84+
6985
class ArgOption:
7086
"""
7187
Argument options to transform arguments before dispatch
@@ -574,7 +590,6 @@ def _validate_type_arg(
574590
:param allow_none_type: If true, allow `None` as the type for this argument
575591
:param in_url: This argument is passed in the URL
576592
"""
577-
578593
if typing_inspect.is_new_type(arg_type):
579594
return self._validate_type_arg(
580595
arg,
@@ -656,6 +671,8 @@ def _validate_type_arg(
656671
elif allow_none_type and types.issubclass(arg_type, type(None)):
657672
# A check for optional arguments
658673
pass
674+
elif issubclass(arg_type, CallContext):
675+
raise InvalidMethodDefinition("CallContext should only be defined in the handler, not the method.")
659676
else:
660677
valid_types = ", ".join([x.__name__ for x in VALID_SIMPLE_ARG_TYPES])
661678
raise InvalidMethodDefinition(

src/inmanta/protocol/endpoints.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,9 @@ async def dispatch_method(self, transport: client.RESTClient, method_call: commo
288288
else:
289289
body[key] = [v.decode("latin-1") for v in value]
290290

291-
response: common.Response = await transport._execute_call(kwargs, method_call.method, config, body, method_call.headers)
291+
body.update(kwargs)
292+
293+
response: common.Response = await transport._execute_call(config, body, method_call.headers)
292294

293295
if response.status_code == 500:
294296
msg = ""

src/inmanta/protocol/methods_v2.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,14 @@ def list_users() -> list[model.User]:
14891489
:return: A list of all users"""
14901490

14911491

1492+
@typedmethod(path="/current_user", operation="GET", client_types=[ClientType.api], api_version=2)
1493+
def get_current_user() -> model.CurrentUser:
1494+
"""Get the current logged in user (based on the provided JWT) and server auth settings
1495+
1496+
:raises NotFound: Raised when server authentication is not enabled
1497+
"""
1498+
1499+
14921500
@typedmethod(path="/user/<username>", operation="DELETE", client_types=[ClientType.api], api_version=2)
14931501
def delete_user(username: str) -> None:
14941502
"""Delete a user from the system with given username.

0 commit comments

Comments
 (0)