Skip to content

Commit f731c0d

Browse files
committed
Adding accessTokenManager class for refreshing credential token
Signed-off-by: Ryan Lettieri <[email protected]>
1 parent 6df1064 commit f731c0d

File tree

4 files changed

+65
-9
lines changed

4 files changed

+65
-9
lines changed

durabletask/accessTokenManager.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from azure.identity import DefaultAzureCredential
5+
from datetime import datetime, timedelta
6+
7+
class AccessTokenManager:
8+
def __init__(self, scope: str, refresh_buffer: int = 60):
9+
self.scope = scope
10+
self.refresh_buffer = refresh_buffer
11+
self.credential = DefaultAzureCredential()
12+
self.token = None
13+
self.expiry_time = None
14+
15+
def get_access_token(self) -> str:
16+
if self.token is None or self.is_token_expired():
17+
self.refresh_token()
18+
return self.token
19+
20+
def is_token_expired(self) -> bool:
21+
if self.expiry_time is None:
22+
return True
23+
return datetime.utcnow() >= (self.expiry_time - timedelta(seconds=self.refresh_buffer))
24+
25+
def refresh_token(self):
26+
new_token = self.credential.get_token(self.scope)
27+
self.token = f"Bearer {new_token.token}"
28+
self.expiry_time = datetime.utcnow() + timedelta(seconds=new_token.expires_on - int(datetime.utcnow().timestamp()))
29+
print(f"Token refreshed. Expires at: {self.expiry_time}")

durabletask/worker.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import durabletask.internal.orchestrator_service_pb2 as pb
1717
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
1818
import durabletask.internal.shared as shared
19+
from durabletask.accessTokenManager import AccessTokenManager
1920
from durabletask import task
2021

2122
TInput = TypeVar('TInput')
@@ -88,14 +89,36 @@ def __init__(self, *,
8889
metadata: Optional[list[tuple[str, str]]] = None,
8990
log_handler=None,
9091
log_formatter: Optional[logging.Formatter] = None,
91-
secure_channel: bool = False):
92+
secure_channel: bool = False,
93+
access_token_manager: AccessTokenManager = None):
9294
self._registry = _Registry()
9395
self._host_address = host_address if host_address else shared.get_default_host_address()
9496
self._metadata = metadata
9597
self._logger = shared.get_logger("worker", log_handler, log_formatter)
9698
self._shutdown = Event()
9799
self._is_running = False
98100
self._secure_channel = secure_channel
101+
self._access_token_manager = access_token_manager
102+
self.__update_metadata_with_token()
103+
104+
def __update_metadata_with_token(self):
105+
"""
106+
Add or update the `authorization` key in the metadata with the current access token.
107+
"""
108+
if self._access_token_manager is not None:
109+
token = self._access_token_manager.get_access_token()
110+
111+
# Check if "authorization" already exists in the metadata
112+
updated = False
113+
for i, (key, _) in enumerate(self._metadata):
114+
if key == "authorization":
115+
self._metadata[i] = ("authorization", token)
116+
updated = True
117+
break
118+
119+
# If not updated, add a new entry
120+
if not updated:
121+
self._metadata.append(("authorization", token))
99122

100123
def __enter__(self):
101124
return self
@@ -130,6 +153,7 @@ def run_loop():
130153
with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor:
131154
while not self._shutdown.is_set():
132155
try:
156+
self.__update_metadata_with_token()
133157
# send a "Hello" message to the sidecar to ensure that it's listening
134158
stub.Hello(empty_pb2.Empty())
135159

examples/dts/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ export TASKHUB=<taskhubname>
2626
export ENDPOINT=<taskhubEndpoint>
2727
```
2828

29+
5. Since the samples rely on azure identity, ensure the package is installed and up-to-date
30+
31+
```sh
32+
python3 -m pip install azure-identity
33+
```
34+
2935
## Running the examples
3036

3137
With one of the sidecars running, you can simply execute any of the examples in this directory using `python3`:

examples/dts/dts_activity_sequence.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""End-to-end sample that demonstrates how to configure an orchestrator
55
that calls an activity function in a sequence and prints the outputs."""
66
from durabletask import client, task, worker
7-
7+
from durabletask.accessTokenManager import AccessTokenManager
88

99
def hello(ctx: task.ActivityContext, name: str) -> str:
1010
"""Activity function that returns a greeting"""
@@ -47,19 +47,16 @@ def sequence(ctx: task.OrchestrationContext, _):
4747
exit()
4848

4949

50-
default_credential = DefaultAzureCredential()
5150
# Define the scope for Azure Resource Manager (ARM)
5251
arm_scope = "https://durabletask.io/.default"
53-
54-
# Retrieve the access token. Note that this approach doesn't support token refresh and can't be used in production.
55-
access_token = "Bearer " + default_credential.get_token(arm_scope).token
52+
token_manager = AccessTokenManager(scope = arm_scope)
5653

5754
metaData: list[tuple[str, str]] = [
58-
("taskhub", taskhub_name),
59-
("authorization", access_token)
55+
("taskhub", taskhub_name)
6056
]
57+
6158
# configure and start the worker
62-
with worker.TaskHubGrpcWorker(host_address=endpoint, metadata=metaData, secure_channel=True) as w:
59+
with worker.TaskHubGrpcWorker(host_address=endpoint, metadata=metaData, secure_channel=True, access_token_manager=token_manager) as w:
6360
w.add_orchestrator(sequence)
6461
w.add_activity(hello)
6562
w.start()

0 commit comments

Comments
 (0)