Skip to content

Commit ea837d0

Browse files
committed
Updating dts client to refresh token
Signed-off-by: Ryan Lettieri <[email protected]>
1 parent f4f98ee commit ea837d0

File tree

5 files changed

+167
-10
lines changed

5 files changed

+167
-10
lines changed

durabletask/client.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(self, *,
9797
log_handler: Optional[logging.Handler] = None,
9898
log_formatter: Optional[logging.Formatter] = None,
9999
secure_channel: bool = False):
100+
self._metadata = metadata
100101
channel = shared.get_grpc_channel(host_address, metadata, secure_channel=secure_channel)
101102
self._stub = stubs.TaskHubSidecarServiceStub(channel)
102103
self._logger = shared.get_logger("client", log_handler, log_formatter)

examples/dts/dts_activity_sequence.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
import os
2-
from azure.identity import DefaultAzureCredential
3-
41
"""End-to-end sample that demonstrates how to configure an orchestrator
52
that calls an activity function in a sequence and prints the outputs."""
3+
import os
64
from durabletask import client, task
75
from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker
6+
from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient
87
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
98

109
def hello(ctx: task.ActivityContext, name: str) -> str:
@@ -52,18 +51,18 @@ def sequence(ctx: task.OrchestrationContext, _):
5251
arm_scope = "https://durabletask.io/.default"
5352
token_manager = AccessTokenManager(scope = arm_scope)
5453

55-
metaData: list[tuple[str, str]] = [
54+
meta_data: list[tuple[str, str]] = [
5655
("taskhub", taskhub_name)
5756
]
5857

5958
# configure and start the worker
60-
with DurableTaskSchedulerWorker(host_address=endpoint, metadata=metaData, secure_channel=True, access_token_manager=token_manager) as w:
59+
with DurableTaskSchedulerWorker(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) as w:
6160
w.add_orchestrator(sequence)
6261
w.add_activity(hello)
6362
w.start()
6463

6564
# Construct the client and run the orchestrations
66-
c = client.TaskHubGrpcClient(host_address=endpoint, metadata=metaData, secure_channel=True)
65+
c = DurableTaskSchedulerClient(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager)
6766
instance_id = c.schedule_new_orchestration(sequence)
6867
state = c.wait_for_orchestration_completion(instance_id, timeout=45)
6968
if state and state.runtime_status == client.OrchestrationStatus.COMPLETED:

examples/dts/dts_fanout_fanin.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""End-to-end sample that demonstrates how to configure an orchestrator
2+
that a dynamic number activity functions in parallel, waits for them all
3+
to complete, and prints an aggregate summary of the outputs."""
4+
import random
5+
import time
6+
import os
7+
from durabletask import client, task
8+
from durabletask import client, task
9+
from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker
10+
from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient
11+
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
12+
13+
14+
def get_work_items(ctx: task.ActivityContext, _) -> list[str]:
15+
"""Activity function that returns a list of work items"""
16+
# return a random number of work items
17+
count = random.randint(2, 10)
18+
print(f'generating {count} work items...')
19+
return [f'work item {i}' for i in range(count)]
20+
21+
22+
def process_work_item(ctx: task.ActivityContext, item: str) -> int:
23+
"""Activity function that returns a result for a given work item"""
24+
print(f'processing work item: {item}')
25+
26+
# simulate some work that takes a variable amount of time
27+
time.sleep(random.random() * 5)
28+
29+
# return a result for the given work item, which is also a random number in this case
30+
return random.randint(0, 10)
31+
32+
33+
def orchestrator(ctx: task.OrchestrationContext, _):
34+
"""Orchestrator function that calls the 'get_work_items' and 'process_work_item'
35+
activity functions in parallel, waits for them all to complete, and prints
36+
an aggregate summary of the outputs"""
37+
38+
work_items: list[str] = yield ctx.call_activity(get_work_items)
39+
40+
# execute the work-items in parallel and wait for them all to return
41+
tasks = [ctx.call_activity(process_work_item, input=item) for item in work_items]
42+
results: list[int] = yield task.when_all(tasks)
43+
44+
# return an aggregate summary of the results
45+
return {
46+
'work_items': work_items,
47+
'results': results,
48+
'total': sum(results),
49+
}
50+
51+
52+
# Read the environment variable
53+
taskhub_name = os.getenv("TASKHUB")
54+
55+
# Check if the variable exists
56+
if taskhub_name:
57+
print(f"The value of TASKHUB is: {taskhub_name}")
58+
else:
59+
print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use")
60+
print("If you are using windows powershell, run the following: $env:TASKHUB=\"<taskhubname>\"")
61+
print("If you are using bash, run the following: export TASKHUB=\"<taskhubname>\"")
62+
exit()
63+
64+
# Read the environment variable
65+
endpoint = os.getenv("ENDPOINT")
66+
67+
# Check if the variable exists
68+
if endpoint:
69+
print(f"The value of ENDPOINT is: {endpoint}")
70+
else:
71+
print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler")
72+
print("If you are using windows powershell, run the following: $env:ENDPOINT=\"<schedulerEndpoint>\"")
73+
print("If you are using bash, run the following: export ENDPOINT=\"<schedulerEndpoint>\"")
74+
exit()
75+
76+
# Define the scope for Azure Resource Manager (ARM)
77+
arm_scope = "https://durabletask.io/.default"
78+
token_manager = AccessTokenManager(scope = arm_scope)
79+
80+
meta_data: list[tuple[str, str]] = [
81+
("taskhub", taskhub_name)
82+
]
83+
84+
85+
# configure and start the worker
86+
with DurableTaskSchedulerWorker(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) as w:
87+
w.add_orchestrator(orchestrator)
88+
w.add_activity(process_work_item)
89+
w.add_activity(get_work_items)
90+
w.start()
91+
92+
# create a client, start an orchestration, and wait for it to finish
93+
c = DurableTaskSchedulerClient(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager)
94+
instance_id = c.schedule_new_orchestration(orchestrator)
95+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
96+
if state and state.runtime_status == client.OrchestrationStatus.COMPLETED:
97+
print(f'Orchestration completed! Result: {state.serialized_output}')
98+
elif state:
99+
print(f'Orchestration failed: {state.failure_details}')
Lines changed: 61 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,64 @@
1-
from durabletask import TaskHubGrpcClient
1+
from durabletask.client import TaskHubGrpcClient
2+
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
23

34
class DurableTaskSchedulerClient(TaskHubGrpcClient):
4-
def __init__(self, *args, **kwargs):
5+
def __init__(self, *args, access_token_manager: AccessTokenManager, **kwargs):
56
# Initialize the base class
6-
super().__init__(*args, **kwargs)
7+
super().__init__(*args, **kwargs)
8+
self._access_token_manager = access_token_manager
9+
self.__update_metadata_with_token()
10+
11+
def __update_metadata_with_token(self):
12+
"""
13+
Add or update the `authorization` key in the metadata with the current access token.
14+
"""
15+
if self._access_token_manager is not None:
16+
token = self._access_token_manager.get_access_token()
17+
18+
# Check if "authorization" already exists in the metadata
19+
updated = False
20+
for i, (key, _) in enumerate(self._metadata):
21+
if key == "authorization":
22+
self._metadata[i] = ("authorization", token)
23+
updated = True
24+
break
25+
26+
# If not updated, add a new entry
27+
if not updated:
28+
self._metadata.append(("authorization", token))
29+
30+
def schedule_new_orchestration(self, *args, **kwargs) -> str:
31+
self.__update_metadata_with_token()
32+
return super().schedule_new_orchestration(*args, **kwargs)
33+
34+
def get_orchestration_state(self, *args, **kwargs):
35+
self.__update_metadata_with_token()
36+
super().get_orchestration_state(*args, **kwargs)
37+
38+
def wait_for_orchestration_start(self, *args, **kwargs):
39+
self.__update_metadata_with_token()
40+
super().wait_for_orchestration_start(*args, **kwargs)
41+
42+
def wait_for_orchestration_completion(self, *args, **kwargs):
43+
self.__update_metadata_with_token()
44+
super().wait_for_orchestration_completion(*args, **kwargs)
45+
46+
def raise_orchestration_event(self, *args, **kwargs):
47+
self.__update_metadata_with_token()
48+
super().raise_orchestration_event(*args, **kwargs)
49+
50+
def terminate_orchestration(self, *args, **kwargs):
51+
self.__update_metadata_with_token()
52+
super().terminate_orchestration(*args, **kwargs)
53+
54+
def suspend_orchestration(self, *args, **kwargs):
55+
self.__update_metadata_with_token()
56+
super().suspend_orchestration(*args, **kwargs)
57+
58+
def resume_orchestration(self, *args, **kwargs):
59+
self.__update_metadata_with_token()
60+
super().resume_orchestration(*args, **kwargs)
61+
62+
def purge_orchestration(self, *args, **kwargs):
63+
self.__update_metadata_with_token()
64+
super().purge_orchestration(*args, **kwargs)

externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
1111

1212
class DurableTaskSchedulerWorker(TaskHubGrpcWorker):
13-
def __init__(self, *args, access_token_manager: AccessTokenManager = None, **kwargs):
13+
def __init__(self, *args, access_token_manager: AccessTokenManager, **kwargs):
1414
# Initialize the base class
1515
super().__init__(*args, **kwargs)
1616
self._access_token_manager = access_token_manager

0 commit comments

Comments
 (0)