Skip to content

Commit f8d79d3

Browse files
committed
Cleaning up construction of dts objects and improving examples
Signed-off-by: Ryan Lettieri <[email protected]>
1 parent ea837d0 commit f8d79d3

File tree

5 files changed

+170
-150
lines changed

5 files changed

+170
-150
lines changed

examples/dts/dts_activity_sequence.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,14 @@ def sequence(ctx: task.OrchestrationContext, _):
4747
exit()
4848

4949

50-
# Define the scope for Azure Resource Manager (ARM)
51-
arm_scope = "https://durabletask.io/.default"
52-
token_manager = AccessTokenManager(scope = arm_scope)
53-
54-
meta_data: list[tuple[str, str]] = [
55-
("taskhub", taskhub_name)
56-
]
57-
5850
# configure and start the worker
59-
with DurableTaskSchedulerWorker(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) as w:
51+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, client_id="", taskhub=taskhub_name) as w:
6052
w.add_orchestrator(sequence)
6153
w.add_activity(hello)
6254
w.start()
6355

6456
# Construct the client and run the orchestrations
65-
c = DurableTaskSchedulerClient(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager)
57+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name)
6658
instance_id = c.schedule_new_orchestration(sequence)
6759
state = c.wait_for_orchestration_completion(instance_id, timeout=45)
6860
if state and state.runtime_status == client.OrchestrationStatus.COMPLETED:

examples/dts/dts_fanout_fanin.py

Lines changed: 90 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,90 @@
1-
"""End-to-end sample that demonstrates how to configure an orchestrator
2-
that a dynamic number activity functions in parallel, waits for them all
3-
to complete, and prints an aggregate summary of the outputs."""
4-
import random
5-
import time
6-
import os
7-
from durabletask import client, task
8-
from durabletask import client, task
9-
from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker
10-
from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient
11-
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
12-
13-
14-
def get_work_items(ctx: task.ActivityContext, _) -> list[str]:
15-
"""Activity function that returns a list of work items"""
16-
# return a random number of work items
17-
count = random.randint(2, 10)
18-
print(f'generating {count} work items...')
19-
return [f'work item {i}' for i in range(count)]
20-
21-
22-
def process_work_item(ctx: task.ActivityContext, item: str) -> int:
23-
"""Activity function that returns a result for a given work item"""
24-
print(f'processing work item: {item}')
25-
26-
# simulate some work that takes a variable amount of time
27-
time.sleep(random.random() * 5)
28-
29-
# return a result for the given work item, which is also a random number in this case
30-
return random.randint(0, 10)
31-
32-
33-
def orchestrator(ctx: task.OrchestrationContext, _):
34-
"""Orchestrator function that calls the 'get_work_items' and 'process_work_item'
35-
activity functions in parallel, waits for them all to complete, and prints
36-
an aggregate summary of the outputs"""
37-
38-
work_items: list[str] = yield ctx.call_activity(get_work_items)
39-
40-
# execute the work-items in parallel and wait for them all to return
41-
tasks = [ctx.call_activity(process_work_item, input=item) for item in work_items]
42-
results: list[int] = yield task.when_all(tasks)
43-
44-
# return an aggregate summary of the results
45-
return {
46-
'work_items': work_items,
47-
'results': results,
48-
'total': sum(results),
49-
}
50-
51-
52-
# Read the environment variable
53-
taskhub_name = os.getenv("TASKHUB")
54-
55-
# Check if the variable exists
56-
if taskhub_name:
57-
print(f"The value of TASKHUB is: {taskhub_name}")
58-
else:
59-
print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use")
60-
print("If you are using windows powershell, run the following: $env:TASKHUB=\"<taskhubname>\"")
61-
print("If you are using bash, run the following: export TASKHUB=\"<taskhubname>\"")
62-
exit()
63-
64-
# Read the environment variable
65-
endpoint = os.getenv("ENDPOINT")
66-
67-
# Check if the variable exists
68-
if endpoint:
69-
print(f"The value of ENDPOINT is: {endpoint}")
70-
else:
71-
print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler")
72-
print("If you are using windows powershell, run the following: $env:ENDPOINT=\"<schedulerEndpoint>\"")
73-
print("If you are using bash, run the following: export ENDPOINT=\"<schedulerEndpoint>\"")
74-
exit()
75-
76-
# Define the scope for Azure Resource Manager (ARM)
77-
arm_scope = "https://durabletask.io/.default"
78-
token_manager = AccessTokenManager(scope = arm_scope)
79-
80-
meta_data: list[tuple[str, str]] = [
81-
("taskhub", taskhub_name)
82-
]
83-
84-
85-
# configure and start the worker
86-
with DurableTaskSchedulerWorker(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) as w:
87-
w.add_orchestrator(orchestrator)
88-
w.add_activity(process_work_item)
89-
w.add_activity(get_work_items)
90-
w.start()
91-
92-
# create a client, start an orchestration, and wait for it to finish
93-
c = DurableTaskSchedulerClient(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager)
94-
instance_id = c.schedule_new_orchestration(orchestrator)
95-
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
96-
if state and state.runtime_status == client.OrchestrationStatus.COMPLETED:
97-
print(f'Orchestration completed! Result: {state.serialized_output}')
98-
elif state:
99-
print(f'Orchestration failed: {state.failure_details}')
1+
"""End-to-end sample that demonstrates how to configure an orchestrator
2+
that a dynamic number activity functions in parallel, waits for them all
3+
to complete, and prints an aggregate summary of the outputs."""
4+
import random
5+
import time
6+
import os
7+
from durabletask import client, task
8+
from durabletask import client, task
9+
from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker
10+
from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient
11+
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
12+
13+
14+
def get_work_items(ctx: task.ActivityContext, _) -> list[str]:
15+
"""Activity function that returns a list of work items"""
16+
# return a random number of work items
17+
count = random.randint(2, 10)
18+
print(f'generating {count} work items...')
19+
return [f'work item {i}' for i in range(count)]
20+
21+
22+
def process_work_item(ctx: task.ActivityContext, item: str) -> int:
23+
"""Activity function that returns a result for a given work item"""
24+
print(f'processing work item: {item}')
25+
26+
# simulate some work that takes a variable amount of time
27+
time.sleep(random.random() * 5)
28+
29+
# return a result for the given work item, which is also a random number in this case
30+
return random.randint(0, 10)
31+
32+
33+
def orchestrator(ctx: task.OrchestrationContext, _):
34+
"""Orchestrator function that calls the 'get_work_items' and 'process_work_item'
35+
activity functions in parallel, waits for them all to complete, and prints
36+
an aggregate summary of the outputs"""
37+
38+
work_items: list[str] = yield ctx.call_activity(get_work_items)
39+
40+
# execute the work-items in parallel and wait for them all to return
41+
tasks = [ctx.call_activity(process_work_item, input=item) for item in work_items]
42+
results: list[int] = yield task.when_all(tasks)
43+
44+
# return an aggregate summary of the results
45+
return {
46+
'work_items': work_items,
47+
'results': results,
48+
'total': sum(results),
49+
}
50+
51+
52+
# Read the environment variable
53+
taskhub_name = os.getenv("TASKHUB")
54+
55+
# Check if the variable exists
56+
if taskhub_name:
57+
print(f"The value of TASKHUB is: {taskhub_name}")
58+
else:
59+
print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use")
60+
print("If you are using windows powershell, run the following: $env:TASKHUB=\"<taskhubname>\"")
61+
print("If you are using bash, run the following: export TASKHUB=\"<taskhubname>\"")
62+
exit()
63+
64+
# Read the environment variable
65+
endpoint = os.getenv("ENDPOINT")
66+
67+
# Check if the variable exists
68+
if endpoint:
69+
print(f"The value of ENDPOINT is: {endpoint}")
70+
else:
71+
print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler")
72+
print("If you are using windows powershell, run the following: $env:ENDPOINT=\"<schedulerEndpoint>\"")
73+
print("If you are using bash, run the following: export ENDPOINT=\"<schedulerEndpoint>\"")
74+
exit()
75+
76+
# configure and start the worker
77+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, client_id="", taskhub=taskhub_name) as w:
78+
w.add_orchestrator(orchestrator)
79+
w.add_activity(process_work_item)
80+
w.add_activity(get_work_items)
81+
w.start()
82+
83+
# create a client, start an orchestration, and wait for it to finish
84+
c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name)
85+
instance_id = c.schedule_new_orchestration(orchestrator)
86+
state = c.wait_for_orchestration_completion(instance_id, timeout=30)
87+
if state and state.runtime_status == client.OrchestrationStatus.COMPLETED:
88+
print(f'Orchestration completed! Result: {state.serialized_output}')
89+
elif state:
90+
print(f'Orchestration failed: {state.failure_details}')

externalpackages/durabletaskscheduler/access_token_manager.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,26 @@
11
# Copyright (c) Microsoft Corporation.
22
# Licensed under the MIT License.
3-
4-
from azure.identity import DefaultAzureCredential
3+
from azure.identity import DefaultAzureCredential, ManagedIdentityCredential
54
from datetime import datetime, timedelta
5+
from typing import Optional
66

77
class AccessTokenManager:
8-
def __init__(self, scope: str, refresh_buffer: int = 60):
9-
self.scope = scope
8+
def __init__(self, refresh_buffer: int = 60, use_managed_identity: bool = False, client_id: Optional[str] = None):
9+
self.scope = "https://durabletask.io/.default"
1010
self.refresh_buffer = refresh_buffer
11-
self.credential = DefaultAzureCredential()
11+
12+
# Choose the appropriate credential based on use_managed_identity
13+
if use_managed_identity:
14+
if not client_id:
15+
print("Using System Assigned Managed Identity for authentication.")
16+
self.credential = ManagedIdentityCredential()
17+
else:
18+
print("Using User Assigned Managed Identity for authentication.")
19+
self.credential = ManagedIdentityCredential(client_id)
20+
else:
21+
self.credential = DefaultAzureCredential()
22+
print("Using Default Azure Credentials for authentication.")
23+
1224
self.token = None
1325
self.expiry_time = None
1426

externalpackages/durabletaskscheduler/durabletask_scheduler_client.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,43 @@
1+
from typing import Optional
12
from durabletask.client import TaskHubGrpcClient
23
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
34

45
class DurableTaskSchedulerClient(TaskHubGrpcClient):
5-
def __init__(self, *args, access_token_manager: AccessTokenManager, **kwargs):
6-
# Initialize the base class
7-
super().__init__(*args, **kwargs)
8-
self._access_token_manager = access_token_manager
6+
def __init__(self, *args,
7+
metadata: Optional[list[tuple[str, str]]] = None,
8+
client_id: Optional[str] = None,
9+
taskhub: str,
10+
**kwargs):
11+
if metadata is None:
12+
metadata = [] # Ensure metadata is initialized
13+
self._metadata = metadata
14+
self._client_id = client_id
15+
self._metadata.append(("taskhub", taskhub))
16+
self._access_token_manager = AccessTokenManager(client_id=self._client_id)
917
self.__update_metadata_with_token()
18+
super().__init__(*args, metadata=self._metadata, **kwargs)
1019

1120
def __update_metadata_with_token(self):
1221
"""
1322
Add or update the `authorization` key in the metadata with the current access token.
1423
"""
15-
if self._access_token_manager is not None:
16-
token = self._access_token_manager.get_access_token()
17-
18-
# Check if "authorization" already exists in the metadata
19-
updated = False
20-
for i, (key, _) in enumerate(self._metadata):
21-
if key == "authorization":
22-
self._metadata[i] = ("authorization", token)
23-
updated = True
24-
break
25-
26-
# If not updated, add a new entry
27-
if not updated:
28-
self._metadata.append(("authorization", token))
24+
token = self._access_token_manager.get_access_token()
25+
26+
# Ensure that self._metadata is initialized
27+
if self._metadata is None:
28+
self._metadata = [] # Initialize it if it's still None
29+
30+
# Check if "authorization" already exists in the metadata
31+
updated = False
32+
for i, (key, _) in enumerate(self._metadata):
33+
if key == "authorization":
34+
self._metadata[i] = ("authorization", token)
35+
updated = True
36+
break
37+
38+
# If not updated, add a new entry
39+
if not updated:
40+
self._metadata.append(("authorization", token))
2941

3042
def schedule_new_orchestration(self, *args, **kwargs) -> str:
3143
self.__update_metadata_with_token()

externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,35 +5,48 @@
55
import durabletask.internal.orchestrator_service_pb2 as pb
66
import durabletask.internal.orchestrator_service_pb2_grpc as stubs
77
import durabletask.internal.shared as shared
8+
from typing import Optional
89

910
from durabletask.worker import TaskHubGrpcWorker
1011
from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager
1112

1213
class DurableTaskSchedulerWorker(TaskHubGrpcWorker):
13-
def __init__(self, *args, access_token_manager: AccessTokenManager, **kwargs):
14-
# Initialize the base class
15-
super().__init__(*args, **kwargs)
16-
self._access_token_manager = access_token_manager
14+
def __init__(self, *args,
15+
metadata: Optional[list[tuple[str, str]]] = None,
16+
client_id: Optional[str] = None,
17+
taskhub: str,
18+
**kwargs):
19+
if metadata is None:
20+
metadata = [] # Ensure metadata is initialized
21+
self._metadata = metadata
22+
self._client_id = client_id
23+
self._metadata.append(("taskhub", taskhub))
24+
self._access_token_manager = AccessTokenManager(client_id=self._client_id)
1725
self.__update_metadata_with_token()
26+
super().__init__(*args, metadata=self._metadata, **kwargs)
27+
1828

1929
def __update_metadata_with_token(self):
2030
"""
2131
Add or update the `authorization` key in the metadata with the current access token.
2232
"""
23-
if self._access_token_manager is not None:
24-
token = self._access_token_manager.get_access_token()
25-
26-
# Check if "authorization" already exists in the metadata
27-
updated = False
28-
for i, (key, _) in enumerate(self._metadata):
29-
if key == "authorization":
30-
self._metadata[i] = ("authorization", token)
31-
updated = True
32-
break
33-
34-
# If not updated, add a new entry
35-
if not updated:
36-
self._metadata.append(("authorization", token))
33+
token = self._access_token_manager.get_access_token()
34+
35+
# Ensure that self._metadata is initialized
36+
if self._metadata is None:
37+
self._metadata = [] # Initialize it if it's still None
38+
39+
# Check if "authorization" already exists in the metadata
40+
updated = False
41+
for i, (key, _) in enumerate(self._metadata):
42+
if key == "authorization":
43+
self._metadata[i] = ("authorization", token)
44+
updated = True
45+
break
46+
47+
# If not updated, add a new entry
48+
if not updated:
49+
self._metadata.append(("authorization", token))
3750

3851
def start(self):
3952
"""Starts the worker on a background thread and begins listening for work items."""

0 commit comments

Comments
 (0)