15
15
"""MONGODB-OIDC Authentication helpers."""
16
16
from __future__ import annotations
17
17
18
- import threading
18
+ import asyncio
19
19
import time
20
20
from dataclasses import dataclass , field
21
21
from typing import TYPE_CHECKING , Any , Mapping , MutableMapping , Optional , Union
36
36
)
37
37
from pymongo .errors import ConfigurationError , OperationFailure
38
38
from pymongo .helpers_shared import _AUTHENTICATION_FAILURE_CODE
39
+ from pymongo .lock import Lock , _async_create_lock
39
40
40
41
if TYPE_CHECKING :
41
42
from pymongo .asynchronous .pool import AsyncConnection
@@ -81,7 +82,7 @@ class _OIDCAuthenticator:
81
82
access_token : Optional [str ] = field (default = None )
82
83
idp_info : Optional [OIDCIdPInfo ] = field (default = None )
83
84
token_gen_id : int = field (default = 0 )
84
- lock : threading . Lock = field (default_factory = threading . Lock )
85
+ lock : Lock = field (default_factory = _async_create_lock )
85
86
last_call_time : float = field (default = 0 )
86
87
87
88
async def reauthenticate (self , conn : AsyncConnection ) -> Optional [Mapping [str , Any ]]:
@@ -164,7 +165,7 @@ async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[s
164
165
# Attempt to authenticate with a JwtStepRequest.
165
166
return await self ._sasl_continue_jwt (conn , start_resp )
166
167
167
- def _get_access_token (self ) -> Optional [str ]:
168
+ async def _get_access_token (self ) -> Optional [str ]:
168
169
properties = self .properties
169
170
cb : Union [None , OIDCCallback ]
170
171
resp : OIDCCallbackResult
@@ -186,7 +187,7 @@ def _get_access_token(self) -> Optional[str]:
186
187
return None
187
188
188
189
if not prev_token and cb is not None :
189
- with self .lock :
190
+ async with self .lock :
190
191
# See if the token was changed while we were waiting for the
191
192
# lock.
192
193
new_token = self .access_token
@@ -196,7 +197,7 @@ def _get_access_token(self) -> Optional[str]:
196
197
# Ensure that we are waiting a min time between callback invocations.
197
198
delta = time .time () - self .last_call_time
198
199
if delta < TIME_BETWEEN_CALLS_SECONDS :
199
- time .sleep (TIME_BETWEEN_CALLS_SECONDS - delta )
200
+ await asyncio .sleep (TIME_BETWEEN_CALLS_SECONDS - delta )
200
201
self .last_call_time = time .time ()
201
202
202
203
if is_human :
@@ -211,7 +212,10 @@ def _get_access_token(self) -> Optional[str]:
211
212
idp_info = self .idp_info ,
212
213
username = self .properties .username ,
213
214
)
214
- resp = cb .fetch (context )
215
+ if not _IS_SYNC :
216
+ resp = await asyncio .get_running_loop ().run_in_executor (None , cb .fetch , context )
217
+ else :
218
+ resp = cb .fetch (context )
215
219
if not isinstance (resp , OIDCCallbackResult ):
216
220
raise ValueError (
217
221
f"Callback result must be of type OIDCCallbackResult, not { type (resp )} "
@@ -253,13 +257,13 @@ async def _sasl_continue_jwt(
253
257
start_payload : dict = bson .decode (start_resp ["payload" ])
254
258
if "issuer" in start_payload :
255
259
self .idp_info = OIDCIdPInfo (** start_payload )
256
- access_token = self ._get_access_token ()
260
+ access_token = await self ._get_access_token ()
257
261
conn .oidc_token_gen_id = self .token_gen_id
258
262
cmd = self ._get_continue_command ({"jwt" : access_token }, start_resp )
259
263
return await self ._run_command (conn , cmd )
260
264
261
265
async def _sasl_start_jwt (self , conn : AsyncConnection ) -> Mapping [str , Any ]:
262
- access_token = self ._get_access_token ()
266
+ access_token = await self ._get_access_token ()
263
267
conn .oidc_token_gen_id = self .token_gen_id
264
268
cmd = self ._get_start_command ({"jwt" : access_token })
265
269
return await self ._run_command (conn , cmd )
0 commit comments