1- # Copyright (c) Microsoft Corporation.
2- # Licensed under the MIT License.
3- from azure .identity import DefaultAzureCredential , ManagedIdentityCredential
4- from datetime import datetime , timedelta , timezone
5- from typing import Optional
6- import durabletask .internal .shared as shared
7-
8- # By default, when there's 10minutes left before the token expires, refresh the token
9- class AccessTokenManager :
10- def __init__ (self , refresh_buffer : int = 600 , metadata : Optional [list [tuple [str , str ]]] = None ):
11- self .scope = "https://durabletask.io/.default"
12- self .refresh_buffer = refresh_buffer
13- self ._use_managed_identity = False
14- self ._metadata = metadata
15- self ._client_id = None
16- self ._logger = shared .get_logger ("token_manager" )
17-
18- if metadata : # Ensure metadata is not None
19- for key , value in metadata :
20- if key == "use_managed_identity" :
21- self ._use_managed_identity = value .lower () == "true" # Properly convert string to bool
22- elif key == "client_id" :
23- self ._client_id = value # Directly assign string
24-
25- # Choose the appropriate credential based on use_managed_identity
26- if self ._use_managed_identity :
27- if not self ._client_id :
28- self ._logger .debug ("Using System Assigned Managed Identity for authentication." )
29- self .credential = ManagedIdentityCredential ()
30- else :
31- self ._logger .debug ("Using User Assigned Managed Identity for authentication." )
32- self .credential = ManagedIdentityCredential (client_id = self ._client_id )
33- else :
34- self .credential = DefaultAzureCredential ()
35- self ._logger .debug ("Using Default Azure Credentials for authentication." )
36-
37- self .token = None
38- self .expiry_time = None
39-
40- def get_access_token (self ) -> str :
41- if self .token is None or self .is_token_expired ():
42- self .refresh_token ()
43- return self .token
44-
45- # Checks if the token is expired, or if it will expire in the next "refresh_buffer" seconds.
46- # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
47- # We will grab a new token when there're 30minutes left on the lifespan of the token
48- def is_token_expired (self ) -> bool :
49- if self .expiry_time is None :
50- return True
51- return datetime .now (timezone .utc ) >= (self .expiry_time - timedelta (seconds = self .refresh_buffer ))
52-
53- def refresh_token (self ):
54- new_token = self .credential .get_token (self .scope )
55- self .token = f"Bearer { new_token .token } "
56-
57- # Convert UNIX timestamp to timezone-aware datetime
58- self .expiry_time = datetime .fromtimestamp (new_token .expires_on , tz = timezone .utc )
1+ # Copyright (c) Microsoft Corporation.
2+ # Licensed under the MIT License.
3+ from azure .identity import DefaultAzureCredential , ManagedIdentityCredential
4+ from datetime import datetime , timedelta , timezone
5+ from typing import Optional
6+ import durabletask .internal .shared as shared
7+
8+ # By default, when there's 10minutes left before the token expires, refresh the token
9+ class AccessTokenManager :
10+ def __init__ (self , refresh_buffer : int = 600 , metadata : Optional [list [tuple [str , str ]]] = None ):
11+ self .scope = "https://durabletask.io/.default"
12+ self .refresh_buffer = refresh_buffer
13+ self ._use_managed_identity = False
14+ self ._metadata = metadata
15+ self ._client_id = None
16+ self ._logger = shared .get_logger ("token_manager" )
17+
18+ if metadata : # Ensure metadata is not None
19+ for key , value in metadata :
20+ if key == "use_managed_identity" :
21+ self ._use_managed_identity = value .lower () == "true" # Properly convert string to bool
22+ elif key == "client_id" :
23+ self ._client_id = value # Directly assign string
24+
25+ # Choose the appropriate credential based on use_managed_identity
26+ if self ._use_managed_identity :
27+ if not self ._client_id :
28+ self ._logger .debug ("Using System Assigned Managed Identity for authentication." )
29+ self .credential = ManagedIdentityCredential ()
30+ else :
31+ self ._logger .debug ("Using User Assigned Managed Identity for authentication." )
32+ self .credential = ManagedIdentityCredential (client_id = self ._client_id )
33+ else :
34+ self .credential = DefaultAzureCredential ()
35+ self ._logger .debug ("Using Default Azure Credentials for authentication." )
36+
37+ self .token = None
38+ self .expiry_time = None
39+
40+ def get_access_token (self ) -> str :
41+ if self .token is None or self .is_token_expired ():
42+ self .refresh_token ()
43+ return self .token
44+
45+ # Checks if the token is expired, or if it will expire in the next "refresh_buffer" seconds.
46+ # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
47+ # We will grab a new token when there're 30minutes left on the lifespan of the token
48+ def is_token_expired (self ) -> bool :
49+ if self .expiry_time is None :
50+ return True
51+ return datetime .now (timezone .utc ) >= (self .expiry_time - timedelta (seconds = self .refresh_buffer ))
52+
53+ def refresh_token (self ):
54+ new_token = self .credential .get_token (self .scope )
55+ self .token = f"Bearer { new_token .token } "
56+
57+ # Convert UNIX timestamp to timezone-aware datetime
58+ self .expiry_time = datetime .fromtimestamp (new_token .expires_on , tz = timezone .utc )
5959 self ._logger .debug (f"Token refreshed. Expires at: { self .expiry_time } " )
0 commit comments