@@ -162,7 +162,7 @@ def encode_token(
162162 return jwt .encode (payload = payload , key = cfg .key , algorithm = cfg .algo )
163163
164164
165- def decode_token (token : str ) -> claim_type :
165+ def decode_token (token : str ) -> tuple [ claim_type , "AuthJWTConfig" ] :
166166 try :
167167 # First decode the token without verification
168168 header = jwt .get_unverified_header (token )
@@ -200,6 +200,7 @@ def decode_token(token: str) -> claim_type:
200200 try :
201201 # copy the payload and make sure the type is claim_type
202202 decoded_payload : MutableMapping [str , str | Sequence [str ]] = {}
203+ unsupported = []
203204 for k , v in jwt .decode (token , key , audience = cfg .audience , algorithms = [cfg .algo ]).items ():
204205 match v :
205206 case str ():
@@ -213,19 +214,23 @@ def decode_token(token: str) -> claim_type:
213214 )
214215 decoded_payload [k ] = v
215216 case _:
216- logging .getLogger (__name__ ).info (
217- "Only claims of type string or list of strings are supported. %s is filtered out." , k
218- )
217+ unsupported .append (k )
218+
219+ if unsupported :
220+ logging .getLogger (__name__ ).debug (
221+ "Only claims of type string or list of strings are supported. %s are filtered out." , ", " .join (unsupported )
222+ )
219223
220224 ct_key = const .INMANTA_URN + "ct"
221- decoded_payload [ct_key ] = [x .strip () for x in str (payload [ct_key ]).split ("," )]
225+ ct_value = str (payload .get (ct_key , "api" ))
226+ decoded_payload [ct_key ] = [x .strip () for x in ct_value .split ("," )]
222227 except Exception as e :
223228 raise exceptions .Forbidden (* e .args )
224229
225230 if not check_custom_claims (claims = decoded_payload , claim_constraints = cfg .claims ):
226231 raise exceptions .Forbidden ("The configured claims constraints did not match. See logs for details." )
227232
228- return decoded_payload
233+ return decoded_payload , cfg
229234
230235
231236#############################
@@ -325,6 +330,13 @@ def __init__(self, name: str, section: str, config: configparser.SectionProxy) -
325330 self .keys : dict [str , bytes ] = {}
326331 self ._config : configparser .SectionProxy = config
327332 self .claims : list [ClaimMatch ] = []
333+
334+ self .jwt_username_claim : str = "sub"
335+ self .expire : int = 0
336+ self .sign : bool = False
337+ self .issuer : str = "https://localhost:8888/"
338+ self .audience : str
339+
328340 if "algorithm" not in config :
329341 raise ValueError ("algorithm is required in %s section" % self .section )
330342
@@ -344,8 +356,6 @@ def validate_generic(self) -> None:
344356 """
345357 if "sign" in self ._config :
346358 self .sign = config .is_bool (self ._config ["sign" ])
347- else :
348- self .sign = False
349359
350360 if "client_types" not in self ._config :
351361 raise ValueError ("client_types is a required options for %s" % self .section )
@@ -357,13 +367,9 @@ def validate_generic(self) -> None:
357367
358368 if "expire" in self ._config :
359369 self .expire = config .is_int (self ._config ["expire" ])
360- else :
361- self .expire = 0
362370
363371 if "issuer" in self ._config :
364372 self .issuer = config .is_str (self ._config ["issuer" ])
365- else :
366- self .issuer = "https://localhost:8888/"
367373
368374 if "audience" in self ._config :
369375 self .audience = config .is_str (self ._config ["audience" ])
@@ -373,6 +379,11 @@ def validate_generic(self) -> None:
373379 if "claims" in self ._config :
374380 self .parse_claim_matching (self ._config ["claims" ])
375381
382+ if "jwt-username-claim" in self ._config :
383+ if self .sign :
384+ raise ValueError (f"auth config { self .section } used for signing cannot use a custom claim." )
385+ self .jwt_username_claim = self ._config .get ("jwt-username-claim" )
386+
376387 def parse_claim_matching (self , claim_conf : str ) -> None :
377388 """Parse claim matching expressions"""
378389 items = re .findall (AUTH_JWT_CLAIM_RE , claim_conf , re .MULTILINE )
0 commit comments