@@ -111,6 +111,8 @@ class OAuthContext:
111
111
# Discovery state for fallback support
112
112
discovery_base_url : str | None = None
113
113
discovery_pathname : str | None = None
114
+ # Optional expected issuer for access tokens (JWT iss claim)
115
+ expected_issuer : str | None = None
114
116
115
117
def get_authorization_base_url (self , server_url : str ) -> str :
116
118
"""Extract base URL by removing path component."""
@@ -126,12 +128,64 @@ def update_token_expiry(self, token: OAuthToken) -> None:
126
128
127
129
def is_token_valid (self ) -> bool :
128
130
"""Check if current token is valid."""
129
- return bool (
131
+ # Basic existence and expiry checks
132
+ basic_valid = bool (
130
133
self .current_tokens
131
134
and self .current_tokens .access_token
132
135
and (not self .token_expiry_time or time .time () <= self .token_expiry_time )
133
136
)
134
137
138
+ if not basic_valid :
139
+ return False
140
+
141
+ # If no expected issuer is configured, behave as before
142
+ if not getattr (self , "expected_issuer" , None ):
143
+ return True
144
+
145
+ # If expected_issuer is set, ensure token issuer matches
146
+ try :
147
+ return self ._token_issuer_matches (self .current_tokens .access_token )
148
+ except Exception :
149
+ # On any parsing issue, treat token as invalid
150
+ logger .exception ("Failed to validate token issuer" )
151
+ return False
152
+
153
+ def _token_issuer_matches (self , token : str ) -> bool :
154
+ """Decode a JWT access token (no signature verification) and compare its 'iss' claim.
155
+
156
+ This performs a safe, minimal check: split the token, base64-decode the payload,
157
+ parse JSON, and compare the 'iss' field to self.expected_issuer. Returns False
158
+ if the token is malformed or the claim is missing/mismatched.
159
+ """
160
+ # JWTs are in the form header.payload.signature
161
+ parts = token .split ("." )
162
+ if len (parts ) < 2 :
163
+ return False
164
+
165
+ payload_b64 = parts [1 ]
166
+
167
+ # Add padding for base64 if necessary
168
+ padding = "=" * (- len (payload_b64 ) % 4 )
169
+ payload_b64 += padding
170
+
171
+ try :
172
+ payload_bytes = base64 .urlsafe_b64decode (payload_b64 .encode ())
173
+ except Exception :
174
+ return False
175
+
176
+ try :
177
+ import json
178
+
179
+ payload = json .loads (payload_bytes )
180
+ except Exception :
181
+ return False
182
+
183
+ iss = payload .get ("iss" )
184
+ if not iss :
185
+ return False
186
+
187
+ return iss == self .expected_issuer
188
+
135
189
def can_refresh_token (self ) -> bool :
136
190
"""Check if token can be refreshed."""
137
191
return bool (self .current_tokens and self .current_tokens .refresh_token and self .client_info )
0 commit comments