Skip to content

Commit 2c251ea

Browse files
committed
Passing token credential as an argument rather than 2 strings
Signed-off-by: Ryan Lettieri <[email protected]>
1 parent ba1ac4f commit 2c251ea

File tree

5 files changed

+32
-52
lines changed

5 files changed

+32
-52
lines changed

durabletask-azuremanaged/durabletask/azuremanaged/client.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from durabletask.client import TaskHubGrpcClient, OrchestrationStatus
66
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
77
from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl
8-
from azure.identity import DefaultAzureCredential
8+
from azure.core.credentials import TokenCredential
99

1010
# Client class used for Durable Task Scheduler (DTS)
1111
class DurableTaskSchedulerClient(TaskHubGrpcClient):
@@ -14,8 +14,7 @@ def __init__(self, *,
1414
taskhub: str,
1515
secure_channel: Optional[bool] = True,
1616
metadata: Optional[list[tuple[str, str]]] = None,
17-
use_managed_identity: Optional[bool] = False,
18-
client_id: Optional[str] = None):
17+
token_credential: Optional[TokenCredential] = None):
1918

2019
if taskhub == None:
2120
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
@@ -27,14 +26,7 @@ def __init__(self, *,
2726
# Append DurableTask-specific metadata
2827
self._metadata.append(("taskhub", taskhub))
2928
self._metadata.append(("dts", "True"))
30-
self._metadata.append(("use_managed_identity", str(use_managed_identity)))
31-
self._metadata.append(("client_id", str(client_id or "None")))
32-
33-
self._access_token_manager = AccessTokenManager(use_managed_identity=use_managed_identity,
34-
client_id=client_id)
35-
token = self._access_token_manager.get_access_token()
36-
self._metadata.append(("authorization", token))
37-
29+
self._metadata.append(("token_credential", token_credential))
3830
self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)]
3931

4032
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class

durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,18 @@ class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
1313

1414
def __init__(self, metadata: list[tuple[str, str]]):
1515
super().__init__(metadata)
16-
17-
use_managed_identity = False
18-
client_id = None
16+
17+
self._token_credential = None
1918

2019
# Check what authentication we are using
2120
if metadata:
2221
for key, value in metadata:
23-
if key.lower() == "use_managed_identity":
24-
self.use_managed_identity = value.strip().lower() == "true" # Convert to boolean
25-
elif key.lower() == "client_id":
26-
self.client_id = value
22+
if key.lower() == "token_credential":
23+
self._token_credential = value
2724

28-
self._token_manager = AccessTokenManager(use_managed_identity=use_managed_identity,
29-
client_id=client_id)
25+
self._token_manager = AccessTokenManager(token_credential=self._token_credential)
26+
token = self._token_manager.get_access_token()
27+
self._metadata.append(("authorization", token))
3028

3129
def _intercept_call(
3230
self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails:

durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,47 +4,43 @@
44
from datetime import datetime, timedelta, timezone
55
from typing import Optional
66
import durabletask.internal.shared as shared
7+
from azure.core.credentials import TokenCredential
78

89
# By default, when there's 10minutes left before the token expires, refresh the token
910
class AccessTokenManager:
10-
def __init__(self, refresh_interval_seconds: int = 600, use_managed_identity: bool = False, client_id: str = None):
11-
self.scope = "https://durabletask.io/.default"
12-
self.refresh_interval_seconds = refresh_interval_seconds
13-
self._use_managed_identity = use_managed_identity
14-
self._client_id = client_id
11+
def __init__(self, refresh_interval_seconds: int = 600, token_credential: TokenCredential = None):
12+
self._scope = "https://durabletask.io/.default"
13+
self._refresh_interval_seconds = refresh_interval_seconds
1514
self._logger = shared.get_logger("token_manager")
1615

17-
# Choose the appropriate credential based on use_managed_identity
18-
if self._use_managed_identity:
19-
if not self._client_id:
20-
self._logger.debug("Using System Assigned Managed Identity for authentication.")
21-
self.credential = ManagedIdentityCredential()
22-
else:
23-
self._logger.debug("Using User Assigned Managed Identity for authentication.")
24-
self.credential = ManagedIdentityCredential(client_id=self._client_id)
16+
# Choose the appropriate credential.
17+
# Both TokenCredential and DefaultAzureCredential get_token methods return an AccessToken
18+
if token_credential:
19+
self._logger.debug("Using user provided token credentials.")
20+
self._credential = token_credential
2521
else:
26-
self.credential = DefaultAzureCredential()
22+
self._credential = DefaultAzureCredential()
2723
self._logger.debug("Using Default Azure Credentials for authentication.")
2824

29-
self.token = None
25+
self._token = self._credential.get_token(self._scope)
3026
self.expiry_time = None
3127

3228
def get_access_token(self) -> str:
33-
if self.token is None or self.is_token_expired():
29+
if self._token is None or self.is_token_expired():
3430
self.refresh_token()
35-
return self.token
31+
return self._token
3632

3733
# Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds.
3834
# For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes,
3935
# We will grab a new token when there're 30minutes left on the lifespan of the token
4036
def is_token_expired(self) -> bool:
4137
if self.expiry_time is None:
4238
return True
43-
return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_interval_seconds))
39+
return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self._refresh_interval_seconds))
4440

4541
def refresh_token(self):
46-
new_token = self.credential.get_token(self.scope)
47-
self.token = f"Bearer {new_token.token}"
42+
new_token = self._credential.get_token(self._scope)
43+
self._token = f"Bearer {new_token.token}"
4844

4945
# Convert UNIX timestamp to timezone-aware datetime
5046
self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc)

durabletask-azuremanaged/durabletask/azuremanaged/worker.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55
from durabletask.worker import TaskHubGrpcWorker
66
from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager
77
from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl
8+
from azure.core.credentials import TokenCredential
89

910
# Worker class used for Durable Task Scheduler (DTS)
1011
class DurableTaskSchedulerWorker(TaskHubGrpcWorker):
1112
def __init__(self, *,
1213
host_address: str,
1314
taskhub: str,
14-
secure_channel: bool,
15+
secure_channel: Optional[bool] = True,
1516
metadata: Optional[list[tuple[str, str]]] = None,
16-
use_managed_identity: Optional[bool] = False,
17-
client_id: Optional[str] = None):
17+
token_credential: Optional[TokenCredential] = None):
1818

1919
if taskhub == None:
2020
raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub")
@@ -24,15 +24,9 @@ def __init__(self, *,
2424
self._metadata = metadata.copy() # Copy to prevent modifying input
2525

2626
# Append DurableTask-specific metadata
27-
self._metadata.append(("taskhub", taskhub or "default-taskhub"))
27+
self._metadata.append(("taskhub", taskhub))
2828
self._metadata.append(("dts", "True"))
29-
self._metadata.append(("use_managed_identity", str(use_managed_identity)))
30-
self._metadata.append(("client_id", str(client_id or "None")))
31-
32-
self._access_token_manager = AccessTokenManager(use_managed_identity=use_managed_identity,
33-
client_id=client_id)
34-
token = self._access_token_manager.get_access_token()
35-
self._metadata.append(("authorization", token))
29+
self._metadata.append(("token_credential", token_credential))
3630
interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)]
3731

3832
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class

examples/dts/dts_activity_sequence.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def sequence(ctx: task.OrchestrationContext, _):
4747

4848

4949
# configure and start the worker
50-
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, use_managed_identity=False, client_id="", taskhub=taskhub_name) as w:
50+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) as w:
5151
w.add_orchestrator(sequence)
5252
w.add_activity(hello)
5353
w.start()

0 commit comments

Comments
 (0)