@@ -80,6 +80,40 @@ class JWTParameters(BaseModel):
80
80
jwt_signing_key : str | None = Field (default = None , description = "Private key for JWT signing." )
81
81
jwt_lifetime_seconds : int = Field (default = 300 , description = "Lifetime of generated JWT in seconds." )
82
82
83
+ def to_assertion (self , with_audience_fallback : str | None = None ) -> str :
84
+ if self .assertion is not None :
85
+ # Prebuilt JWT (e.g. acquired out-of-band)
86
+ assertion = self .assertion
87
+ else :
88
+ if not self .jwt_signing_key :
89
+ raise OAuthFlowError ("Missing signing key for JWT bearer grant" )
90
+ if not self .issuer :
91
+ raise OAuthFlowError ("Missing issuer for JWT bearer grant" )
92
+ if not self .subject :
93
+ raise OAuthFlowError ("Missing subject for JWT bearer grant" )
94
+
95
+ audience = self .audience if self .audience else with_audience_fallback
96
+ if not audience :
97
+ raise OAuthFlowError ("Missing audience for JWT bearer grant" )
98
+
99
+ now = int (time .time ())
100
+ claims : dict [str , Any ] = {
101
+ "iss" : self .issuer ,
102
+ "sub" : self .subject ,
103
+ "aud" : audience ,
104
+ "exp" : now + self .jwt_lifetime_seconds ,
105
+ "iat" : now ,
106
+ "jti" : str (uuid4 ()),
107
+ }
108
+ claims .update (self .claims or {})
109
+
110
+ assertion = jwt .encode (
111
+ claims ,
112
+ self .jwt_signing_key ,
113
+ algorithm = self .jwt_signing_algorithm or "RS256" ,
114
+ )
115
+ return assertion
116
+
83
117
84
118
class TokenStorage (Protocol ):
85
119
"""Protocol for token storage implementations."""
@@ -111,7 +145,6 @@ class OAuthContext:
111
145
redirect_handler : Callable [[str ], Awaitable [None ]] | None
112
146
callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None
113
147
timeout : float = 300.0
114
- jwt_parameters : JWTParameters | None = None
115
148
116
149
# Discovered metadata
117
150
protected_resource_metadata : ProtectedResourceMetadata | None = None
@@ -213,7 +246,6 @@ def __init__(
213
246
redirect_handler : Callable [[str ], Awaitable [None ]] | None = None ,
214
247
callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None = None ,
215
248
timeout : float = 300.0 ,
216
- jwt_parameters : JWTParameters | None = None ,
217
249
):
218
250
"""Initialize OAuth2 authentication."""
219
251
self .context = OAuthContext (
@@ -223,7 +255,6 @@ def __init__(
223
255
redirect_handler = redirect_handler ,
224
256
callback_handler = callback_handler ,
225
257
timeout = timeout ,
226
- jwt_parameters = jwt_parameters ,
227
258
)
228
259
self ._initialized = False
229
260
@@ -334,16 +365,9 @@ async def _handle_registration_response(self, response: httpx.Response) -> None:
334
365
335
366
async def _perform_authorization (self ) -> httpx .Request :
336
367
"""Perform the authorization flow."""
337
- if "client_credentials" in self .context .client_metadata .grant_types :
338
- token_request = await self ._exchange_token_client_credentials ()
339
- return token_request
340
- elif "urn:ietf:params:oauth:grant-type:jwt-bearer" in self .context .client_metadata .grant_types :
341
- token_request = await self ._exchange_token_jwt_bearer ()
342
- return token_request
343
- else :
344
- auth_code , code_verifier = await self ._perform_authorization_code_grant ()
345
- token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier )
346
- return token_request
368
+ auth_code , code_verifier = await self ._perform_authorization_code_grant ()
369
+ token_request = await self ._exchange_token_authorization_code (auth_code , code_verifier )
370
+ return token_request
347
371
348
372
async def _perform_authorization_code_grant (self ) -> tuple [str , str ]:
349
373
"""Perform the authorization redirect and get auth code."""
@@ -406,21 +430,25 @@ def _get_token_endpoint(self) -> str:
406
430
token_url = urljoin (auth_base_url , "/token" )
407
431
return token_url
408
432
409
- async def _exchange_token_authorization_code (self , auth_code : str , code_verifier : str ) -> httpx .Request :
433
+ async def _exchange_token_authorization_code (
434
+ self , auth_code : str , code_verifier : str , * , token_data : dict [str , Any ] = {}
435
+ ) -> httpx .Request :
410
436
"""Build token exchange request for authorization_code flow."""
411
437
if self .context .client_metadata .redirect_uris is None :
412
438
raise OAuthFlowError ("No redirect URIs provided for authorization code grant" )
413
439
if not self .context .client_info :
414
440
raise OAuthFlowError ("Missing client info" )
415
441
416
442
token_url = self ._get_token_endpoint ()
417
- token_data = {
418
- "grant_type" : "authorization_code" ,
419
- "code" : auth_code ,
420
- "redirect_uri" : str (self .context .client_metadata .redirect_uris [0 ]),
421
- "client_id" : self .context .client_info .client_id ,
422
- "code_verifier" : code_verifier ,
423
- }
443
+ token_data .update (
444
+ {
445
+ "grant_type" : "authorization_code" ,
446
+ "code" : auth_code ,
447
+ "redirect_uri" : str (self .context .client_metadata .redirect_uris [0 ]),
448
+ "client_id" : self .context .client_info .client_id ,
449
+ "code_verifier" : code_verifier ,
450
+ }
451
+ )
424
452
425
453
# Only include resource param if conditions are met
426
454
if self .context .should_include_resource_param (self .context .protocol_version ):
@@ -433,131 +461,6 @@ async def _exchange_token_authorization_code(self, auth_code: str, code_verifier
433
461
"POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
434
462
)
435
463
436
- async def _exchange_token_client_credentials (self ) -> httpx .Request :
437
- """Build token exchange request for client_credentials flow."""
438
- if not self .context .client_info :
439
- raise OAuthFlowError ("Missing client info" )
440
-
441
- token_url = self ._get_token_endpoint ()
442
- token_data = {
443
- "grant_type" : "client_credentials" ,
444
- }
445
-
446
- headers = {"Content-Type" : "application/x-www-form-urlencoded" }
447
-
448
- # Only include resource param if conditions are met
449
- if self .context .should_include_resource_param (self .context .protocol_version ):
450
- token_data ["resource" ] = self .context .get_resource_url () # RFC 8707
451
-
452
- if self .context .client_metadata .scope :
453
- token_data ["scope" ] = self .context .client_metadata .scope
454
-
455
- if self .context .client_metadata .token_endpoint_auth_method == "client_secret_post" :
456
- # Include in request body
457
- if self .context .client_info .client_id :
458
- token_data ["client_id" ] = self .context .client_info .client_id
459
- if self .context .client_info .client_secret :
460
- token_data ["client_secret" ] = self .context .client_info .client_secret
461
- elif self .context .client_metadata .token_endpoint_auth_method == "client_secret_basic" :
462
- # Include as Basic auth header
463
- if not self .context .client_info .client_id :
464
- raise OAuthTokenError ("Missing client_id in Basic auth flow" )
465
- if not self .context .client_info .client_secret :
466
- raise OAuthTokenError ("Missing client_secret in Basic auth flow" )
467
- raw_auth = f"{ self .context .client_info .client_id } :{ self .context .client_info .client_secret } "
468
- headers ["Authorization" ] = f"Basic { base64 .b64encode (raw_auth .encode ()).decode ()} "
469
- elif self .context .client_metadata .token_endpoint_auth_method == "private_key_jwt" :
470
- # Use JWT assertion for client authentication
471
- if not self .context .jwt_parameters :
472
- raise OAuthTokenError ("Missing JWT parameters for private_key_jwt flow" )
473
-
474
- if self .context .jwt_parameters .assertion is not None :
475
- # Prebuilt JWT (e.g. acquired out-of-band)
476
- assertion = self .context .jwt_parameters .assertion
477
- else :
478
- if not self .context .jwt_parameters .jwt_signing_key :
479
- raise OAuthTokenError ("Missing JWT signing key for private_key_jwt flow" )
480
- if not self .context .jwt_parameters .jwt_signing_algorithm :
481
- raise OAuthTokenError ("Missing JWT signing algorithm for private_key_jwt flow" )
482
-
483
- now = int (time .time ())
484
- claims = {
485
- "iss" : self .context .jwt_parameters .issuer ,
486
- "sub" : self .context .jwt_parameters .subject ,
487
- "aud" : self .context .jwt_parameters .audience if self .context .jwt_parameters .audience else token_url ,
488
- "exp" : now + self .context .jwt_parameters .jwt_lifetime_seconds ,
489
- "iat" : now ,
490
- "jti" : str (uuid4 ()),
491
- }
492
- claims .update (self .context .jwt_parameters .claims or {})
493
-
494
- assertion = jwt .encode (
495
- claims ,
496
- self .context .jwt_parameters .jwt_signing_key ,
497
- algorithm = self .context .jwt_parameters .jwt_signing_algorithm or "RS256" ,
498
- )
499
-
500
- # When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
501
- token_data ["client_assertion" ] = assertion
502
- token_data ["client_assertion_type" ] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
503
- # We need to set the audience to the token endpoint, the audience is difference from the one in claims
504
- # it represents the resource server that will validate the token
505
- token_data ["audience" ] = self .context .get_resource_url ()
506
-
507
- return httpx .Request ("POST" , token_url , data = token_data , headers = headers )
508
-
509
- async def _exchange_token_jwt_bearer (self ) -> httpx .Request :
510
- """Build token exchange request for JWT bearer grant."""
511
- if not self .context .client_info :
512
- raise OAuthFlowError ("Missing client info" )
513
- if not self .context .jwt_parameters :
514
- raise OAuthFlowError ("Missing JWT parameters" )
515
-
516
- token_url = self ._get_token_endpoint ()
517
-
518
- if self .context .jwt_parameters .assertion is not None :
519
- # Prebuilt JWT (e.g. acquired out-of-band)
520
- assertion = self .context .jwt_parameters .assertion
521
- else :
522
- if not self .context .jwt_parameters .jwt_signing_key :
523
- raise OAuthFlowError ("Missing signing key for JWT bearer grant" )
524
- if not self .context .jwt_parameters .issuer :
525
- raise OAuthFlowError ("Missing issuer for JWT bearer grant" )
526
- if not self .context .jwt_parameters .subject :
527
- raise OAuthFlowError ("Missing subject for JWT bearer grant" )
528
-
529
- now = int (time .time ())
530
- claims = {
531
- "iss" : self .context .jwt_parameters .issuer ,
532
- "sub" : self .context .jwt_parameters .subject ,
533
- "aud" : token_url ,
534
- "exp" : now + self .context .jwt_parameters .jwt_lifetime_seconds ,
535
- "iat" : now ,
536
- "jti" : str (uuid4 ()),
537
- }
538
- claims .update (self .context .jwt_parameters .claims or {})
539
-
540
- assertion = jwt .encode (
541
- claims ,
542
- self .context .jwt_parameters .jwt_signing_key ,
543
- algorithm = self .context .jwt_parameters .jwt_signing_algorithm or "RS256" ,
544
- )
545
-
546
- token_data = {
547
- "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
548
- "assertion" : assertion ,
549
- }
550
-
551
- if self .context .should_include_resource_param (self .context .protocol_version ):
552
- token_data ["resource" ] = self .context .get_resource_url ()
553
-
554
- if self .context .client_metadata .scope :
555
- token_data ["scope" ] = self .context .client_metadata .scope
556
-
557
- return httpx .Request (
558
- "POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
559
- )
560
-
561
464
async def _handle_token_response (self , response : httpx .Response ) -> None :
562
465
"""Handle token exchange response."""
563
466
if response .status_code != 200 :
@@ -720,3 +623,78 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
720
623
# Retry with new tokens
721
624
self ._add_auth_header (request )
722
625
yield request
626
+
627
+
628
+ class RFC7523OAuthClientProvider (OAuthClientProvider ):
629
+ """OAuth client provider for RFC7532 clients."""
630
+
631
+ jwt_parameters : JWTParameters | None = None
632
+
633
+ def __init__ (
634
+ self ,
635
+ server_url : str ,
636
+ client_metadata : OAuthClientMetadata ,
637
+ storage : TokenStorage ,
638
+ redirect_handler : Callable [[str ], Awaitable [None ]] | None = None ,
639
+ callback_handler : Callable [[], Awaitable [tuple [str , str | None ]]] | None = None ,
640
+ timeout : float = 300.0 ,
641
+ jwt_parameters : JWTParameters | None = None ,
642
+ ) -> None :
643
+ super ().__init__ (server_url , client_metadata , storage , redirect_handler , callback_handler , timeout )
644
+ self .jwt_parameters = jwt_parameters
645
+
646
+ async def _exchange_token_authorization_code (
647
+ self , auth_code : str , code_verifier : str , * , token_data : dict [str , Any ] = {}
648
+ ) -> httpx .Request :
649
+ """Build token exchange request for authorization_code flow."""
650
+ if self .context .client_metadata .token_endpoint_auth_method == "private_key_jwt" :
651
+ self ._add_client_authentication_jwt (token_data = token_data )
652
+ return await super ()._exchange_token_authorization_code (auth_code , code_verifier , token_data = token_data )
653
+
654
+ async def _perform_authorization (self ) -> httpx .Request :
655
+ """Perform the authorization flow."""
656
+ if "urn:ietf:params:oauth:grant-type:jwt-bearer" in self .context .client_metadata .grant_types :
657
+ token_request = await self ._exchange_token_jwt_bearer ()
658
+ return token_request
659
+ else :
660
+ return await super ()._perform_authorization ()
661
+
662
+ def _add_client_authentication_jwt (self , * , token_data : dict [str , Any ]):
663
+ """Add JWT assertion for client authentication to token endpoint parameters."""
664
+ if not self .jwt_parameters :
665
+ raise OAuthTokenError ("Missing JWT parameters for private_key_jwt flow" )
666
+
667
+ token_url = self ._get_token_endpoint ()
668
+ assertion = self .jwt_parameters .to_assertion (with_audience_fallback = token_url )
669
+
670
+ # When using private_key_jwt, in a client_credentials flow, we use RFC 7523 Section 2.2
671
+ token_data ["client_assertion" ] = assertion
672
+ token_data ["client_assertion_type" ] = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer"
673
+ # We need to set the audience to the token endpoint, the audience is difference from the one in claims
674
+ # it represents the resource server that will validate the token
675
+ token_data ["audience" ] = self .context .get_resource_url ()
676
+
677
+ async def _exchange_token_jwt_bearer (self ) -> httpx .Request :
678
+ """Build token exchange request for JWT bearer grant."""
679
+ if not self .context .client_info :
680
+ raise OAuthFlowError ("Missing client info" )
681
+ if not self .jwt_parameters :
682
+ raise OAuthFlowError ("Missing JWT parameters" )
683
+
684
+ token_url = self ._get_token_endpoint ()
685
+ assertion = self .jwt_parameters .to_assertion (with_audience_fallback = token_url )
686
+
687
+ token_data = {
688
+ "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
689
+ "assertion" : assertion ,
690
+ }
691
+
692
+ if self .context .should_include_resource_param (self .context .protocol_version ):
693
+ token_data ["resource" ] = self .context .get_resource_url ()
694
+
695
+ if self .context .client_metadata .scope :
696
+ token_data ["scope" ] = self .context .client_metadata .scope
697
+
698
+ return httpx .Request (
699
+ "POST" , token_url , data = token_data , headers = {"Content-Type" : "application/x-www-form-urlencoded" }
700
+ )
0 commit comments