From 026c5722fb15e6f1136d730dc206abd2238d911a Mon Sep 17 00:00:00 2001 From: Ryan Lettieri <67934986+RyanLettieri@users.noreply.github.com> Date: Tue, 21 Jan 2025 17:42:24 -0700 Subject: [PATCH 01/31] Creation of DTS example and passing of completionToken Signed-off-by: Ryan Lettieri <67934986+RyanLettieri@users.noreply.github.com> --- durabletask/worker.py | 21 +++++---- examples/README.md | 31 +++++++++++-- examples/dts_activity_sequence.py | 73 +++++++++++++++++++++++++++++++ 3 files changed, 114 insertions(+), 11 deletions(-) create mode 100644 examples/dts_activity_sequence.py diff --git a/durabletask/worker.py b/durabletask/worker.py index 75e2e37..51a62fd 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -143,9 +143,11 @@ def run_loop(): request_type = work_item.WhichOneof('request') self._logger.debug(f'Received "{request_type}" work item') if work_item.HasField('orchestratorRequest'): - executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub) + executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken) elif work_item.HasField('activityRequest'): - executor.submit(self._execute_activity, work_item.activityRequest, stub) + executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken) + elif work_item.HasField('healthPing'): + pass # no-op else: self._logger.warning(f'Unexpected work item type: {request_type}') @@ -184,26 +186,27 @@ def stop(self): self._logger.info("Worker shutdown completed") self._is_running = False - def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHubSidecarServiceStub): + def _execute_orchestrator(self, req: pb.OrchestratorRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken): try: executor = _OrchestrationExecutor(self._registry, self._logger) result = executor.execute(req.instanceId, req.pastEvents, req.newEvents) res = pb.OrchestratorResponse( instanceId=req.instanceId, actions=result.actions, - customStatus=pbh.get_string_value(result.encoded_custom_status)) + customStatus=pbh.get_string_value(result.encoded_custom_status), + completionToken=completionToken) except Exception as ex: self._logger.exception(f"An error occurred while trying to execute instance '{req.instanceId}': {ex}") failure_details = pbh.new_failure_details(ex) actions = [pbh.new_complete_orchestration_action(-1, pb.ORCHESTRATION_STATUS_FAILED, "", failure_details)] - res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=actions) + res = pb.OrchestratorResponse(instanceId=req.instanceId, actions=actions, completionToken=completionToken) try: stub.CompleteOrchestratorTask(res) except Exception as ex: self._logger.exception(f"Failed to deliver orchestrator response for '{req.instanceId}' to sidecar: {ex}") - def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub): + def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarServiceStub, completionToken): instance_id = req.orchestrationInstance.instanceId try: executor = _ActivityExecutor(self._registry, self._logger) @@ -211,12 +214,14 @@ def _execute_activity(self, req: pb.ActivityRequest, stub: stubs.TaskHubSidecarS res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, - result=pbh.get_string_value(result)) + result=pbh.get_string_value(result), + completionToken=completionToken) except Exception as ex: res = pb.ActivityResponse( instanceId=instance_id, taskId=req.taskId, - failureDetails=pbh.new_failure_details(ex)) + failureDetails=pbh.new_failure_details(ex), + completionToken=completionToken) try: stub.CompleteActivityTask(res) diff --git a/examples/README.md b/examples/README.md index ec9088f..4b5fee0 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,14 +1,39 @@ # Examples -This directory contains examples of how to author durable orchestrations using the Durable Task Python SDK. +This directory contains examples of how to author durable orchestrations using the Durable Task Python SDK. There are two backends that are compatible with the Durable Task Python SDK: The Dapr sidecar, and the Durable Task Scheduler (DTS) -## Prerequisites +## Prerequisites for using Dapr All the examples assume that you have a Durable Task-compatible sidecar running locally. There are two options for this: 1. Install the latest version of the [Dapr CLI](https://docs.dapr.io/getting-started/install-dapr-cli/), which contains and exposes an embedded version of the Durable Task engine. The setup process (which requires Docker) will configure the workflow engine to store state in a local Redis container. -1. Clone and run the [Durable Task Sidecar](https://github.com/microsoft/durabletask-go) project locally (requires Go 1.18 or higher). Orchestration state will be stored in a local sqlite database. +2. Clone and run the [Durable Task Sidecar](https://github.com/microsoft/durabletask-go) project locally (requires Go 1.18 or higher). Orchestration state will be stored in a local sqlite database. + + +## Prerequisites for using DTS + +All the examples assume that you have a Durable Task Scheduler taskhub created. + +The simplest way to create a taskhub is by using the az cli commands: + +1. Create a scheduler: + az durabletask scheduler create --resource-group --name --location --ip-allowlist "[0.0.0.0/0]" --sku-capacity 1, --sku-name "Dedicated" --tags "{}" + +2. Create your taskhub + az durabletask taskhub create --resource-group --scheduler-name --name + +3. Retrieve the endpoint for the taskhub. This can be done by locating the taskhub in the portal. + +4. Set the appropriate environment variables for the TASKHUB and ENDPOINT + +```sh +export TASKHUB= +``` + +```sh +export ENDPOINT= +``` ## Running the examples diff --git a/examples/dts_activity_sequence.py b/examples/dts_activity_sequence.py new file mode 100644 index 0000000..3cf4429 --- /dev/null +++ b/examples/dts_activity_sequence.py @@ -0,0 +1,73 @@ +import os +from azure.identity import DefaultAzureCredential + +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +from durabletask import client, task, worker + + +def hello(ctx: task.ActivityContext, name: str) -> str: + """Activity function that returns a greeting""" + return f'Hello {name}!' + + +def sequence(ctx: task.OrchestrationContext, _): + """Orchestrator function that calls the 'hello' activity function in a sequence""" + # call "hello" activity function in a sequence + result1 = yield ctx.call_activity(hello, input='Tokyo') + result2 = yield ctx.call_activity(hello, input='Seattle') + result3 = yield ctx.call_activity(hello, input='London') + + # return an array of results + return [result1, result2, result3] + + +# Read the environment variable +taskhub_name = os.getenv("TASKHUB") + +# Check if the variable exists +if taskhub_name: + print(f"The value of TASKHUB is: {taskhub_name}") +else: + print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use") + print("If you are using windows powershell, run the following: $env:TASKHUB=\"\"") + print("If you are using bash, run the following: export TASKHUB=\"\"") + exit() + +# Read the environment variable +endpoint = os.getenv("ENDPOINT") + +# Check if the variable exists +if endpoint: + print(f"The value of ENDPOINT is: {endpoint}") +else: + print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the taskhub") + print("If you are using windows powershell, run the following: $env:ENDPOINT=\"\"") + print("If you are using bash, run the following: export ENDPOINT=\"\"") + exit() + + +default_credential = DefaultAzureCredential() +# Define the scope for Azure Resource Manager (ARM) +arm_scope = "https://durabletask.io/.default" + +# Retrieve the access token +access_token = "Bearer " + default_credential.get_token(arm_scope).token +# create a client, start an orchestration, and wait for it to finish +metaData: list[tuple[str, str]] = [ + ("taskhub", taskhub_name), # Hardcode for now, just the taskhub name + ("authorization", access_token) # use azure identity sdk for python +] +# configure and start the worker +with worker.TaskHubGrpcWorker(host_address=endpoint, metadata=metaData, secure_channel=True) as w: + w.add_orchestrator(sequence) + w.add_activity(hello) + w.start() + + c = client.TaskHubGrpcClient(host_address=endpoint, metadata=metaData, secure_channel=True) + instance_id = c.schedule_new_orchestration(sequence) + state = c.wait_for_orchestration_completion(instance_id, timeout=45) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') From 136a3d0a885fc50a0d273e0747e3840cc5fb4638 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Wed, 22 Jan 2025 12:09:11 -0700 Subject: [PATCH 02/31] Adressing review feedback Signed-off-by: Ryan Lettieri --- examples/README.md | 2 +- examples/dts/README.md | 35 +++++++++++++++++++++ examples/{ => dts}/dts_activity_sequence.py | 14 ++++----- 3 files changed, 43 insertions(+), 8 deletions(-) create mode 100644 examples/dts/README.md rename examples/{ => dts}/dts_activity_sequence.py (88%) diff --git a/examples/README.md b/examples/README.md index 4b5fee0..ae76979 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,6 +1,6 @@ # Examples -This directory contains examples of how to author durable orchestrations using the Durable Task Python SDK. There are two backends that are compatible with the Durable Task Python SDK: The Dapr sidecar, and the Durable Task Scheduler (DTS) +This directory contains examples of how to author durable orchestrations using the Durable Task Python SDK. There are two backends that are compatible with the Durable Task Python SDK: The Dapr sidecar, and the Durable Task Scheduler (DTS). ## Prerequisites for using Dapr diff --git a/examples/dts/README.md b/examples/dts/README.md new file mode 100644 index 0000000..866eed1 --- /dev/null +++ b/examples/dts/README.md @@ -0,0 +1,35 @@ +# Examples + +This directory contains examples of how to author durable orchestrations using the Durable Task Python SDK in conjunction with the Durable Task Scheduler (DTS). + +## Prerequisites + +All the examples assume that you have a Durable Task Scheduler taskhub created. + +The simplest way to create a taskhub is by using the az cli commands: + +1. Create a scheduler: + az durabletask scheduler create --resource-group --name --location --ip-allowlist "[0.0.0.0/0]" --sku-capacity 1 --sku-name "Dedicated" --tags "{}" + +2. Create your taskhub + az durabletask taskhub create --resource-group --scheduler-name --name + +3. Retrieve the endpoint for the scheduler. This can be done by locating the taskhub in the portal. + +4. Set the appropriate environment variables for the TASKHUB and ENDPOINT + +```sh +export TASKHUB= +``` + +```sh +export ENDPOINT= +``` + +## Running the examples + +With one of the sidecars running, you can simply execute any of the examples in this directory using `python3`: + +```sh +python3 dts_activity_sequence.py +``` diff --git a/examples/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py similarity index 88% rename from examples/dts_activity_sequence.py rename to examples/dts/dts_activity_sequence.py index 3cf4429..e7af40e 100644 --- a/examples/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -41,9 +41,9 @@ def sequence(ctx: task.OrchestrationContext, _): if endpoint: print(f"The value of ENDPOINT is: {endpoint}") else: - print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the taskhub") - print("If you are using windows powershell, run the following: $env:ENDPOINT=\"\"") - print("If you are using bash, run the following: export ENDPOINT=\"\"") + print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler") + print("If you are using windows powershell, run the following: $env:ENDPOINT=\"\"") + print("If you are using bash, run the following: export ENDPOINT=\"\"") exit() @@ -51,12 +51,12 @@ def sequence(ctx: task.OrchestrationContext, _): # Define the scope for Azure Resource Manager (ARM) arm_scope = "https://durabletask.io/.default" -# Retrieve the access token +# Retrieve the access token. Note that this approach doesn't support token refresh and can't be used in production. access_token = "Bearer " + default_credential.get_token(arm_scope).token -# create a client, start an orchestration, and wait for it to finish + metaData: list[tuple[str, str]] = [ - ("taskhub", taskhub_name), # Hardcode for now, just the taskhub name - ("authorization", access_token) # use azure identity sdk for python + ("taskhub", taskhub_name), + ("authorization", access_token) ] # configure and start the worker with worker.TaskHubGrpcWorker(host_address=endpoint, metadata=metaData, secure_channel=True) as w: From 6df1064bea4cba59c741e1a0e13dc3ac5d4d22e0 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Wed, 22 Jan 2025 12:11:18 -0700 Subject: [PATCH 03/31] Reverting dapr readme Signed-off-by: Ryan Lettieri --- examples/README.md | 29 ++--------------------------- 1 file changed, 2 insertions(+), 27 deletions(-) diff --git a/examples/README.md b/examples/README.md index ae76979..7cfbc7a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,8 +1,8 @@ # Examples -This directory contains examples of how to author durable orchestrations using the Durable Task Python SDK. There are two backends that are compatible with the Durable Task Python SDK: The Dapr sidecar, and the Durable Task Scheduler (DTS). +This directory contains examples of how to author durable orchestrations using the Durable Task Python SDK. -## Prerequisites for using Dapr +## Prerequisites All the examples assume that you have a Durable Task-compatible sidecar running locally. There are two options for this: @@ -10,31 +10,6 @@ All the examples assume that you have a Durable Task-compatible sidecar running 2. Clone and run the [Durable Task Sidecar](https://github.com/microsoft/durabletask-go) project locally (requires Go 1.18 or higher). Orchestration state will be stored in a local sqlite database. - -## Prerequisites for using DTS - -All the examples assume that you have a Durable Task Scheduler taskhub created. - -The simplest way to create a taskhub is by using the az cli commands: - -1. Create a scheduler: - az durabletask scheduler create --resource-group --name --location --ip-allowlist "[0.0.0.0/0]" --sku-capacity 1, --sku-name "Dedicated" --tags "{}" - -2. Create your taskhub - az durabletask taskhub create --resource-group --scheduler-name --name - -3. Retrieve the endpoint for the taskhub. This can be done by locating the taskhub in the portal. - -4. Set the appropriate environment variables for the TASKHUB and ENDPOINT - -```sh -export TASKHUB= -``` - -```sh -export ENDPOINT= -``` - ## Running the examples With one of the sidecars running, you can simply execute any of the examples in this directory using `python3`: From f731c0d5cc2358c99f2d09f98afd168b979ff6be Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Fri, 24 Jan 2025 12:04:18 -0700 Subject: [PATCH 04/31] Adding accessTokenManager class for refreshing credential token Signed-off-by: Ryan Lettieri --- durabletask/accessTokenManager.py | 29 +++++++++++++++++++++++++++ durabletask/worker.py | 26 +++++++++++++++++++++++- examples/dts/README.md | 6 ++++++ examples/dts/dts_activity_sequence.py | 13 +++++------- 4 files changed, 65 insertions(+), 9 deletions(-) create mode 100644 durabletask/accessTokenManager.py diff --git a/durabletask/accessTokenManager.py b/durabletask/accessTokenManager.py new file mode 100644 index 0000000..8e0dd9a --- /dev/null +++ b/durabletask/accessTokenManager.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from azure.identity import DefaultAzureCredential +from datetime import datetime, timedelta + +class AccessTokenManager: + def __init__(self, scope: str, refresh_buffer: int = 60): + self.scope = scope + self.refresh_buffer = refresh_buffer + self.credential = DefaultAzureCredential() + self.token = None + self.expiry_time = None + + def get_access_token(self) -> str: + if self.token is None or self.is_token_expired(): + self.refresh_token() + return self.token + + def is_token_expired(self) -> bool: + if self.expiry_time is None: + return True + return datetime.utcnow() >= (self.expiry_time - timedelta(seconds=self.refresh_buffer)) + + def refresh_token(self): + new_token = self.credential.get_token(self.scope) + self.token = f"Bearer {new_token.token}" + self.expiry_time = datetime.utcnow() + timedelta(seconds=new_token.expires_on - int(datetime.utcnow().timestamp())) + print(f"Token refreshed. Expires at: {self.expiry_time}") \ No newline at end of file diff --git a/durabletask/worker.py b/durabletask/worker.py index 51a62fd..8722267 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -16,6 +16,7 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared +from durabletask.accessTokenManager import AccessTokenManager from durabletask import task TInput = TypeVar('TInput') @@ -88,7 +89,8 @@ def __init__(self, *, metadata: Optional[list[tuple[str, str]]] = None, log_handler=None, log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False): + secure_channel: bool = False, + access_token_manager: AccessTokenManager = None): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() self._metadata = metadata @@ -96,6 +98,27 @@ def __init__(self, *, self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel + self._access_token_manager = access_token_manager + self.__update_metadata_with_token() + + def __update_metadata_with_token(self): + """ + Add or update the `authorization` key in the metadata with the current access token. + """ + if self._access_token_manager is not None: + token = self._access_token_manager.get_access_token() + + # Check if "authorization" already exists in the metadata + updated = False + for i, (key, _) in enumerate(self._metadata): + if key == "authorization": + self._metadata[i] = ("authorization", token) + updated = True + break + + # If not updated, add a new entry + if not updated: + self._metadata.append(("authorization", token)) def __enter__(self): return self @@ -130,6 +153,7 @@ def run_loop(): with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: while not self._shutdown.is_set(): try: + self.__update_metadata_with_token() # send a "Hello" message to the sidecar to ensure that it's listening stub.Hello(empty_pb2.Empty()) diff --git a/examples/dts/README.md b/examples/dts/README.md index 866eed1..f349cf2 100644 --- a/examples/dts/README.md +++ b/examples/dts/README.md @@ -26,6 +26,12 @@ export TASKHUB= export ENDPOINT= ``` +5. Since the samples rely on azure identity, ensure the package is installed and up-to-date + +```sh +python3 -m pip install azure-identity +``` + ## Running the examples With one of the sidecars running, you can simply execute any of the examples in this directory using `python3`: diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index e7af40e..f10563a 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -4,7 +4,7 @@ """End-to-end sample that demonstrates how to configure an orchestrator that calls an activity function in a sequence and prints the outputs.""" from durabletask import client, task, worker - +from durabletask.accessTokenManager import AccessTokenManager def hello(ctx: task.ActivityContext, name: str) -> str: """Activity function that returns a greeting""" @@ -47,19 +47,16 @@ def sequence(ctx: task.OrchestrationContext, _): exit() -default_credential = DefaultAzureCredential() # Define the scope for Azure Resource Manager (ARM) arm_scope = "https://durabletask.io/.default" - -# Retrieve the access token. Note that this approach doesn't support token refresh and can't be used in production. -access_token = "Bearer " + default_credential.get_token(arm_scope).token +token_manager = AccessTokenManager(scope = arm_scope) metaData: list[tuple[str, str]] = [ - ("taskhub", taskhub_name), - ("authorization", access_token) + ("taskhub", taskhub_name) ] + # configure and start the worker -with worker.TaskHubGrpcWorker(host_address=endpoint, metadata=metaData, secure_channel=True) as w: +with worker.TaskHubGrpcWorker(host_address=endpoint, metadata=metaData, secure_channel=True, access_token_manager=token_manager) as w: w.add_orchestrator(sequence) w.add_activity(hello) w.start() From eb984164c1c7cc1e21cc5b8a7f5ab5a32803be76 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Fri, 24 Jan 2025 12:08:18 -0700 Subject: [PATCH 05/31] Adding comments to the example Signed-off-by: Ryan Lettieri --- durabletask/accessTokenManager.py | 56 +++++++++++++-------------- examples/dts/dts_activity_sequence.py | 1 + 2 files changed, 29 insertions(+), 28 deletions(-) diff --git a/durabletask/accessTokenManager.py b/durabletask/accessTokenManager.py index 8e0dd9a..7628355 100644 --- a/durabletask/accessTokenManager.py +++ b/durabletask/accessTokenManager.py @@ -1,29 +1,29 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from azure.identity import DefaultAzureCredential -from datetime import datetime, timedelta - -class AccessTokenManager: - def __init__(self, scope: str, refresh_buffer: int = 60): - self.scope = scope - self.refresh_buffer = refresh_buffer - self.credential = DefaultAzureCredential() - self.token = None - self.expiry_time = None - - def get_access_token(self) -> str: - if self.token is None or self.is_token_expired(): - self.refresh_token() - return self.token - - def is_token_expired(self) -> bool: - if self.expiry_time is None: - return True - return datetime.utcnow() >= (self.expiry_time - timedelta(seconds=self.refresh_buffer)) - - def refresh_token(self): - new_token = self.credential.get_token(self.scope) - self.token = f"Bearer {new_token.token}" - self.expiry_time = datetime.utcnow() + timedelta(seconds=new_token.expires_on - int(datetime.utcnow().timestamp())) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from azure.identity import DefaultAzureCredential +from datetime import datetime, timedelta + +class AccessTokenManager: + def __init__(self, scope: str, refresh_buffer: int = 60): + self.scope = scope + self.refresh_buffer = refresh_buffer + self.credential = DefaultAzureCredential() + self.token = None + self.expiry_time = None + + def get_access_token(self) -> str: + if self.token is None or self.is_token_expired(): + self.refresh_token() + return self.token + + def is_token_expired(self) -> bool: + if self.expiry_time is None: + return True + return datetime.utcnow() >= (self.expiry_time - timedelta(seconds=self.refresh_buffer)) + + def refresh_token(self): + new_token = self.credential.get_token(self.scope) + self.token = f"Bearer {new_token.token}" + self.expiry_time = datetime.utcnow() + timedelta(seconds=new_token.expires_on - int(datetime.utcnow().timestamp())) print(f"Token refreshed. Expires at: {self.expiry_time}") \ No newline at end of file diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index f10563a..54a2376 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -61,6 +61,7 @@ def sequence(ctx: task.OrchestrationContext, _): w.add_activity(hello) w.start() + # Construct the client and run the orchestrations c = client.TaskHubGrpcClient(host_address=endpoint, metadata=metaData, secure_channel=True) instance_id = c.schedule_new_orchestration(sequence) state = c.wait_for_orchestration_completion(instance_id, timeout=45) From 0de338d6b168e6bf98fe0fdff3e081747dd765eb Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Fri, 24 Jan 2025 12:11:05 -0700 Subject: [PATCH 06/31] Adding in requirement for azure-identity Signed-off-by: Ryan Lettieri --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index a31419b..49896d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,4 @@ grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newe protobuf pytest pytest-cov +azure-identity \ No newline at end of file From 6050771604fb89fc7fef1b83a0f05b1b55019e74 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Tue, 28 Jan 2025 11:00:30 -0700 Subject: [PATCH 07/31] Moving dts logic into its own module Signed-off-by: Ryan Lettieri --- durabletask/worker.py | 28 +----- examples/dts/dts_activity_sequence.py | 7 +- .../durabletaskscheduler/__init__.py | 7 ++ .../access_token_manager.py | 0 .../durabletask_scheduler_client.py | 6 ++ .../durabletask_scheduler_worker.py | 94 +++++++++++++++++++ 6 files changed, 114 insertions(+), 28 deletions(-) create mode 100644 externalpackages/durabletaskscheduler/__init__.py rename durabletask/accessTokenManager.py => externalpackages/durabletaskscheduler/access_token_manager.py (100%) create mode 100644 externalpackages/durabletaskscheduler/durabletask_scheduler_client.py create mode 100644 externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py diff --git a/durabletask/worker.py b/durabletask/worker.py index 8722267..d6d65b1 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -16,7 +16,7 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared -from durabletask.accessTokenManager import AccessTokenManager + from durabletask import task TInput = TypeVar('TInput') @@ -89,8 +89,7 @@ def __init__(self, *, metadata: Optional[list[tuple[str, str]]] = None, log_handler=None, log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False, - access_token_manager: AccessTokenManager = None): + secure_channel: bool = False): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() self._metadata = metadata @@ -98,27 +97,7 @@ def __init__(self, *, self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel - self._access_token_manager = access_token_manager - self.__update_metadata_with_token() - - def __update_metadata_with_token(self): - """ - Add or update the `authorization` key in the metadata with the current access token. - """ - if self._access_token_manager is not None: - token = self._access_token_manager.get_access_token() - - # Check if "authorization" already exists in the metadata - updated = False - for i, (key, _) in enumerate(self._metadata): - if key == "authorization": - self._metadata[i] = ("authorization", token) - updated = True - break - - # If not updated, add a new entry - if not updated: - self._metadata.append(("authorization", token)) + def __enter__(self): return self @@ -153,7 +132,6 @@ def run_loop(): with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: while not self._shutdown.is_set(): try: - self.__update_metadata_with_token() # send a "Hello" message to the sidecar to ensure that it's listening stub.Hello(empty_pb2.Empty()) diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index 54a2376..588f8bf 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -3,8 +3,9 @@ """End-to-end sample that demonstrates how to configure an orchestrator that calls an activity function in a sequence and prints the outputs.""" -from durabletask import client, task, worker -from durabletask.accessTokenManager import AccessTokenManager +from durabletask import client, task +from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker +from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager def hello(ctx: task.ActivityContext, name: str) -> str: """Activity function that returns a greeting""" @@ -56,7 +57,7 @@ def sequence(ctx: task.OrchestrationContext, _): ] # configure and start the worker -with worker.TaskHubGrpcWorker(host_address=endpoint, metadata=metaData, secure_channel=True, access_token_manager=token_manager) as w: +with DurableTaskSchedulerWorker(host_address=endpoint, metadata=metaData, secure_channel=True, access_token_manager=token_manager) as w: w.add_orchestrator(sequence) w.add_activity(hello) w.start() diff --git a/externalpackages/durabletaskscheduler/__init__.py b/externalpackages/durabletaskscheduler/__init__.py new file mode 100644 index 0000000..e3941ba --- /dev/null +++ b/externalpackages/durabletaskscheduler/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Durable Task SDK for Python""" + + +PACKAGE_NAME = "durabletaskscheduler" diff --git a/durabletask/accessTokenManager.py b/externalpackages/durabletaskscheduler/access_token_manager.py similarity index 100% rename from durabletask/accessTokenManager.py rename to externalpackages/durabletaskscheduler/access_token_manager.py diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py new file mode 100644 index 0000000..66094da --- /dev/null +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py @@ -0,0 +1,6 @@ +from durabletask import TaskHubGrpcClient + +class DurableTaskSchedulerClient(TaskHubGrpcClient): + def __init__(self, *args, **kwargs): + # Initialize the base class + super().__init__(*args, **kwargs) \ No newline at end of file diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py new file mode 100644 index 0000000..f283661 --- /dev/null +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py @@ -0,0 +1,94 @@ +import concurrent.futures +from threading import Thread +from google.protobuf import empty_pb2 +import grpc +import durabletask.internal.orchestrator_service_pb2 as pb +import durabletask.internal.orchestrator_service_pb2_grpc as stubs +import durabletask.internal.shared as shared + +from durabletask.worker import TaskHubGrpcWorker +from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager + +class DurableTaskSchedulerWorker(TaskHubGrpcWorker): + def __init__(self, *args, access_token_manager: AccessTokenManager = None, **kwargs): + # Initialize the base class + super().__init__(*args, **kwargs) + self._access_token_manager = access_token_manager + self.__update_metadata_with_token() + + def __update_metadata_with_token(self): + """ + Add or update the `authorization` key in the metadata with the current access token. + """ + if self._access_token_manager is not None: + token = self._access_token_manager.get_access_token() + + # Check if "authorization" already exists in the metadata + updated = False + for i, (key, _) in enumerate(self._metadata): + if key == "authorization": + self._metadata[i] = ("authorization", token) + updated = True + break + + # If not updated, add a new entry + if not updated: + self._metadata.append(("authorization", token)) + + def start(self): + """Starts the worker on a background thread and begins listening for work items.""" + channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel) + stub = stubs.TaskHubSidecarServiceStub(channel) + + if self._is_running: + raise RuntimeError('The worker is already running.') + + def run_loop(): + # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity + # functions. We'd need to know ahead of time whether a function is async or not. + # TODO: Max concurrency configuration settings + with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: + while not self._shutdown.is_set(): + try: + self.__update_metadata_with_token() + # send a "Hello" message to the sidecar to ensure that it's listening + stub.Hello(empty_pb2.Empty()) + + # stream work items + self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest()) + self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...') + + # The stream blocks until either a work item is received or the stream is canceled + # by another thread (see the stop() method). + for work_item in self._response_stream: # type: ignore + request_type = work_item.WhichOneof('request') + self._logger.debug(f'Received "{request_type}" work item') + if work_item.HasField('orchestratorRequest'): + executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken) + elif work_item.HasField('activityRequest'): + executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken) + elif work_item.HasField('healthPing'): + pass # no-op + else: + self._logger.warning(f'Unexpected work item type: {request_type}') + + except grpc.RpcError as rpc_error: + if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore + self._logger.info(f'Disconnected from {self._host_address}') + elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore + self._logger.warning( + f'The sidecar at address {self._host_address} is unavailable - will continue retrying') + else: + self._logger.warning(f'Unexpected error: {rpc_error}') + except Exception as ex: + self._logger.warning(f'Unexpected error: {ex}') + + # CONSIDER: exponential backoff + self._shutdown.wait(5) + self._logger.info("No longer listening for work items") + return + + self._logger.info(f"Starting gRPC worker that connects to {self._host_address}") + self._runLoop = Thread(target=run_loop) + self._runLoop.start() + self._is_running = True From f4f98ee9b51c1fefd2d9a66853832074bcbdb817 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Tue, 28 Jan 2025 11:02:31 -0700 Subject: [PATCH 08/31] Fixing whitesapce Signed-off-by: Ryan Lettieri --- durabletask/worker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/durabletask/worker.py b/durabletask/worker.py index d6d65b1..51a62fd 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -16,7 +16,6 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared - from durabletask import task TInput = TypeVar('TInput') @@ -98,7 +97,6 @@ def __init__(self, *, self._is_running = False self._secure_channel = secure_channel - def __enter__(self): return self From ea837d077bbd4112e97efbbc1d9e3d88887dfc1d Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Wed, 29 Jan 2025 09:18:03 -0700 Subject: [PATCH 09/31] Updating dts client to refresh token Signed-off-by: Ryan Lettieri --- durabletask/client.py | 1 + examples/dts/dts_activity_sequence.py | 11 +-- examples/dts/dts_fanout_fanin.py | 99 +++++++++++++++++++ .../durabletask_scheduler_client.py | 64 +++++++++++- .../durabletask_scheduler_worker.py | 2 +- 5 files changed, 167 insertions(+), 10 deletions(-) create mode 100644 examples/dts/dts_fanout_fanin.py diff --git a/durabletask/client.py b/durabletask/client.py index 31953ae..74e51f5 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -97,6 +97,7 @@ def __init__(self, *, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False): + self._metadata = metadata channel = shared.get_grpc_channel(host_address, metadata, secure_channel=secure_channel) self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index 588f8bf..e8d8c55 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -1,10 +1,9 @@ -import os -from azure.identity import DefaultAzureCredential - """End-to-end sample that demonstrates how to configure an orchestrator that calls an activity function in a sequence and prints the outputs.""" +import os from durabletask import client, task from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker +from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager def hello(ctx: task.ActivityContext, name: str) -> str: @@ -52,18 +51,18 @@ def sequence(ctx: task.OrchestrationContext, _): arm_scope = "https://durabletask.io/.default" token_manager = AccessTokenManager(scope = arm_scope) -metaData: list[tuple[str, str]] = [ +meta_data: list[tuple[str, str]] = [ ("taskhub", taskhub_name) ] # configure and start the worker -with DurableTaskSchedulerWorker(host_address=endpoint, metadata=metaData, secure_channel=True, access_token_manager=token_manager) as w: +with DurableTaskSchedulerWorker(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) as w: w.add_orchestrator(sequence) w.add_activity(hello) w.start() # Construct the client and run the orchestrations - c = client.TaskHubGrpcClient(host_address=endpoint, metadata=metaData, secure_channel=True) + c = DurableTaskSchedulerClient(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) instance_id = c.schedule_new_orchestration(sequence) state = c.wait_for_orchestration_completion(instance_id, timeout=45) if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: diff --git a/examples/dts/dts_fanout_fanin.py b/examples/dts/dts_fanout_fanin.py new file mode 100644 index 0000000..9c8cc65 --- /dev/null +++ b/examples/dts/dts_fanout_fanin.py @@ -0,0 +1,99 @@ +"""End-to-end sample that demonstrates how to configure an orchestrator +that a dynamic number activity functions in parallel, waits for them all +to complete, and prints an aggregate summary of the outputs.""" +import random +import time +import os +from durabletask import client, task +from durabletask import client, task +from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker +from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient +from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager + + +def get_work_items(ctx: task.ActivityContext, _) -> list[str]: + """Activity function that returns a list of work items""" + # return a random number of work items + count = random.randint(2, 10) + print(f'generating {count} work items...') + return [f'work item {i}' for i in range(count)] + + +def process_work_item(ctx: task.ActivityContext, item: str) -> int: + """Activity function that returns a result for a given work item""" + print(f'processing work item: {item}') + + # simulate some work that takes a variable amount of time + time.sleep(random.random() * 5) + + # return a result for the given work item, which is also a random number in this case + return random.randint(0, 10) + + +def orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that calls the 'get_work_items' and 'process_work_item' + activity functions in parallel, waits for them all to complete, and prints + an aggregate summary of the outputs""" + + work_items: list[str] = yield ctx.call_activity(get_work_items) + + # execute the work-items in parallel and wait for them all to return + tasks = [ctx.call_activity(process_work_item, input=item) for item in work_items] + results: list[int] = yield task.when_all(tasks) + + # return an aggregate summary of the results + return { + 'work_items': work_items, + 'results': results, + 'total': sum(results), + } + + +# Read the environment variable +taskhub_name = os.getenv("TASKHUB") + +# Check if the variable exists +if taskhub_name: + print(f"The value of TASKHUB is: {taskhub_name}") +else: + print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use") + print("If you are using windows powershell, run the following: $env:TASKHUB=\"\"") + print("If you are using bash, run the following: export TASKHUB=\"\"") + exit() + +# Read the environment variable +endpoint = os.getenv("ENDPOINT") + +# Check if the variable exists +if endpoint: + print(f"The value of ENDPOINT is: {endpoint}") +else: + print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler") + print("If you are using windows powershell, run the following: $env:ENDPOINT=\"\"") + print("If you are using bash, run the following: export ENDPOINT=\"\"") + exit() + +# Define the scope for Azure Resource Manager (ARM) +arm_scope = "https://durabletask.io/.default" +token_manager = AccessTokenManager(scope = arm_scope) + +meta_data: list[tuple[str, str]] = [ + ("taskhub", taskhub_name) +] + + +# configure and start the worker +with DurableTaskSchedulerWorker(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) as w: + w.add_orchestrator(orchestrator) + w.add_activity(process_work_item) + w.add_activity(get_work_items) + w.start() + + # create a client, start an orchestration, and wait for it to finish + c = DurableTaskSchedulerClient(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) + instance_id = c.schedule_new_orchestration(orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py index 66094da..663eb5c 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py @@ -1,6 +1,64 @@ -from durabletask import TaskHubGrpcClient +from durabletask.client import TaskHubGrpcClient +from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager class DurableTaskSchedulerClient(TaskHubGrpcClient): - def __init__(self, *args, **kwargs): + def __init__(self, *args, access_token_manager: AccessTokenManager, **kwargs): # Initialize the base class - super().__init__(*args, **kwargs) \ No newline at end of file + super().__init__(*args, **kwargs) + self._access_token_manager = access_token_manager + self.__update_metadata_with_token() + + def __update_metadata_with_token(self): + """ + Add or update the `authorization` key in the metadata with the current access token. + """ + if self._access_token_manager is not None: + token = self._access_token_manager.get_access_token() + + # Check if "authorization" already exists in the metadata + updated = False + for i, (key, _) in enumerate(self._metadata): + if key == "authorization": + self._metadata[i] = ("authorization", token) + updated = True + break + + # If not updated, add a new entry + if not updated: + self._metadata.append(("authorization", token)) + + def schedule_new_orchestration(self, *args, **kwargs) -> str: + self.__update_metadata_with_token() + return super().schedule_new_orchestration(*args, **kwargs) + + def get_orchestration_state(self, *args, **kwargs): + self.__update_metadata_with_token() + super().get_orchestration_state(*args, **kwargs) + + def wait_for_orchestration_start(self, *args, **kwargs): + self.__update_metadata_with_token() + super().wait_for_orchestration_start(*args, **kwargs) + + def wait_for_orchestration_completion(self, *args, **kwargs): + self.__update_metadata_with_token() + super().wait_for_orchestration_completion(*args, **kwargs) + + def raise_orchestration_event(self, *args, **kwargs): + self.__update_metadata_with_token() + super().raise_orchestration_event(*args, **kwargs) + + def terminate_orchestration(self, *args, **kwargs): + self.__update_metadata_with_token() + super().terminate_orchestration(*args, **kwargs) + + def suspend_orchestration(self, *args, **kwargs): + self.__update_metadata_with_token() + super().suspend_orchestration(*args, **kwargs) + + def resume_orchestration(self, *args, **kwargs): + self.__update_metadata_with_token() + super().resume_orchestration(*args, **kwargs) + + def purge_orchestration(self, *args, **kwargs): + self.__update_metadata_with_token() + super().purge_orchestration(*args, **kwargs) \ No newline at end of file diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py index f283661..7f44f67 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py @@ -10,7 +10,7 @@ from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager class DurableTaskSchedulerWorker(TaskHubGrpcWorker): - def __init__(self, *args, access_token_manager: AccessTokenManager = None, **kwargs): + def __init__(self, *args, access_token_manager: AccessTokenManager, **kwargs): # Initialize the base class super().__init__(*args, **kwargs) self._access_token_manager = access_token_manager From f8d79d380400cb8b60359b66d3bde14b69544742 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Wed, 29 Jan 2025 12:43:39 -0700 Subject: [PATCH 10/31] Cleaning up construction of dts objects and improving examples Signed-off-by: Ryan Lettieri --- examples/dts/dts_activity_sequence.py | 12 +- examples/dts/dts_fanout_fanin.py | 189 +++++++++--------- .../access_token_manager.py | 22 +- .../durabletask_scheduler_client.py | 48 +++-- .../durabletask_scheduler_worker.py | 49 +++-- 5 files changed, 170 insertions(+), 150 deletions(-) diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index e8d8c55..5c0ad41 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -47,22 +47,14 @@ def sequence(ctx: task.OrchestrationContext, _): exit() -# Define the scope for Azure Resource Manager (ARM) -arm_scope = "https://durabletask.io/.default" -token_manager = AccessTokenManager(scope = arm_scope) - -meta_data: list[tuple[str, str]] = [ - ("taskhub", taskhub_name) -] - # configure and start the worker -with DurableTaskSchedulerWorker(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) as w: +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, client_id="", taskhub=taskhub_name) as w: w.add_orchestrator(sequence) w.add_activity(hello) w.start() # Construct the client and run the orchestrations - c = DurableTaskSchedulerClient(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) instance_id = c.schedule_new_orchestration(sequence) state = c.wait_for_orchestration_completion(instance_id, timeout=45) if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: diff --git a/examples/dts/dts_fanout_fanin.py b/examples/dts/dts_fanout_fanin.py index 9c8cc65..d6d2917 100644 --- a/examples/dts/dts_fanout_fanin.py +++ b/examples/dts/dts_fanout_fanin.py @@ -1,99 +1,90 @@ -"""End-to-end sample that demonstrates how to configure an orchestrator -that a dynamic number activity functions in parallel, waits for them all -to complete, and prints an aggregate summary of the outputs.""" -import random -import time -import os -from durabletask import client, task -from durabletask import client, task -from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker -from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient -from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager - - -def get_work_items(ctx: task.ActivityContext, _) -> list[str]: - """Activity function that returns a list of work items""" - # return a random number of work items - count = random.randint(2, 10) - print(f'generating {count} work items...') - return [f'work item {i}' for i in range(count)] - - -def process_work_item(ctx: task.ActivityContext, item: str) -> int: - """Activity function that returns a result for a given work item""" - print(f'processing work item: {item}') - - # simulate some work that takes a variable amount of time - time.sleep(random.random() * 5) - - # return a result for the given work item, which is also a random number in this case - return random.randint(0, 10) - - -def orchestrator(ctx: task.OrchestrationContext, _): - """Orchestrator function that calls the 'get_work_items' and 'process_work_item' - activity functions in parallel, waits for them all to complete, and prints - an aggregate summary of the outputs""" - - work_items: list[str] = yield ctx.call_activity(get_work_items) - - # execute the work-items in parallel and wait for them all to return - tasks = [ctx.call_activity(process_work_item, input=item) for item in work_items] - results: list[int] = yield task.when_all(tasks) - - # return an aggregate summary of the results - return { - 'work_items': work_items, - 'results': results, - 'total': sum(results), - } - - -# Read the environment variable -taskhub_name = os.getenv("TASKHUB") - -# Check if the variable exists -if taskhub_name: - print(f"The value of TASKHUB is: {taskhub_name}") -else: - print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use") - print("If you are using windows powershell, run the following: $env:TASKHUB=\"\"") - print("If you are using bash, run the following: export TASKHUB=\"\"") - exit() - -# Read the environment variable -endpoint = os.getenv("ENDPOINT") - -# Check if the variable exists -if endpoint: - print(f"The value of ENDPOINT is: {endpoint}") -else: - print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler") - print("If you are using windows powershell, run the following: $env:ENDPOINT=\"\"") - print("If you are using bash, run the following: export ENDPOINT=\"\"") - exit() - -# Define the scope for Azure Resource Manager (ARM) -arm_scope = "https://durabletask.io/.default" -token_manager = AccessTokenManager(scope = arm_scope) - -meta_data: list[tuple[str, str]] = [ - ("taskhub", taskhub_name) -] - - -# configure and start the worker -with DurableTaskSchedulerWorker(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) as w: - w.add_orchestrator(orchestrator) - w.add_activity(process_work_item) - w.add_activity(get_work_items) - w.start() - - # create a client, start an orchestration, and wait for it to finish - c = DurableTaskSchedulerClient(host_address=endpoint, metadata=meta_data, secure_channel=True, access_token_manager=token_manager) - instance_id = c.schedule_new_orchestration(orchestrator) - state = c.wait_for_orchestration_completion(instance_id, timeout=30) - if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: - print(f'Orchestration completed! Result: {state.serialized_output}') - elif state: - print(f'Orchestration failed: {state.failure_details}') +"""End-to-end sample that demonstrates how to configure an orchestrator +that a dynamic number activity functions in parallel, waits for them all +to complete, and prints an aggregate summary of the outputs.""" +import random +import time +import os +from durabletask import client, task +from durabletask import client, task +from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker +from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient +from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager + + +def get_work_items(ctx: task.ActivityContext, _) -> list[str]: + """Activity function that returns a list of work items""" + # return a random number of work items + count = random.randint(2, 10) + print(f'generating {count} work items...') + return [f'work item {i}' for i in range(count)] + + +def process_work_item(ctx: task.ActivityContext, item: str) -> int: + """Activity function that returns a result for a given work item""" + print(f'processing work item: {item}') + + # simulate some work that takes a variable amount of time + time.sleep(random.random() * 5) + + # return a result for the given work item, which is also a random number in this case + return random.randint(0, 10) + + +def orchestrator(ctx: task.OrchestrationContext, _): + """Orchestrator function that calls the 'get_work_items' and 'process_work_item' + activity functions in parallel, waits for them all to complete, and prints + an aggregate summary of the outputs""" + + work_items: list[str] = yield ctx.call_activity(get_work_items) + + # execute the work-items in parallel and wait for them all to return + tasks = [ctx.call_activity(process_work_item, input=item) for item in work_items] + results: list[int] = yield task.when_all(tasks) + + # return an aggregate summary of the results + return { + 'work_items': work_items, + 'results': results, + 'total': sum(results), + } + + +# Read the environment variable +taskhub_name = os.getenv("TASKHUB") + +# Check if the variable exists +if taskhub_name: + print(f"The value of TASKHUB is: {taskhub_name}") +else: + print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use") + print("If you are using windows powershell, run the following: $env:TASKHUB=\"\"") + print("If you are using bash, run the following: export TASKHUB=\"\"") + exit() + +# Read the environment variable +endpoint = os.getenv("ENDPOINT") + +# Check if the variable exists +if endpoint: + print(f"The value of ENDPOINT is: {endpoint}") +else: + print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler") + print("If you are using windows powershell, run the following: $env:ENDPOINT=\"\"") + print("If you are using bash, run the following: export ENDPOINT=\"\"") + exit() + +# configure and start the worker +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, client_id="", taskhub=taskhub_name) as w: + w.add_orchestrator(orchestrator) + w.add_activity(process_work_item) + w.add_activity(get_work_items) + w.start() + + # create a client, start an orchestration, and wait for it to finish + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) + instance_id = c.schedule_new_orchestration(orchestrator) + state = c.wait_for_orchestration_completion(instance_id, timeout=30) + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') diff --git a/externalpackages/durabletaskscheduler/access_token_manager.py b/externalpackages/durabletaskscheduler/access_token_manager.py index 7628355..bbe4de0 100644 --- a/externalpackages/durabletaskscheduler/access_token_manager.py +++ b/externalpackages/durabletaskscheduler/access_token_manager.py @@ -1,14 +1,26 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -from azure.identity import DefaultAzureCredential +from azure.identity import DefaultAzureCredential, ManagedIdentityCredential from datetime import datetime, timedelta +from typing import Optional class AccessTokenManager: - def __init__(self, scope: str, refresh_buffer: int = 60): - self.scope = scope + def __init__(self, refresh_buffer: int = 60, use_managed_identity: bool = False, client_id: Optional[str] = None): + self.scope = "https://durabletask.io/.default" self.refresh_buffer = refresh_buffer - self.credential = DefaultAzureCredential() + + # Choose the appropriate credential based on use_managed_identity + if use_managed_identity: + if not client_id: + print("Using System Assigned Managed Identity for authentication.") + self.credential = ManagedIdentityCredential() + else: + print("Using User Assigned Managed Identity for authentication.") + self.credential = ManagedIdentityCredential(client_id) + else: + self.credential = DefaultAzureCredential() + print("Using Default Azure Credentials for authentication.") + self.token = None self.expiry_time = None diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py index 663eb5c..596254f 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py @@ -1,31 +1,43 @@ +from typing import Optional from durabletask.client import TaskHubGrpcClient from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager class DurableTaskSchedulerClient(TaskHubGrpcClient): - def __init__(self, *args, access_token_manager: AccessTokenManager, **kwargs): - # Initialize the base class - super().__init__(*args, **kwargs) - self._access_token_manager = access_token_manager + def __init__(self, *args, + metadata: Optional[list[tuple[str, str]]] = None, + client_id: Optional[str] = None, + taskhub: str, + **kwargs): + if metadata is None: + metadata = [] # Ensure metadata is initialized + self._metadata = metadata + self._client_id = client_id + self._metadata.append(("taskhub", taskhub)) + self._access_token_manager = AccessTokenManager(client_id=self._client_id) self.__update_metadata_with_token() + super().__init__(*args, metadata=self._metadata, **kwargs) def __update_metadata_with_token(self): """ Add or update the `authorization` key in the metadata with the current access token. """ - if self._access_token_manager is not None: - token = self._access_token_manager.get_access_token() - - # Check if "authorization" already exists in the metadata - updated = False - for i, (key, _) in enumerate(self._metadata): - if key == "authorization": - self._metadata[i] = ("authorization", token) - updated = True - break - - # If not updated, add a new entry - if not updated: - self._metadata.append(("authorization", token)) + token = self._access_token_manager.get_access_token() + + # Ensure that self._metadata is initialized + if self._metadata is None: + self._metadata = [] # Initialize it if it's still None + + # Check if "authorization" already exists in the metadata + updated = False + for i, (key, _) in enumerate(self._metadata): + if key == "authorization": + self._metadata[i] = ("authorization", token) + updated = True + break + + # If not updated, add a new entry + if not updated: + self._metadata.append(("authorization", token)) def schedule_new_orchestration(self, *args, **kwargs) -> str: self.__update_metadata_with_token() diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py index 7f44f67..9a45b36 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py @@ -5,35 +5,48 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared +from typing import Optional from durabletask.worker import TaskHubGrpcWorker from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager class DurableTaskSchedulerWorker(TaskHubGrpcWorker): - def __init__(self, *args, access_token_manager: AccessTokenManager, **kwargs): - # Initialize the base class - super().__init__(*args, **kwargs) - self._access_token_manager = access_token_manager + def __init__(self, *args, + metadata: Optional[list[tuple[str, str]]] = None, + client_id: Optional[str] = None, + taskhub: str, + **kwargs): + if metadata is None: + metadata = [] # Ensure metadata is initialized + self._metadata = metadata + self._client_id = client_id + self._metadata.append(("taskhub", taskhub)) + self._access_token_manager = AccessTokenManager(client_id=self._client_id) self.__update_metadata_with_token() + super().__init__(*args, metadata=self._metadata, **kwargs) + def __update_metadata_with_token(self): """ Add or update the `authorization` key in the metadata with the current access token. """ - if self._access_token_manager is not None: - token = self._access_token_manager.get_access_token() - - # Check if "authorization" already exists in the metadata - updated = False - for i, (key, _) in enumerate(self._metadata): - if key == "authorization": - self._metadata[i] = ("authorization", token) - updated = True - break - - # If not updated, add a new entry - if not updated: - self._metadata.append(("authorization", token)) + token = self._access_token_manager.get_access_token() + + # Ensure that self._metadata is initialized + if self._metadata is None: + self._metadata = [] # Initialize it if it's still None + + # Check if "authorization" already exists in the metadata + updated = False + for i, (key, _) in enumerate(self._metadata): + if key == "authorization": + self._metadata[i] = ("authorization", token) + updated = True + break + + # If not updated, add a new entry + if not updated: + self._metadata.append(("authorization", token)) def start(self): """Starts the worker on a background thread and begins listening for work items.""" From 1e676510507c3ae3ee154a2c4e34ecd80f207b84 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Tue, 4 Feb 2025 16:01:51 -0700 Subject: [PATCH 11/31] Migrating shared access token logic to new grpc class Signed-off-by: Ryan Lettieri --- durabletask/internal/shared.py | 8 +- examples/dts/dts_activity_sequence.py | 5 +- examples/dts/dts_fanout_fanin.py | 4 +- .../access_token_manager.py | 32 +++++-- .../durabletask_grpc_interceptor.py | 29 +++++++ .../durabletask_scheduler_client.py | 56 ++++-------- .../durabletask_scheduler_worker.py | 86 ++++--------------- 7 files changed, 94 insertions(+), 126 deletions(-) create mode 100644 externalpackages/durabletaskscheduler/durabletask_grpc_interceptor.py diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index c4f3aa4..0e3ee77 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -10,6 +10,7 @@ import grpc from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl +from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl # Field name used to indicate that an object was automatically serialized # and should be deserialized as a SimpleNamespace @@ -50,7 +51,12 @@ def get_grpc_channel( channel = grpc.insecure_channel(host_address) if metadata is not None and len(metadata) > 0: - interceptors = [DefaultClientInterceptorImpl(metadata)] + for key, _ in metadata: + # Check if we are using DTS as the backend and if so, construct the DTS specific interceptors + if key == "dts": + interceptors = [DTSDefaultClientInterceptorImpl(metadata)] + else: + interceptors = [DefaultClientInterceptorImpl(metadata)] channel = grpc.intercept_channel(channel, *interceptors) return channel diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index 5c0ad41..8d52089 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -4,7 +4,6 @@ from durabletask import client, task from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient -from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager def hello(ctx: task.ActivityContext, name: str) -> str: """Activity function that returns a greeting""" @@ -48,7 +47,7 @@ def sequence(ctx: task.OrchestrationContext, _): # configure and start the worker -with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, client_id="", taskhub=taskhub_name) as w: +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, use_managed_identity=False, client_id="", taskhub=taskhub_name) as w: w.add_orchestrator(sequence) w.add_activity(hello) w.start() @@ -56,7 +55,7 @@ def sequence(ctx: task.OrchestrationContext, _): # Construct the client and run the orchestrations c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) instance_id = c.schedule_new_orchestration(sequence) - state = c.wait_for_orchestration_completion(instance_id, timeout=45) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: print(f'Orchestration completed! Result: {state.serialized_output}') elif state: diff --git a/examples/dts/dts_fanout_fanin.py b/examples/dts/dts_fanout_fanin.py index d6d2917..90cda73 100644 --- a/examples/dts/dts_fanout_fanin.py +++ b/examples/dts/dts_fanout_fanin.py @@ -8,7 +8,6 @@ from durabletask import client, task from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient -from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager def get_work_items(ctx: task.ActivityContext, _) -> list[str]: @@ -74,7 +73,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): exit() # configure and start the worker -with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, client_id="", taskhub=taskhub_name) as w: +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) as w: w.add_orchestrator(orchestrator) w.add_activity(process_work_item) w.add_activity(get_work_items) @@ -88,3 +87,4 @@ def orchestrator(ctx: task.OrchestrationContext, _): print(f'Orchestration completed! Result: {state.serialized_output}') elif state: print(f'Orchestration failed: {state.failure_details}') + exit() \ No newline at end of file diff --git a/externalpackages/durabletaskscheduler/access_token_manager.py b/externalpackages/durabletaskscheduler/access_token_manager.py index bbe4de0..d36b7c0 100644 --- a/externalpackages/durabletaskscheduler/access_token_manager.py +++ b/externalpackages/durabletaskscheduler/access_token_manager.py @@ -1,22 +1,33 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. from azure.identity import DefaultAzureCredential, ManagedIdentityCredential -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from typing import Optional +# By default, when there's 10minutes left before the token expires, refresh the token class AccessTokenManager: - def __init__(self, refresh_buffer: int = 60, use_managed_identity: bool = False, client_id: Optional[str] = None): + def __init__(self, refresh_buffer: int = 600, metadata: Optional[list[tuple[str, str]]] = None): self.scope = "https://durabletask.io/.default" self.refresh_buffer = refresh_buffer - + self._use_managed_identity = False + self._metadata = metadata + self._client_id = None + + if metadata: # Ensure metadata is not None + for key, value in metadata: + if key == "use_managed_identity": + self._use_managed_identity = value.lower() == "true" # Properly convert string to bool + elif key == "client_id": + self._client_id = value # Directly assign string + # Choose the appropriate credential based on use_managed_identity - if use_managed_identity: - if not client_id: + if self._use_managed_identity: + if not self._client_id: print("Using System Assigned Managed Identity for authentication.") self.credential = ManagedIdentityCredential() else: print("Using User Assigned Managed Identity for authentication.") - self.credential = ManagedIdentityCredential(client_id) + self.credential = ManagedIdentityCredential(client_id=self._client_id) else: self.credential = DefaultAzureCredential() print("Using Default Azure Credentials for authentication.") @@ -29,13 +40,18 @@ def get_access_token(self) -> str: self.refresh_token() return self.token + # Checks if the token is expired, or if it will expire in the next "refresh_buffer" seconds. + # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes, + # We will grab a new token when there're 30minutes left on the lifespan of the token def is_token_expired(self) -> bool: if self.expiry_time is None: return True - return datetime.utcnow() >= (self.expiry_time - timedelta(seconds=self.refresh_buffer)) + return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_buffer)) def refresh_token(self): new_token = self.credential.get_token(self.scope) self.token = f"Bearer {new_token.token}" - self.expiry_time = datetime.utcnow() + timedelta(seconds=new_token.expires_on - int(datetime.utcnow().timestamp())) + + # Convert UNIX timestamp to timezone-aware datetime + self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc) print(f"Token refreshed. Expires at: {self.expiry_time}") \ No newline at end of file diff --git a/externalpackages/durabletaskscheduler/durabletask_grpc_interceptor.py b/externalpackages/durabletaskscheduler/durabletask_grpc_interceptor.py new file mode 100644 index 0000000..39c84d6 --- /dev/null +++ b/externalpackages/durabletaskscheduler/durabletask_grpc_interceptor.py @@ -0,0 +1,29 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl +from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager + +import grpc + +class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): + """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an + interceptor to add additional headers to all calls as needed.""" + + def __init__(self, metadata: list[tuple[str, str]]): + super().__init__(metadata) + self._token_manager = AccessTokenManager(metadata=self._metadata) + + def _intercept_call( + self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC + call details.""" + # Refresh the auth token if it is present and needed + if self._metadata is not None: + for i, (key, _) in enumerate(self._metadata): + if key.lower() == "authorization": # Ensure case-insensitive comparison + new_token = self._token_manager.get_access_token() # Get the new token + self._metadata[i] = ("authorization", new_token) # Update the token + + return super()._intercept_call(client_call_details) diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py index 596254f..972eab3 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py @@ -1,21 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + from typing import Optional from durabletask.client import TaskHubGrpcClient from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager class DurableTaskSchedulerClient(TaskHubGrpcClient): - def __init__(self, *args, + def __init__(self, + host_address: str, + secure_channel: bool, metadata: Optional[list[tuple[str, str]]] = None, + use_managed_identity: Optional[bool] = False, client_id: Optional[str] = None, - taskhub: str, + taskhub: str = None, **kwargs): if metadata is None: metadata = [] # Ensure metadata is initialized self._metadata = metadata + self._use_managed_identity = use_managed_identity self._client_id = client_id self._metadata.append(("taskhub", taskhub)) - self._access_token_manager = AccessTokenManager(client_id=self._client_id) + self._metadata.append(("dts", "True")) + self._metadata.append(("use_managed_identity", str(use_managed_identity))) + self._metadata.append(("client_id", str(client_id))) + self._access_token_manager = AccessTokenManager(metadata=self._metadata) self.__update_metadata_with_token() - super().__init__(*args, metadata=self._metadata, **kwargs) + super().__init__(host_address=host_address, secure_channel=secure_channel, metadata=self._metadata, **kwargs) def __update_metadata_with_token(self): """ @@ -37,40 +47,4 @@ def __update_metadata_with_token(self): # If not updated, add a new entry if not updated: - self._metadata.append(("authorization", token)) - - def schedule_new_orchestration(self, *args, **kwargs) -> str: - self.__update_metadata_with_token() - return super().schedule_new_orchestration(*args, **kwargs) - - def get_orchestration_state(self, *args, **kwargs): - self.__update_metadata_with_token() - super().get_orchestration_state(*args, **kwargs) - - def wait_for_orchestration_start(self, *args, **kwargs): - self.__update_metadata_with_token() - super().wait_for_orchestration_start(*args, **kwargs) - - def wait_for_orchestration_completion(self, *args, **kwargs): - self.__update_metadata_with_token() - super().wait_for_orchestration_completion(*args, **kwargs) - - def raise_orchestration_event(self, *args, **kwargs): - self.__update_metadata_with_token() - super().raise_orchestration_event(*args, **kwargs) - - def terminate_orchestration(self, *args, **kwargs): - self.__update_metadata_with_token() - super().terminate_orchestration(*args, **kwargs) - - def suspend_orchestration(self, *args, **kwargs): - self.__update_metadata_with_token() - super().suspend_orchestration(*args, **kwargs) - - def resume_orchestration(self, *args, **kwargs): - self.__update_metadata_with_token() - super().resume_orchestration(*args, **kwargs) - - def purge_orchestration(self, *args, **kwargs): - self.__update_metadata_with_token() - super().purge_orchestration(*args, **kwargs) \ No newline at end of file + self._metadata.append(("authorization", token)) \ No newline at end of file diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py index 9a45b36..5ba1afe 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py @@ -1,30 +1,32 @@ -import concurrent.futures -from threading import Thread -from google.protobuf import empty_pb2 -import grpc -import durabletask.internal.orchestrator_service_pb2 as pb -import durabletask.internal.orchestrator_service_pb2_grpc as stubs -import durabletask.internal.shared as shared -from typing import Optional +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Optional from durabletask.worker import TaskHubGrpcWorker from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager +# Worker class used for Durable Task Scheduler (DTS) class DurableTaskSchedulerWorker(TaskHubGrpcWorker): - def __init__(self, *args, + def __init__(self, + host_address: str, + secure_channel: bool, metadata: Optional[list[tuple[str, str]]] = None, + use_managed_identity: Optional[bool] = False, client_id: Optional[str] = None, - taskhub: str, + taskhub: str = None, **kwargs): if metadata is None: metadata = [] # Ensure metadata is initialized self._metadata = metadata + self._use_managed_identity = use_managed_identity self._client_id = client_id self._metadata.append(("taskhub", taskhub)) - self._access_token_manager = AccessTokenManager(client_id=self._client_id) + self._metadata.append(("dts", "True")) + self._metadata.append(("use_managed_identity", str(use_managed_identity))) + self._metadata.append(("client_id", str(client_id))) + self._access_token_manager = AccessTokenManager(metadata=self._metadata) self.__update_metadata_with_token() - super().__init__(*args, metadata=self._metadata, **kwargs) - + super().__init__(host_address=host_address, secure_channel=secure_channel, metadata=self._metadata, **kwargs) def __update_metadata_with_token(self): """ @@ -47,61 +49,3 @@ def __update_metadata_with_token(self): # If not updated, add a new entry if not updated: self._metadata.append(("authorization", token)) - - def start(self): - """Starts the worker on a background thread and begins listening for work items.""" - channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel) - stub = stubs.TaskHubSidecarServiceStub(channel) - - if self._is_running: - raise RuntimeError('The worker is already running.') - - def run_loop(): - # TODO: Investigate whether asyncio could be used to enable greater concurrency for async activity - # functions. We'd need to know ahead of time whether a function is async or not. - # TODO: Max concurrency configuration settings - with concurrent.futures.ThreadPoolExecutor(max_workers=16) as executor: - while not self._shutdown.is_set(): - try: - self.__update_metadata_with_token() - # send a "Hello" message to the sidecar to ensure that it's listening - stub.Hello(empty_pb2.Empty()) - - # stream work items - self._response_stream = stub.GetWorkItems(pb.GetWorkItemsRequest()) - self._logger.info(f'Successfully connected to {self._host_address}. Waiting for work items...') - - # The stream blocks until either a work item is received or the stream is canceled - # by another thread (see the stop() method). - for work_item in self._response_stream: # type: ignore - request_type = work_item.WhichOneof('request') - self._logger.debug(f'Received "{request_type}" work item') - if work_item.HasField('orchestratorRequest'): - executor.submit(self._execute_orchestrator, work_item.orchestratorRequest, stub, work_item.completionToken) - elif work_item.HasField('activityRequest'): - executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken) - elif work_item.HasField('healthPing'): - pass # no-op - else: - self._logger.warning(f'Unexpected work item type: {request_type}') - - except grpc.RpcError as rpc_error: - if rpc_error.code() == grpc.StatusCode.CANCELLED: # type: ignore - self._logger.info(f'Disconnected from {self._host_address}') - elif rpc_error.code() == grpc.StatusCode.UNAVAILABLE: # type: ignore - self._logger.warning( - f'The sidecar at address {self._host_address} is unavailable - will continue retrying') - else: - self._logger.warning(f'Unexpected error: {rpc_error}') - except Exception as ex: - self._logger.warning(f'Unexpected error: {ex}') - - # CONSIDER: exponential backoff - self._shutdown.wait(5) - self._logger.info("No longer listening for work items") - return - - self._logger.info(f"Starting gRPC worker that connects to {self._host_address}") - self._runLoop = Thread(target=run_loop) - self._runLoop.start() - self._is_running = True From 6b1bfd2e13aba73621f73be2f37473e09ce34c51 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Wed, 5 Feb 2025 11:01:56 -0700 Subject: [PATCH 12/31] Adding log statements to access_token_manager Signed-off-by: Ryan Lettieri --- .../durabletaskscheduler/access_token_manager.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/externalpackages/durabletaskscheduler/access_token_manager.py b/externalpackages/durabletaskscheduler/access_token_manager.py index d36b7c0..63fd69e 100644 --- a/externalpackages/durabletaskscheduler/access_token_manager.py +++ b/externalpackages/durabletaskscheduler/access_token_manager.py @@ -3,6 +3,7 @@ from azure.identity import DefaultAzureCredential, ManagedIdentityCredential from datetime import datetime, timedelta, timezone from typing import Optional +import durabletask.internal.shared as shared # By default, when there's 10minutes left before the token expires, refresh the token class AccessTokenManager: @@ -12,6 +13,7 @@ def __init__(self, refresh_buffer: int = 600, metadata: Optional[list[tuple[str, self._use_managed_identity = False self._metadata = metadata self._client_id = None + self._logger = shared.get_logger("token_manager") if metadata: # Ensure metadata is not None for key, value in metadata: @@ -23,14 +25,14 @@ def __init__(self, refresh_buffer: int = 600, metadata: Optional[list[tuple[str, # Choose the appropriate credential based on use_managed_identity if self._use_managed_identity: if not self._client_id: - print("Using System Assigned Managed Identity for authentication.") + self._logger.debug("Using System Assigned Managed Identity for authentication.") self.credential = ManagedIdentityCredential() else: - print("Using User Assigned Managed Identity for authentication.") + self._logger.debug("Using User Assigned Managed Identity for authentication.") self.credential = ManagedIdentityCredential(client_id=self._client_id) else: self.credential = DefaultAzureCredential() - print("Using Default Azure Credentials for authentication.") + self._logger.debug("Using Default Azure Credentials for authentication.") self.token = None self.expiry_time = None @@ -54,4 +56,4 @@ def refresh_token(self): # Convert UNIX timestamp to timezone-aware datetime self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc) - print(f"Token refreshed. Expires at: {self.expiry_time}") \ No newline at end of file + self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}") \ No newline at end of file From bd56a35090c36467a8f95d06e2bde33e706f32a9 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Wed, 5 Feb 2025 12:42:17 -0700 Subject: [PATCH 13/31] breaking for loop when setting interceptors Signed-off-by: Ryan Lettieri --- durabletask/internal/shared.py | 1 + 1 file changed, 1 insertion(+) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 0e3ee77..34d5bf5 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -55,6 +55,7 @@ def get_grpc_channel( # Check if we are using DTS as the backend and if so, construct the DTS specific interceptors if key == "dts": interceptors = [DTSDefaultClientInterceptorImpl(metadata)] + break else: interceptors = [DefaultClientInterceptorImpl(metadata)] channel = grpc.intercept_channel(channel, *interceptors) From efc01463401dde2ef16cc5b7c05e89f85a35b279 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Fri, 7 Feb 2025 10:57:30 -0700 Subject: [PATCH 14/31] Removing changes to client.py and adding additional steps to readme.md Signed-off-by: Ryan Lettieri --- durabletask/client.py | 1 - examples/dts/README.md | 6 ++++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/durabletask/client.py b/durabletask/client.py index 74e51f5..31953ae 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -97,7 +97,6 @@ def __init__(self, *, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False): - self._metadata = metadata channel = shared.get_grpc_channel(host_address, metadata, secure_channel=secure_channel) self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) diff --git a/examples/dts/README.md b/examples/dts/README.md index f349cf2..feefa05 100644 --- a/examples/dts/README.md +++ b/examples/dts/README.md @@ -32,6 +32,12 @@ export ENDPOINT= python3 -m pip install azure-identity ``` +6. Install the correct pacakges from the top level of this repository, i.e. durabletask-python/ + +```sh +python3 -m pip install . +``` + ## Running the examples With one of the sidecars running, you can simply execute any of the examples in this directory using `python3`: From 3fd0b089ee77f76187b915a5becda25ccb17ae84 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Tue, 11 Feb 2025 11:18:38 -0700 Subject: [PATCH 15/31] Refactoring client and worker to pass around interceptors Signed-off-by: Ryan Lettieri --- durabletask/client.py | 21 +++++++++++-- durabletask/internal/shared.py | 20 +++++-------- durabletask/worker.py | 23 ++++++++++++-- .../durabletask_scheduler_client.py | 30 +++++++++++++------ .../durabletask_scheduler_worker.py | 27 ++++++++++++----- 5 files changed, 86 insertions(+), 35 deletions(-) diff --git a/durabletask/client.py b/durabletask/client.py index 31953ae..0820a72 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -15,6 +15,8 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl + from durabletask import task TInput = TypeVar('TInput') @@ -96,8 +98,23 @@ def __init__(self, *, metadata: Optional[list[tuple[str, str]]] = None, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False): - channel = shared.get_grpc_channel(host_address, metadata, secure_channel=secure_channel) + secure_channel: bool = False, + interceptors: Optional[list] = None): + + # Determine the interceptors to use + if interceptors is not None: + self._interceptors = interceptors + elif metadata: + self._interceptors = [DefaultClientInterceptorImpl(metadata)] + else: + self._interceptors = None + + channel = shared.get_grpc_channel( + host_address=host_address, + metadata=metadata, + secure_channel=secure_channel, + interceptors=self._interceptors + ) self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 34d5bf5..d57fbcc 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -9,9 +9,6 @@ import grpc -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl -from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl - # Field name used to indicate that an object was automatically serialized # and should be deserialized as a SimpleNamespace AUTO_SERIALIZED = "__durabletask_autoobject__" @@ -26,8 +23,10 @@ def get_default_host_address() -> str: def get_grpc_channel( host_address: Optional[str], - metadata: Optional[list[tuple[str, str]]], - secure_channel: bool = False) -> grpc.Channel: + metadata: Optional[list[tuple[str, str]]] = None, + secure_channel: bool = False, + interceptors: Optional[list] = None) -> grpc.Channel: + if host_address is None: host_address = get_default_host_address() @@ -45,19 +44,14 @@ def get_grpc_channel( host_address = host_address[len(protocol):] break + # Create the base channel if secure_channel: channel = grpc.secure_channel(host_address, grpc.ssl_channel_credentials()) else: channel = grpc.insecure_channel(host_address) - if metadata is not None and len(metadata) > 0: - for key, _ in metadata: - # Check if we are using DTS as the backend and if so, construct the DTS specific interceptors - if key == "dts": - interceptors = [DTSDefaultClientInterceptorImpl(metadata)] - break - else: - interceptors = [DefaultClientInterceptorImpl(metadata)] + # Apply interceptors ONLY if they exist + if interceptors: channel = grpc.intercept_channel(channel, *interceptors) return channel diff --git a/durabletask/worker.py b/durabletask/worker.py index 51a62fd..6dce0e4 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -16,7 +16,9 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared + from durabletask import task +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') @@ -88,15 +90,25 @@ def __init__(self, *, metadata: Optional[list[tuple[str, str]]] = None, log_handler=None, log_formatter: Optional[logging.Formatter] = None, - secure_channel: bool = False): + secure_channel: bool = False, + interceptors: Optional[list[grpc.ServerInterceptor]] = None): # Add interceptors self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() - self._metadata = metadata + self._metadata = metadata or [] # Ensure metadata is never None self._logger = shared.get_logger("worker", log_handler, log_formatter) self._shutdown = Event() self._is_running = False self._secure_channel = secure_channel + # Determine the interceptors to use + if interceptors is not None: + self._interceptors = interceptors + elif self._metadata: + self._interceptors = [DefaultClientInterceptorImpl(self._metadata)] + else: + self._interceptors = None + + def __enter__(self): return self @@ -117,7 +129,12 @@ def add_activity(self, fn: task.Activity) -> str: def start(self): """Starts the worker on a background thread and begins listening for work items.""" - channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel) + + if self._metadata: + interceptors = [DefaultClientInterceptorImpl(self._metadata)] + else: + interceptors = None + channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel, interceptors) stub = stubs.TaskHubSidecarServiceStub(channel) if self._is_running: diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py index 972eab3..38e9278 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py @@ -4,28 +4,40 @@ from typing import Optional from durabletask.client import TaskHubGrpcClient from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager +from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl +# Client class used for Durable Task Scheduler (DTS) class DurableTaskSchedulerClient(TaskHubGrpcClient): def __init__(self, host_address: str, secure_channel: bool, - metadata: Optional[list[tuple[str, str]]] = None, + metadata: Optional[list[tuple[str, str]]] = [], use_managed_identity: Optional[bool] = False, client_id: Optional[str] = None, taskhub: str = None, **kwargs): - if metadata is None: - metadata = [] # Ensure metadata is initialized - self._metadata = metadata - self._use_managed_identity = use_managed_identity - self._client_id = client_id - self._metadata.append(("taskhub", taskhub)) + + # Ensure metadata is a list + metadata = metadata or [] + self._metadata = metadata.copy() # Use a copy to avoid modifying original + + # Append DurableTask-specific metadata + self._metadata.append(("taskhub", taskhub or "default-taskhub")) self._metadata.append(("dts", "True")) self._metadata.append(("use_managed_identity", str(use_managed_identity))) - self._metadata.append(("client_id", str(client_id))) + self._metadata.append(("client_id", str(client_id or "None"))) + self._access_token_manager = AccessTokenManager(metadata=self._metadata) self.__update_metadata_with_token() - super().__init__(host_address=host_address, secure_channel=secure_channel, metadata=self._metadata, **kwargs) + interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + + super().__init__( + host_address=host_address, + secure_channel=secure_channel, + metadata=self._metadata, + interceptors=interceptors, # Now explicitly passing interceptors + **kwargs + ) def __update_metadata_with_token(self): """ diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py index 5ba1afe..9d830ee 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py @@ -4,6 +4,7 @@ from typing import Optional from durabletask.worker import TaskHubGrpcWorker from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager +from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl # Worker class used for Durable Task Scheduler (DTS) class DurableTaskSchedulerWorker(TaskHubGrpcWorker): @@ -15,18 +16,28 @@ def __init__(self, client_id: Optional[str] = None, taskhub: str = None, **kwargs): - if metadata is None: - metadata = [] # Ensure metadata is initialized - self._metadata = metadata - self._use_managed_identity = use_managed_identity - self._client_id = client_id - self._metadata.append(("taskhub", taskhub)) + + # Ensure metadata is a list + metadata = metadata or [] + self._metadata = metadata.copy() # Copy to prevent modifying input + + # Append DurableTask-specific metadata + self._metadata.append(("taskhub", taskhub or "default-taskhub")) self._metadata.append(("dts", "True")) self._metadata.append(("use_managed_identity", str(use_managed_identity))) - self._metadata.append(("client_id", str(client_id))) + self._metadata.append(("client_id", str(client_id or "None"))) + self._access_token_manager = AccessTokenManager(metadata=self._metadata) self.__update_metadata_with_token() - super().__init__(host_address=host_address, secure_channel=secure_channel, metadata=self._metadata, **kwargs) + interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + + super().__init__( + host_address=host_address, + secure_channel=secure_channel, + metadata=self._metadata, + interceptors=interceptors, + **kwargs + ) def __update_metadata_with_token(self): """ From 4260d025c0415f6a0a7ec3c5d48b435c0a4e85a7 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Tue, 11 Feb 2025 11:28:04 -0700 Subject: [PATCH 16/31] Fixing import for DefaultClientInterceptorImpl Signed-off-by: Ryan Lettieri --- tests/test_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_client.py b/tests/test_client.py index caacf65..b3056d2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,8 +1,8 @@ from unittest.mock import patch, ANY -from durabletask.internal.shared import (DefaultClientInterceptorImpl, - get_default_host_address, +from durabletask.internal.shared import (get_default_host_address, get_grpc_channel) +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl HOST_ADDRESS = 'localhost:50051' METADATA = [('key1', 'value1'), ('key2', 'value2')] From ec4617c08f25b7b24dcc9c0bd91ee41c7749f83b Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Tue, 11 Feb 2025 14:12:40 -0700 Subject: [PATCH 17/31] Adressing round 1 of feedback Signed-off-by: Ryan Lettieri --- durabletask/client.py | 2 +- durabletask/worker.py | 11 +++-------- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/durabletask/client.py b/durabletask/client.py index 0820a72..8b3ac94 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -99,7 +99,7 @@ def __init__(self, *, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, - interceptors: Optional[list] = None): + interceptors: Optional[list[DefaultClientInterceptorImpl]] = None): # Determine the interceptors to use if interceptors is not None: diff --git a/durabletask/worker.py b/durabletask/worker.py index 6dce0e4..d06fb02 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -91,10 +91,10 @@ def __init__(self, *, log_handler=None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, - interceptors: Optional[list[grpc.ServerInterceptor]] = None): # Add interceptors + interceptors: Optional[list[DefaultClientInterceptorImpl]] = None): # Add interceptors self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() - self._metadata = metadata or [] # Ensure metadata is never None + self._metadata = metadata self._logger = shared.get_logger("worker", log_handler, log_formatter) self._shutdown = Event() self._is_running = False @@ -129,12 +129,7 @@ def add_activity(self, fn: task.Activity) -> str: def start(self): """Starts the worker on a background thread and begins listening for work items.""" - - if self._metadata: - interceptors = [DefaultClientInterceptorImpl(self._metadata)] - else: - interceptors = None - channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel, interceptors) + channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel, self._interceptors) stub = stubs.TaskHubSidecarServiceStub(channel) if self._is_running: From ed733ea7eaf6ab43d23fd3dfbd7c1b7ab1a0f04c Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Tue, 11 Feb 2025 20:16:42 -0700 Subject: [PATCH 18/31] Fixing interceptor issue Signed-off-by: Ryan Lettieri --- durabletask/client.py | 2 ++ durabletask/internal/shared.py | 3 ++- durabletask/worker.py | 2 ++ .../durabletask_scheduler_client.py | 10 ++++++---- .../durabletask_scheduler_worker.py | 4 +++- 5 files changed, 15 insertions(+), 6 deletions(-) diff --git a/durabletask/client.py b/durabletask/client.py index 8b3ac94..55a30e4 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -104,6 +104,8 @@ def __init__(self, *, # Determine the interceptors to use if interceptors is not None: self._interceptors = interceptors + if metadata: + self._interceptors.append(DefaultClientInterceptorImpl(metadata)) elif metadata: self._interceptors = [DefaultClientInterceptorImpl(metadata)] else: diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index d57fbcc..6327796 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -6,6 +6,7 @@ import logging from types import SimpleNamespace from typing import Any, Optional +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl import grpc @@ -25,7 +26,7 @@ def get_grpc_channel( host_address: Optional[str], metadata: Optional[list[tuple[str, str]]] = None, secure_channel: bool = False, - interceptors: Optional[list] = None) -> grpc.Channel: + interceptors: Optional[list[DefaultClientInterceptorImpl]] = None) -> grpc.Channel: if host_address is None: host_address = get_default_host_address() diff --git a/durabletask/worker.py b/durabletask/worker.py index d06fb02..e67e5ca 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -103,6 +103,8 @@ def __init__(self, *, # Determine the interceptors to use if interceptors is not None: self._interceptors = interceptors + if metadata: + self._interceptors.append(DefaultClientInterceptorImpl(metadata)) elif self._metadata: self._interceptors = [DefaultClientInterceptorImpl(self._metadata)] else: diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py index 38e9278..f42a8d8 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py @@ -11,7 +11,7 @@ class DurableTaskSchedulerClient(TaskHubGrpcClient): def __init__(self, host_address: str, secure_channel: bool, - metadata: Optional[list[tuple[str, str]]] = [], + metadata: Optional[list[tuple[str, str]]] = None, use_managed_identity: Optional[bool] = False, client_id: Optional[str] = None, taskhub: str = None, @@ -29,13 +29,15 @@ def __init__(self, self._access_token_manager = AccessTokenManager(metadata=self._metadata) self.__update_metadata_with_token() - interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + # We pass in None for the metadata so we don't construct an additional interceptor in the parent class + # Since the parent class doesn't use anything metadata for anything else, we can set it as None super().__init__( host_address=host_address, secure_channel=secure_channel, - metadata=self._metadata, - interceptors=interceptors, # Now explicitly passing interceptors + metadata=None, + interceptors=self._interceptors, **kwargs ) diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py index 9d830ee..f6bd184 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py +++ b/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py @@ -31,10 +31,12 @@ def __init__(self, self.__update_metadata_with_token() interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + # We pass in None for the metadata so we don't construct an additional interceptor in the parent class + # Since the parent class doesn't use anything metadata for anything else, we can set it as None super().__init__( host_address=host_address, secure_channel=secure_channel, - metadata=self._metadata, + metadata=None, interceptors=interceptors, **kwargs ) From 99f62d7453e66486599bd02651270eaad7bd744c Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Wed, 12 Feb 2025 13:06:43 -0700 Subject: [PATCH 19/31] Moving some files around to remove dependencies Signed-off-by: Ryan Lettieri --- durabletask-azuremanaged/__init__.py | 0 .../durabletask/azuremanaged/__init__.py | 0 .../azuremanaged}/access_token_manager.py | 116 ++++++++-------- .../durabletask_grpc_interceptor.py | 58 ++++---- .../durabletask_keep_alive_service.py | 28 ++++ .../durabletask_scheduler_client.py | 126 ++++++++--------- .../durabletask_scheduler_worker.py | 128 +++++++++--------- durabletask-azuremanaged/pyproject.toml | 41 ++++++ durabletask/client.py | 2 +- examples/dts/README.md | 8 +- examples/dts/dts_activity_sequence.py | 4 +- examples/dts/dts_fanout_fanin.py | 4 +- .../durabletaskscheduler/__init__.py | 7 - requirements.txt | 3 +- 14 files changed, 297 insertions(+), 228 deletions(-) create mode 100644 durabletask-azuremanaged/__init__.py create mode 100644 durabletask-azuremanaged/durabletask/azuremanaged/__init__.py rename {externalpackages/durabletaskscheduler => durabletask-azuremanaged/durabletask/azuremanaged}/access_token_manager.py (98%) rename {externalpackages/durabletaskscheduler => durabletask-azuremanaged/durabletask/azuremanaged}/durabletask_grpc_interceptor.py (92%) create mode 100644 durabletask-azuremanaged/durabletask/azuremanaged/durabletask_keep_alive_service.py rename {externalpackages/durabletaskscheduler => durabletask-azuremanaged/durabletask/azuremanaged}/durabletask_scheduler_client.py (90%) rename {externalpackages/durabletaskscheduler => durabletask-azuremanaged/durabletask/azuremanaged}/durabletask_scheduler_worker.py (90%) create mode 100644 durabletask-azuremanaged/pyproject.toml delete mode 100644 externalpackages/durabletaskscheduler/__init__.py diff --git a/durabletask-azuremanaged/__init__.py b/durabletask-azuremanaged/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/__init__.py b/durabletask-azuremanaged/durabletask/azuremanaged/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/externalpackages/durabletaskscheduler/access_token_manager.py b/durabletask-azuremanaged/durabletask/azuremanaged/access_token_manager.py similarity index 98% rename from externalpackages/durabletaskscheduler/access_token_manager.py rename to durabletask-azuremanaged/durabletask/azuremanaged/access_token_manager.py index 63fd69e..43e86dc 100644 --- a/externalpackages/durabletaskscheduler/access_token_manager.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/access_token_manager.py @@ -1,59 +1,59 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from azure.identity import DefaultAzureCredential, ManagedIdentityCredential -from datetime import datetime, timedelta, timezone -from typing import Optional -import durabletask.internal.shared as shared - -# By default, when there's 10minutes left before the token expires, refresh the token -class AccessTokenManager: - def __init__(self, refresh_buffer: int = 600, metadata: Optional[list[tuple[str, str]]] = None): - self.scope = "https://durabletask.io/.default" - self.refresh_buffer = refresh_buffer - self._use_managed_identity = False - self._metadata = metadata - self._client_id = None - self._logger = shared.get_logger("token_manager") - - if metadata: # Ensure metadata is not None - for key, value in metadata: - if key == "use_managed_identity": - self._use_managed_identity = value.lower() == "true" # Properly convert string to bool - elif key == "client_id": - self._client_id = value # Directly assign string - - # Choose the appropriate credential based on use_managed_identity - if self._use_managed_identity: - if not self._client_id: - self._logger.debug("Using System Assigned Managed Identity for authentication.") - self.credential = ManagedIdentityCredential() - else: - self._logger.debug("Using User Assigned Managed Identity for authentication.") - self.credential = ManagedIdentityCredential(client_id=self._client_id) - else: - self.credential = DefaultAzureCredential() - self._logger.debug("Using Default Azure Credentials for authentication.") - - self.token = None - self.expiry_time = None - - def get_access_token(self) -> str: - if self.token is None or self.is_token_expired(): - self.refresh_token() - return self.token - - # Checks if the token is expired, or if it will expire in the next "refresh_buffer" seconds. - # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes, - # We will grab a new token when there're 30minutes left on the lifespan of the token - def is_token_expired(self) -> bool: - if self.expiry_time is None: - return True - return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_buffer)) - - def refresh_token(self): - new_token = self.credential.get_token(self.scope) - self.token = f"Bearer {new_token.token}" - - # Convert UNIX timestamp to timezone-aware datetime - self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from azure.identity import DefaultAzureCredential, ManagedIdentityCredential +from datetime import datetime, timedelta, timezone +from typing import Optional +import durabletask.internal.shared as shared + +# By default, when there's 10minutes left before the token expires, refresh the token +class AccessTokenManager: + def __init__(self, refresh_buffer: int = 600, metadata: Optional[list[tuple[str, str]]] = None): + self.scope = "https://durabletask.io/.default" + self.refresh_buffer = refresh_buffer + self._use_managed_identity = False + self._metadata = metadata + self._client_id = None + self._logger = shared.get_logger("token_manager") + + if metadata: # Ensure metadata is not None + for key, value in metadata: + if key == "use_managed_identity": + self._use_managed_identity = value.lower() == "true" # Properly convert string to bool + elif key == "client_id": + self._client_id = value # Directly assign string + + # Choose the appropriate credential based on use_managed_identity + if self._use_managed_identity: + if not self._client_id: + self._logger.debug("Using System Assigned Managed Identity for authentication.") + self.credential = ManagedIdentityCredential() + else: + self._logger.debug("Using User Assigned Managed Identity for authentication.") + self.credential = ManagedIdentityCredential(client_id=self._client_id) + else: + self.credential = DefaultAzureCredential() + self._logger.debug("Using Default Azure Credentials for authentication.") + + self.token = None + self.expiry_time = None + + def get_access_token(self) -> str: + if self.token is None or self.is_token_expired(): + self.refresh_token() + return self.token + + # Checks if the token is expired, or if it will expire in the next "refresh_buffer" seconds. + # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes, + # We will grab a new token when there're 30minutes left on the lifespan of the token + def is_token_expired(self) -> bool: + if self.expiry_time is None: + return True + return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_buffer)) + + def refresh_token(self): + new_token = self.credential.get_token(self.scope) + self.token = f"Bearer {new_token.token}" + + # Convert UNIX timestamp to timezone-aware datetime + self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc) self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}") \ No newline at end of file diff --git a/externalpackages/durabletaskscheduler/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py similarity index 92% rename from externalpackages/durabletaskscheduler/durabletask_grpc_interceptor.py rename to durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py index 39c84d6..97897d9 100644 --- a/externalpackages/durabletaskscheduler/durabletask_grpc_interceptor.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py @@ -1,29 +1,29 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl -from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager - -import grpc - -class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): - """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an - interceptor to add additional headers to all calls as needed.""" - - def __init__(self, metadata: list[tuple[str, str]]): - super().__init__(metadata) - self._token_manager = AccessTokenManager(metadata=self._metadata) - - def _intercept_call( - self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: - """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC - call details.""" - # Refresh the auth token if it is present and needed - if self._metadata is not None: - for i, (key, _) in enumerate(self._metadata): - if key.lower() == "authorization": # Ensure case-insensitive comparison - new_token = self._token_manager.get_access_token() # Get the new token - self._metadata[i] = ("authorization", new_token) # Update the token - - return super()._intercept_call(client_call_details) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl +from durabletask.azuremanaged.access_token_manager import AccessTokenManager + +import grpc + +class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): + """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an + interceptor to add additional headers to all calls as needed.""" + + def __init__(self, metadata: list[tuple[str, str]]): + super().__init__(metadata) + self._token_manager = AccessTokenManager(metadata=self._metadata) + + def _intercept_call( + self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC + call details.""" + # Refresh the auth token if it is present and needed + if self._metadata is not None: + for i, (key, _) in enumerate(self._metadata): + if key.lower() == "authorization": # Ensure case-insensitive comparison + new_token = self._token_manager.get_access_token() # Get the new token + self._metadata[i] = ("authorization", new_token) # Update the token + + return super()._intercept_call(client_call_details) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_keep_alive_service.py b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_keep_alive_service.py new file mode 100644 index 0000000..0720bc9 --- /dev/null +++ b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_keep_alive_service.py @@ -0,0 +1,28 @@ +import threading +import time +import requests # You could use grpc or another library depending on your setup + +class KeepAliveService: + def __init__(self, interval: int = 60, endpoint: str = "https://sdktest1-fgcac9hja3f8.northcentralus.durabletask.io"): + self.interval = interval # Time interval in seconds + self.endpoint = endpoint # The endpoint for sending no-op requests + self._keep_alive_thread = threading.Thread(target=self._send_noop_periodically) + self._keep_alive_thread.daemon = True # Makes sure it ends when the main program ends + self._keep_alive_thread.start() + + def _send_noop_periodically(self): + while True: + try: + # Send a simple GET or POST request to a "ping" or no-op endpoint + response = requests.get(self.endpoint) # Replace with the appropriate method + if response.status_code == 200: + print("No-op request sent successfully.") + else: + print(f"No-op failed with status code {response.status_code}") + except Exception as e: + print(f"Error sending no-op: {e}") + + time.sleep(self.interval) # Wait before sending another no-op + +# Example Usage +keep_alive_service = KeepAliveService(interval=60, endpoint="https://sdktest1-fgcac9hja3f8.northcentralus.durabletask.io") diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_client.py similarity index 90% rename from externalpackages/durabletaskscheduler/durabletask_scheduler_client.py rename to durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_client.py index f42a8d8..8f89dd1 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_client.py @@ -1,64 +1,64 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from typing import Optional -from durabletask.client import TaskHubGrpcClient -from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager -from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl - -# Client class used for Durable Task Scheduler (DTS) -class DurableTaskSchedulerClient(TaskHubGrpcClient): - def __init__(self, - host_address: str, - secure_channel: bool, - metadata: Optional[list[tuple[str, str]]] = None, - use_managed_identity: Optional[bool] = False, - client_id: Optional[str] = None, - taskhub: str = None, - **kwargs): - - # Ensure metadata is a list - metadata = metadata or [] - self._metadata = metadata.copy() # Use a copy to avoid modifying original - - # Append DurableTask-specific metadata - self._metadata.append(("taskhub", taskhub or "default-taskhub")) - self._metadata.append(("dts", "True")) - self._metadata.append(("use_managed_identity", str(use_managed_identity))) - self._metadata.append(("client_id", str(client_id or "None"))) - - self._access_token_manager = AccessTokenManager(metadata=self._metadata) - self.__update_metadata_with_token() - self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] - - # We pass in None for the metadata so we don't construct an additional interceptor in the parent class - # Since the parent class doesn't use anything metadata for anything else, we can set it as None - super().__init__( - host_address=host_address, - secure_channel=secure_channel, - metadata=None, - interceptors=self._interceptors, - **kwargs - ) - - def __update_metadata_with_token(self): - """ - Add or update the `authorization` key in the metadata with the current access token. - """ - token = self._access_token_manager.get_access_token() - - # Ensure that self._metadata is initialized - if self._metadata is None: - self._metadata = [] # Initialize it if it's still None - - # Check if "authorization" already exists in the metadata - updated = False - for i, (key, _) in enumerate(self._metadata): - if key == "authorization": - self._metadata[i] = ("authorization", token) - updated = True - break - - # If not updated, add a new entry - if not updated: +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Optional +from durabletask.client import TaskHubGrpcClient +from durabletask.azuremanaged.access_token_manager import AccessTokenManager +from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl + +# Client class used for Durable Task Scheduler (DTS) +class DurableTaskSchedulerClient(TaskHubGrpcClient): + def __init__(self, + host_address: str, + secure_channel: bool, + metadata: Optional[list[tuple[str, str]]] = None, + use_managed_identity: Optional[bool] = False, + client_id: Optional[str] = None, + taskhub: str = None, + **kwargs): + + # Ensure metadata is a list + metadata = metadata or [] + self._metadata = metadata.copy() # Use a copy to avoid modifying original + + # Append DurableTask-specific metadata + self._metadata.append(("taskhub", taskhub or "default-taskhub")) + self._metadata.append(("dts", "True")) + self._metadata.append(("use_managed_identity", str(use_managed_identity))) + self._metadata.append(("client_id", str(client_id or "None"))) + + self._access_token_manager = AccessTokenManager(metadata=self._metadata) + self.__update_metadata_with_token() + self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + + # We pass in None for the metadata so we don't construct an additional interceptor in the parent class + # Since the parent class doesn't use anything metadata for anything else, we can set it as None + super().__init__( + host_address=host_address, + secure_channel=secure_channel, + metadata=None, + interceptors=self._interceptors, + **kwargs + ) + + def __update_metadata_with_token(self): + """ + Add or update the `authorization` key in the metadata with the current access token. + """ + token = self._access_token_manager.get_access_token() + + # Ensure that self._metadata is initialized + if self._metadata is None: + self._metadata = [] # Initialize it if it's still None + + # Check if "authorization" already exists in the metadata + updated = False + for i, (key, _) in enumerate(self._metadata): + if key == "authorization": + self._metadata[i] = ("authorization", token) + updated = True + break + + # If not updated, add a new entry + if not updated: self._metadata.append(("authorization", token)) \ No newline at end of file diff --git a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_worker.py similarity index 90% rename from externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py rename to durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_worker.py index f6bd184..992ff01 100644 --- a/externalpackages/durabletaskscheduler/durabletask_scheduler_worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_worker.py @@ -1,64 +1,64 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from typing import Optional -from durabletask.worker import TaskHubGrpcWorker -from externalpackages.durabletaskscheduler.access_token_manager import AccessTokenManager -from externalpackages.durabletaskscheduler.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl - -# Worker class used for Durable Task Scheduler (DTS) -class DurableTaskSchedulerWorker(TaskHubGrpcWorker): - def __init__(self, - host_address: str, - secure_channel: bool, - metadata: Optional[list[tuple[str, str]]] = None, - use_managed_identity: Optional[bool] = False, - client_id: Optional[str] = None, - taskhub: str = None, - **kwargs): - - # Ensure metadata is a list - metadata = metadata or [] - self._metadata = metadata.copy() # Copy to prevent modifying input - - # Append DurableTask-specific metadata - self._metadata.append(("taskhub", taskhub or "default-taskhub")) - self._metadata.append(("dts", "True")) - self._metadata.append(("use_managed_identity", str(use_managed_identity))) - self._metadata.append(("client_id", str(client_id or "None"))) - - self._access_token_manager = AccessTokenManager(metadata=self._metadata) - self.__update_metadata_with_token() - interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] - - # We pass in None for the metadata so we don't construct an additional interceptor in the parent class - # Since the parent class doesn't use anything metadata for anything else, we can set it as None - super().__init__( - host_address=host_address, - secure_channel=secure_channel, - metadata=None, - interceptors=interceptors, - **kwargs - ) - - def __update_metadata_with_token(self): - """ - Add or update the `authorization` key in the metadata with the current access token. - """ - token = self._access_token_manager.get_access_token() - - # Ensure that self._metadata is initialized - if self._metadata is None: - self._metadata = [] # Initialize it if it's still None - - # Check if "authorization" already exists in the metadata - updated = False - for i, (key, _) in enumerate(self._metadata): - if key == "authorization": - self._metadata[i] = ("authorization", token) - updated = True - break - - # If not updated, add a new entry - if not updated: - self._metadata.append(("authorization", token)) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Optional +from durabletask.worker import TaskHubGrpcWorker +from durabletask.azuremanaged.access_token_manager import AccessTokenManager +from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl + +# Worker class used for Durable Task Scheduler (DTS) +class DurableTaskSchedulerWorker(TaskHubGrpcWorker): + def __init__(self, + host_address: str, + secure_channel: bool, + metadata: Optional[list[tuple[str, str]]] = None, + use_managed_identity: Optional[bool] = False, + client_id: Optional[str] = None, + taskhub: str = None, + **kwargs): + + # Ensure metadata is a list + metadata = metadata or [] + self._metadata = metadata.copy() # Copy to prevent modifying input + + # Append DurableTask-specific metadata + self._metadata.append(("taskhub", taskhub or "default-taskhub")) + self._metadata.append(("dts", "True")) + self._metadata.append(("use_managed_identity", str(use_managed_identity))) + self._metadata.append(("client_id", str(client_id or "None"))) + + self._access_token_manager = AccessTokenManager(metadata=self._metadata) + self.__update_metadata_with_token() + interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + + # We pass in None for the metadata so we don't construct an additional interceptor in the parent class + # Since the parent class doesn't use anything metadata for anything else, we can set it as None + super().__init__( + host_address=host_address, + secure_channel=secure_channel, + metadata=None, + interceptors=interceptors, + **kwargs + ) + + def __update_metadata_with_token(self): + """ + Add or update the `authorization` key in the metadata with the current access token. + """ + token = self._access_token_manager.get_access_token() + + # Ensure that self._metadata is initialized + if self._metadata is None: + self._metadata = [] # Initialize it if it's still None + + # Check if "authorization" already exists in the metadata + updated = False + for i, (key, _) in enumerate(self._metadata): + if key == "authorization": + self._metadata[i] = ("authorization", token) + updated = True + break + + # If not updated, add a new entry + if not updated: + self._metadata.append(("authorization", token)) diff --git a/durabletask-azuremanaged/pyproject.toml b/durabletask-azuremanaged/pyproject.toml new file mode 100644 index 0000000..baa78cf --- /dev/null +++ b/durabletask-azuremanaged/pyproject.toml @@ -0,0 +1,41 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# For more information on pyproject.toml, see https://peps.python.org/pep-0621/ + +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "durabletask.azuremanaged" +version = "0.1b1" +description = "An Azure Managed Backend for Durable Task Python SDK" +keywords = [ + "durable", + "task", + "workflow", + "azure" +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", +] +requires-python = ">=3.9" +license = {file = "LICENSE"} +readme = "README.md" +dependencies = [ + "durabletask", + "azure-identity" +] + +[project.urls] +repository = "https://github.com/microsoft/durabletask-python" +changelog = "https://github.com/microsoft/durabletask-python/blob/main/CHANGELOG.md" + +[tool.setuptools.packages.find] +include = ["durabletask.azuremanaged", "durabletask.azuremanaged.*"] + +[tool.pytest.ini_options] +minversion = "6.0" diff --git a/durabletask/client.py b/durabletask/client.py index 55a30e4..cc4a673 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -99,7 +99,7 @@ def __init__(self, *, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, - interceptors: Optional[list[DefaultClientInterceptorImpl]] = None): + interceptors: Optional[list[Union[grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor]]] = None): # Determine the interceptors to use if interceptors is not None: diff --git a/examples/dts/README.md b/examples/dts/README.md index feefa05..7467447 100644 --- a/examples/dts/README.md +++ b/examples/dts/README.md @@ -38,9 +38,15 @@ python3 -m pip install azure-identity python3 -m pip install . ``` +7. Install the DTS specific packages from the durabletask-python/durabletask-azuremanaged directory + +```sh +pip3 install -e . +``` + ## Running the examples -With one of the sidecars running, you can simply execute any of the examples in this directory using `python3`: +Now, you can simply execute any of the examples in this directory using `python3`: ```sh python3 dts_activity_sequence.py diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index 8d52089..ba59742 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -2,8 +2,8 @@ that calls an activity function in a sequence and prints the outputs.""" import os from durabletask import client, task -from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker -from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient +from durabletask.azuremanaged.durabletask_scheduler_worker import DurableTaskSchedulerWorker +from durabletask.azuremanaged.durabletask_scheduler_client import DurableTaskSchedulerClient def hello(ctx: task.ActivityContext, name: str) -> str: """Activity function that returns a greeting""" diff --git a/examples/dts/dts_fanout_fanin.py b/examples/dts/dts_fanout_fanin.py index 90cda73..665853d 100644 --- a/examples/dts/dts_fanout_fanin.py +++ b/examples/dts/dts_fanout_fanin.py @@ -6,8 +6,8 @@ import os from durabletask import client, task from durabletask import client, task -from externalpackages.durabletaskscheduler.durabletask_scheduler_worker import DurableTaskSchedulerWorker -from externalpackages.durabletaskscheduler.durabletask_scheduler_client import DurableTaskSchedulerClient +import DurableTaskSchedulerWorker +import DurableTaskSchedulerClient def get_work_items(ctx: task.ActivityContext, _) -> list[str]: diff --git a/externalpackages/durabletaskscheduler/__init__.py b/externalpackages/durabletaskscheduler/__init__.py deleted file mode 100644 index e3941ba..0000000 --- a/externalpackages/durabletaskscheduler/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -"""Durable Task SDK for Python""" - - -PACKAGE_NAME = "durabletaskscheduler" diff --git a/requirements.txt b/requirements.txt index 49896d3..36c461e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newe protobuf pytest pytest-cov -azure-identity \ No newline at end of file +azure-identity +durabletask-azuremanaged \ No newline at end of file From f9d55ab22b807ddb63e3935e60a8e69c30b32946 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Wed, 12 Feb 2025 14:25:14 -0700 Subject: [PATCH 20/31] Adressing more feedback Signed-off-by: Ryan Lettieri --- ...abletask_scheduler_client.py => client.py} | 20 ++++++------- .../durabletask_grpc_interceptor.py | 2 +- .../durabletask_keep_alive_service.py | 28 ------------------- .../{ => internal}/access_token_manager.py | 8 +++--- ...abletask_scheduler_worker.py => worker.py} | 16 +++++------ durabletask-azuremanaged/pyproject.toml | 2 +- examples/dts/dts_activity_sequence.py | 4 +-- examples/dts/dts_fanout_fanin.py | 5 ++-- requirements.txt | 3 +- 9 files changed, 29 insertions(+), 59 deletions(-) rename durabletask-azuremanaged/durabletask/azuremanaged/{durabletask_scheduler_client.py => client.py} (82%) delete mode 100644 durabletask-azuremanaged/durabletask/azuremanaged/durabletask_keep_alive_service.py rename durabletask-azuremanaged/durabletask/azuremanaged/{ => internal}/access_token_manager.py (90%) rename durabletask-azuremanaged/durabletask/azuremanaged/{durabletask_scheduler_worker.py => worker.py} (86%) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py similarity index 82% rename from durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_client.py rename to durabletask-azuremanaged/durabletask/azuremanaged/client.py index 8f89dd1..5c30819 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -3,26 +3,28 @@ from typing import Optional from durabletask.client import TaskHubGrpcClient -from durabletask.azuremanaged.access_token_manager import AccessTokenManager +from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl # Client class used for Durable Task Scheduler (DTS) class DurableTaskSchedulerClient(TaskHubGrpcClient): - def __init__(self, + def __init__(self, *, host_address: str, - secure_channel: bool, + taskhub: str, + secure_channel: Optional[bool] = True, metadata: Optional[list[tuple[str, str]]] = None, use_managed_identity: Optional[bool] = False, - client_id: Optional[str] = None, - taskhub: str = None, - **kwargs): + client_id: Optional[str] = None): + if taskhub == None: + raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") + # Ensure metadata is a list metadata = metadata or [] self._metadata = metadata.copy() # Use a copy to avoid modifying original # Append DurableTask-specific metadata - self._metadata.append(("taskhub", taskhub or "default-taskhub")) + self._metadata.append(("taskhub", taskhub)) self._metadata.append(("dts", "True")) self._metadata.append(("use_managed_identity", str(use_managed_identity))) self._metadata.append(("client_id", str(client_id or "None"))) @@ -37,9 +39,7 @@ def __init__(self, host_address=host_address, secure_channel=secure_channel, metadata=None, - interceptors=self._interceptors, - **kwargs - ) + interceptors=self._interceptors) def __update_metadata_with_token(self): """ diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py index 97897d9..eb809ee 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl -from durabletask.azuremanaged.access_token_manager import AccessTokenManager +from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager import grpc diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_keep_alive_service.py b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_keep_alive_service.py deleted file mode 100644 index 0720bc9..0000000 --- a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_keep_alive_service.py +++ /dev/null @@ -1,28 +0,0 @@ -import threading -import time -import requests # You could use grpc or another library depending on your setup - -class KeepAliveService: - def __init__(self, interval: int = 60, endpoint: str = "https://sdktest1-fgcac9hja3f8.northcentralus.durabletask.io"): - self.interval = interval # Time interval in seconds - self.endpoint = endpoint # The endpoint for sending no-op requests - self._keep_alive_thread = threading.Thread(target=self._send_noop_periodically) - self._keep_alive_thread.daemon = True # Makes sure it ends when the main program ends - self._keep_alive_thread.start() - - def _send_noop_periodically(self): - while True: - try: - # Send a simple GET or POST request to a "ping" or no-op endpoint - response = requests.get(self.endpoint) # Replace with the appropriate method - if response.status_code == 200: - print("No-op request sent successfully.") - else: - print(f"No-op failed with status code {response.status_code}") - except Exception as e: - print(f"Error sending no-op: {e}") - - time.sleep(self.interval) # Wait before sending another no-op - -# Example Usage -keep_alive_service = KeepAliveService(interval=60, endpoint="https://sdktest1-fgcac9hja3f8.northcentralus.durabletask.io") diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/access_token_manager.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py similarity index 90% rename from durabletask-azuremanaged/durabletask/azuremanaged/access_token_manager.py rename to durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py index 43e86dc..d93fec6 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/access_token_manager.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py @@ -7,9 +7,9 @@ # By default, when there's 10minutes left before the token expires, refresh the token class AccessTokenManager: - def __init__(self, refresh_buffer: int = 600, metadata: Optional[list[tuple[str, str]]] = None): + def __init__(self, refresh_interval_seconds: int = 600, metadata: Optional[list[tuple[str, str]]] = None): self.scope = "https://durabletask.io/.default" - self.refresh_buffer = refresh_buffer + self.refresh_interval_seconds = refresh_interval_seconds self._use_managed_identity = False self._metadata = metadata self._client_id = None @@ -42,13 +42,13 @@ def get_access_token(self) -> str: self.refresh_token() return self.token - # Checks if the token is expired, or if it will expire in the next "refresh_buffer" seconds. + # Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds. # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes, # We will grab a new token when there're 30minutes left on the lifespan of the token def is_token_expired(self) -> bool: if self.expiry_time is None: return True - return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_buffer)) + return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_interval_seconds)) def refresh_token(self): new_token = self.credential.get_token(self.scope) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py similarity index 86% rename from durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_worker.py rename to durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 992ff01..2345f7c 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_scheduler_worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -3,20 +3,22 @@ from typing import Optional from durabletask.worker import TaskHubGrpcWorker -from durabletask.azuremanaged.access_token_manager import AccessTokenManager +from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl # Worker class used for Durable Task Scheduler (DTS) class DurableTaskSchedulerWorker(TaskHubGrpcWorker): - def __init__(self, + def __init__(self, *, host_address: str, + taskhub: str, secure_channel: bool, metadata: Optional[list[tuple[str, str]]] = None, use_managed_identity: Optional[bool] = False, - client_id: Optional[str] = None, - taskhub: str = None, - **kwargs): + client_id: Optional[str] = None): + if taskhub == None: + raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") + # Ensure metadata is a list metadata = metadata or [] self._metadata = metadata.copy() # Copy to prevent modifying input @@ -37,9 +39,7 @@ def __init__(self, host_address=host_address, secure_channel=secure_channel, metadata=None, - interceptors=interceptors, - **kwargs - ) + interceptors=interceptors) def __update_metadata_with_token(self): """ diff --git a/durabletask-azuremanaged/pyproject.toml b/durabletask-azuremanaged/pyproject.toml index baa78cf..db7f561 100644 --- a/durabletask-azuremanaged/pyproject.toml +++ b/durabletask-azuremanaged/pyproject.toml @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta" [project] name = "durabletask.azuremanaged" version = "0.1b1" -description = "An Azure Managed Backend for Durable Task Python SDK" +description = "Extensions for the Durable Task Python SDK for integrating with the Durable Task Scheduler in Azure" keywords = [ "durable", "task", diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index ba59742..eea25a1 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -2,8 +2,8 @@ that calls an activity function in a sequence and prints the outputs.""" import os from durabletask import client, task -from durabletask.azuremanaged.durabletask_scheduler_worker import DurableTaskSchedulerWorker -from durabletask.azuremanaged.durabletask_scheduler_client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from durabletask.azuremanaged.client import DurableTaskSchedulerClient def hello(ctx: task.ActivityContext, name: str) -> str: """Activity function that returns a greeting""" diff --git a/examples/dts/dts_fanout_fanin.py b/examples/dts/dts_fanout_fanin.py index 665853d..05e3ef9 100644 --- a/examples/dts/dts_fanout_fanin.py +++ b/examples/dts/dts_fanout_fanin.py @@ -6,9 +6,8 @@ import os from durabletask import client, task from durabletask import client, task -import DurableTaskSchedulerWorker -import DurableTaskSchedulerClient - +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from durabletask.azuremanaged.client import DurableTaskSchedulerClient def get_work_items(ctx: task.ActivityContext, _) -> list[str]: """Activity function that returns a list of work items""" diff --git a/requirements.txt b/requirements.txt index 36c461e..49896d3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,4 @@ grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newe protobuf pytest pytest-cov -azure-identity -durabletask-azuremanaged \ No newline at end of file +azure-identity \ No newline at end of file From ba1ac4f281e256d467871889ae9f82222644a86c Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Wed, 12 Feb 2025 15:40:29 -0700 Subject: [PATCH 21/31] More review feedback Signed-off-by: Ryan Lettieri --- .../durabletask/azuremanaged/client.py | 110 ++++++++---------- .../durabletask_grpc_interceptor.py | 71 ++++++----- .../internal/access_token_manager.py | 108 ++++++++--------- .../durabletask/azuremanaged/worker.py | 108 +++++++---------- durabletask-azuremanaged/pyproject.toml | 82 ++++++------- examples/dts/README.md | 2 + examples/dts/dts_activity_sequence.py | 6 +- 7 files changed, 228 insertions(+), 259 deletions(-) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index 5c30819..7fad517 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -1,64 +1,46 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from typing import Optional -from durabletask.client import TaskHubGrpcClient -from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager -from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl - -# Client class used for Durable Task Scheduler (DTS) -class DurableTaskSchedulerClient(TaskHubGrpcClient): - def __init__(self, *, - host_address: str, - taskhub: str, - secure_channel: Optional[bool] = True, - metadata: Optional[list[tuple[str, str]]] = None, - use_managed_identity: Optional[bool] = False, - client_id: Optional[str] = None): - - if taskhub == None: - raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") - - # Ensure metadata is a list - metadata = metadata or [] - self._metadata = metadata.copy() # Use a copy to avoid modifying original - - # Append DurableTask-specific metadata - self._metadata.append(("taskhub", taskhub)) - self._metadata.append(("dts", "True")) - self._metadata.append(("use_managed_identity", str(use_managed_identity))) - self._metadata.append(("client_id", str(client_id or "None"))) - - self._access_token_manager = AccessTokenManager(metadata=self._metadata) - self.__update_metadata_with_token() - self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] - - # We pass in None for the metadata so we don't construct an additional interceptor in the parent class - # Since the parent class doesn't use anything metadata for anything else, we can set it as None - super().__init__( - host_address=host_address, - secure_channel=secure_channel, - metadata=None, - interceptors=self._interceptors) - - def __update_metadata_with_token(self): - """ - Add or update the `authorization` key in the metadata with the current access token. - """ - token = self._access_token_manager.get_access_token() - - # Ensure that self._metadata is initialized - if self._metadata is None: - self._metadata = [] # Initialize it if it's still None - - # Check if "authorization" already exists in the metadata - updated = False - for i, (key, _) in enumerate(self._metadata): - if key == "authorization": - self._metadata[i] = ("authorization", token) - updated = True - break - - # If not updated, add a new entry - if not updated: - self._metadata.append(("authorization", token)) \ No newline at end of file +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Optional +from durabletask.client import TaskHubGrpcClient, OrchestrationStatus +from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager +from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl +from azure.identity import DefaultAzureCredential + +# Client class used for Durable Task Scheduler (DTS) +class DurableTaskSchedulerClient(TaskHubGrpcClient): + def __init__(self, *, + host_address: str, + taskhub: str, + secure_channel: Optional[bool] = True, + metadata: Optional[list[tuple[str, str]]] = None, + use_managed_identity: Optional[bool] = False, + client_id: Optional[str] = None): + + if taskhub == None: + raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") + + # Ensure metadata is a list + metadata = metadata or [] + self._metadata = metadata.copy() # Use a copy to avoid modifying original + + # Append DurableTask-specific metadata + self._metadata.append(("taskhub", taskhub)) + self._metadata.append(("dts", "True")) + self._metadata.append(("use_managed_identity", str(use_managed_identity))) + self._metadata.append(("client_id", str(client_id or "None"))) + + self._access_token_manager = AccessTokenManager(use_managed_identity=use_managed_identity, + client_id=client_id) + token = self._access_token_manager.get_access_token() + self._metadata.append(("authorization", token)) + + self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + + # We pass in None for the metadata so we don't construct an additional interceptor in the parent class + # Since the parent class doesn't use anything metadata for anything else, we can set it as None + super().__init__( + host_address=host_address, + secure_channel=secure_channel, + metadata=None, + interceptors=self._interceptors) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py index eb809ee..6ae00d9 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py @@ -1,29 +1,42 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl -from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager - -import grpc - -class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): - """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an - interceptor to add additional headers to all calls as needed.""" - - def __init__(self, metadata: list[tuple[str, str]]): - super().__init__(metadata) - self._token_manager = AccessTokenManager(metadata=self._metadata) - - def _intercept_call( - self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: - """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC - call details.""" - # Refresh the auth token if it is present and needed - if self._metadata is not None: - for i, (key, _) in enumerate(self._metadata): - if key.lower() == "authorization": # Ensure case-insensitive comparison - new_token = self._token_manager.get_access_token() # Get the new token - self._metadata[i] = ("authorization", new_token) # Update the token - - return super()._intercept_call(client_call_details) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl +from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager + +import grpc + +class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): + """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an + interceptor to add additional headers to all calls as needed.""" + + def __init__(self, metadata: list[tuple[str, str]]): + super().__init__(metadata) + + use_managed_identity = False + client_id = None + + # Check what authentication we are using + if metadata: + for key, value in metadata: + if key.lower() == "use_managed_identity": + self.use_managed_identity = value.strip().lower() == "true" # Convert to boolean + elif key.lower() == "client_id": + self.client_id = value + + self._token_manager = AccessTokenManager(use_managed_identity=use_managed_identity, + client_id=client_id) + + def _intercept_call( + self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC + call details.""" + # Refresh the auth token if it is present and needed + if self._metadata is not None: + for i, (key, _) in enumerate(self._metadata): + if key.lower() == "authorization": # Ensure case-insensitive comparison + new_token = self._token_manager.get_access_token() # Get the new token + self._metadata[i] = ("authorization", new_token) # Update the token + + return super()._intercept_call(client_call_details) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py index d93fec6..c52a955 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py @@ -1,59 +1,51 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from azure.identity import DefaultAzureCredential, ManagedIdentityCredential -from datetime import datetime, timedelta, timezone -from typing import Optional -import durabletask.internal.shared as shared - -# By default, when there's 10minutes left before the token expires, refresh the token -class AccessTokenManager: - def __init__(self, refresh_interval_seconds: int = 600, metadata: Optional[list[tuple[str, str]]] = None): - self.scope = "https://durabletask.io/.default" - self.refresh_interval_seconds = refresh_interval_seconds - self._use_managed_identity = False - self._metadata = metadata - self._client_id = None - self._logger = shared.get_logger("token_manager") - - if metadata: # Ensure metadata is not None - for key, value in metadata: - if key == "use_managed_identity": - self._use_managed_identity = value.lower() == "true" # Properly convert string to bool - elif key == "client_id": - self._client_id = value # Directly assign string - - # Choose the appropriate credential based on use_managed_identity - if self._use_managed_identity: - if not self._client_id: - self._logger.debug("Using System Assigned Managed Identity for authentication.") - self.credential = ManagedIdentityCredential() - else: - self._logger.debug("Using User Assigned Managed Identity for authentication.") - self.credential = ManagedIdentityCredential(client_id=self._client_id) - else: - self.credential = DefaultAzureCredential() - self._logger.debug("Using Default Azure Credentials for authentication.") - - self.token = None - self.expiry_time = None - - def get_access_token(self) -> str: - if self.token is None or self.is_token_expired(): - self.refresh_token() - return self.token - - # Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds. - # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes, - # We will grab a new token when there're 30minutes left on the lifespan of the token - def is_token_expired(self) -> bool: - if self.expiry_time is None: - return True - return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_interval_seconds)) - - def refresh_token(self): - new_token = self.credential.get_token(self.scope) - self.token = f"Bearer {new_token.token}" - - # Convert UNIX timestamp to timezone-aware datetime - self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from azure.identity import DefaultAzureCredential, ManagedIdentityCredential +from datetime import datetime, timedelta, timezone +from typing import Optional +import durabletask.internal.shared as shared + +# By default, when there's 10minutes left before the token expires, refresh the token +class AccessTokenManager: + def __init__(self, refresh_interval_seconds: int = 600, use_managed_identity: bool = False, client_id: str = None): + self.scope = "https://durabletask.io/.default" + self.refresh_interval_seconds = refresh_interval_seconds + self._use_managed_identity = use_managed_identity + self._client_id = client_id + self._logger = shared.get_logger("token_manager") + + # Choose the appropriate credential based on use_managed_identity + if self._use_managed_identity: + if not self._client_id: + self._logger.debug("Using System Assigned Managed Identity for authentication.") + self.credential = ManagedIdentityCredential() + else: + self._logger.debug("Using User Assigned Managed Identity for authentication.") + self.credential = ManagedIdentityCredential(client_id=self._client_id) + else: + self.credential = DefaultAzureCredential() + self._logger.debug("Using Default Azure Credentials for authentication.") + + self.token = None + self.expiry_time = None + + def get_access_token(self) -> str: + if self.token is None or self.is_token_expired(): + self.refresh_token() + return self.token + + # Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds. + # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes, + # We will grab a new token when there're 30minutes left on the lifespan of the token + def is_token_expired(self) -> bool: + if self.expiry_time is None: + return True + return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_interval_seconds)) + + def refresh_token(self): + new_token = self.credential.get_token(self.scope) + self.token = f"Bearer {new_token.token}" + + # Convert UNIX timestamp to timezone-aware datetime + self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc) self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}") \ No newline at end of file diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 2345f7c..603918e 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -1,64 +1,44 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from typing import Optional -from durabletask.worker import TaskHubGrpcWorker -from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager -from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl - -# Worker class used for Durable Task Scheduler (DTS) -class DurableTaskSchedulerWorker(TaskHubGrpcWorker): - def __init__(self, *, - host_address: str, - taskhub: str, - secure_channel: bool, - metadata: Optional[list[tuple[str, str]]] = None, - use_managed_identity: Optional[bool] = False, - client_id: Optional[str] = None): - - if taskhub == None: - raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") - - # Ensure metadata is a list - metadata = metadata or [] - self._metadata = metadata.copy() # Copy to prevent modifying input - - # Append DurableTask-specific metadata - self._metadata.append(("taskhub", taskhub or "default-taskhub")) - self._metadata.append(("dts", "True")) - self._metadata.append(("use_managed_identity", str(use_managed_identity))) - self._metadata.append(("client_id", str(client_id or "None"))) - - self._access_token_manager = AccessTokenManager(metadata=self._metadata) - self.__update_metadata_with_token() - interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] - - # We pass in None for the metadata so we don't construct an additional interceptor in the parent class - # Since the parent class doesn't use anything metadata for anything else, we can set it as None - super().__init__( - host_address=host_address, - secure_channel=secure_channel, - metadata=None, - interceptors=interceptors) - - def __update_metadata_with_token(self): - """ - Add or update the `authorization` key in the metadata with the current access token. - """ - token = self._access_token_manager.get_access_token() - - # Ensure that self._metadata is initialized - if self._metadata is None: - self._metadata = [] # Initialize it if it's still None - - # Check if "authorization" already exists in the metadata - updated = False - for i, (key, _) in enumerate(self._metadata): - if key == "authorization": - self._metadata[i] = ("authorization", token) - updated = True - break - - # If not updated, add a new entry - if not updated: - self._metadata.append(("authorization", token)) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typing import Optional +from durabletask.worker import TaskHubGrpcWorker +from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager +from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl + +# Worker class used for Durable Task Scheduler (DTS) +class DurableTaskSchedulerWorker(TaskHubGrpcWorker): + def __init__(self, *, + host_address: str, + taskhub: str, + secure_channel: bool, + metadata: Optional[list[tuple[str, str]]] = None, + use_managed_identity: Optional[bool] = False, + client_id: Optional[str] = None): + + if taskhub == None: + raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") + + # Ensure metadata is a list + metadata = metadata or [] + self._metadata = metadata.copy() # Copy to prevent modifying input + + # Append DurableTask-specific metadata + self._metadata.append(("taskhub", taskhub or "default-taskhub")) + self._metadata.append(("dts", "True")) + self._metadata.append(("use_managed_identity", str(use_managed_identity))) + self._metadata.append(("client_id", str(client_id or "None"))) + + self._access_token_manager = AccessTokenManager(use_managed_identity=use_managed_identity, + client_id=client_id) + token = self._access_token_manager.get_access_token() + self._metadata.append(("authorization", token)) + interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + + # We pass in None for the metadata so we don't construct an additional interceptor in the parent class + # Since the parent class doesn't use anything metadata for anything else, we can set it as None + super().__init__( + host_address=host_address, + secure_channel=secure_channel, + metadata=None, + interceptors=interceptors) \ No newline at end of file diff --git a/durabletask-azuremanaged/pyproject.toml b/durabletask-azuremanaged/pyproject.toml index db7f561..a23e4a8 100644 --- a/durabletask-azuremanaged/pyproject.toml +++ b/durabletask-azuremanaged/pyproject.toml @@ -1,41 +1,41 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -# For more information on pyproject.toml, see https://peps.python.org/pep-0621/ - -[build-system] -requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta" - -[project] -name = "durabletask.azuremanaged" -version = "0.1b1" -description = "Extensions for the Durable Task Python SDK for integrating with the Durable Task Scheduler in Azure" -keywords = [ - "durable", - "task", - "workflow", - "azure" -] -classifiers = [ - "Development Status :: 3 - Alpha", - "Programming Language :: Python :: 3", - "License :: OSI Approved :: MIT License", -] -requires-python = ">=3.9" -license = {file = "LICENSE"} -readme = "README.md" -dependencies = [ - "durabletask", - "azure-identity" -] - -[project.urls] -repository = "https://github.com/microsoft/durabletask-python" -changelog = "https://github.com/microsoft/durabletask-python/blob/main/CHANGELOG.md" - -[tool.setuptools.packages.find] -include = ["durabletask.azuremanaged", "durabletask.azuremanaged.*"] - -[tool.pytest.ini_options] -minversion = "6.0" +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# For more information on pyproject.toml, see https://peps.python.org/pep-0621/ + +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "durabletask.azuremanaged" +version = "0.1b1" +description = "Extensions for the Durable Task Python SDK for integrating with the Durable Task Scheduler in Azure" +keywords = [ + "durable", + "task", + "workflow", + "azure" +] +classifiers = [ + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", +] +requires-python = ">=3.9" +license = {file = "LICENSE"} +readme = "README.md" +dependencies = [ + "durabletask", + "azure-identity" +] + +[project.urls] +repository = "https://github.com/microsoft/durabletask-python" +changelog = "https://github.com/microsoft/durabletask-python/blob/main/CHANGELOG.md" + +[tool.setuptools.packages.find] +include = ["durabletask.azuremanaged", "durabletask.azuremanaged.*"] + +[tool.pytest.ini_options] +minversion = "6.0" diff --git a/examples/dts/README.md b/examples/dts/README.md index 7467447..5c34b03 100644 --- a/examples/dts/README.md +++ b/examples/dts/README.md @@ -44,6 +44,8 @@ python3 -m pip install . pip3 install -e . ``` +8. Grant yourself the DurableTaskDataContributor role over your scheduler + ## Running the examples Now, you can simply execute any of the examples in this directory using `python3`: diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index eea25a1..768be6f 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -1,9 +1,9 @@ """End-to-end sample that demonstrates how to configure an orchestrator that calls an activity function in a sequence and prints the outputs.""" import os -from durabletask import client, task +from durabletask import task from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker -from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.client import DurableTaskSchedulerClient, OrchestrationStatus def hello(ctx: task.ActivityContext, name: str) -> str: """Activity function that returns a greeting""" @@ -56,7 +56,7 @@ def sequence(ctx: task.OrchestrationContext, _): c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) instance_id = c.schedule_new_orchestration(sequence) state = c.wait_for_orchestration_completion(instance_id, timeout=60) - if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: + if state and state.runtime_status == OrchestrationStatus.COMPLETED: print(f'Orchestration completed! Result: {state.serialized_output}') elif state: print(f'Orchestration failed: {state.failure_details}') From 2c251ea7c19732af8e802057c4d15f6230fd2831 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Thu, 13 Feb 2025 08:09:49 -0700 Subject: [PATCH 22/31] Passing token credential as an argument rather than 2 strings Signed-off-by: Ryan Lettieri --- .../durabletask/azuremanaged/client.py | 14 ++------ .../durabletask_grpc_interceptor.py | 16 ++++----- .../internal/access_token_manager.py | 36 +++++++++---------- .../durabletask/azuremanaged/worker.py | 16 +++------ examples/dts/dts_activity_sequence.py | 2 +- 5 files changed, 32 insertions(+), 52 deletions(-) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index 7fad517..e4e71d5 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -5,7 +5,7 @@ from durabletask.client import TaskHubGrpcClient, OrchestrationStatus from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl -from azure.identity import DefaultAzureCredential +from azure.core.credentials import TokenCredential # Client class used for Durable Task Scheduler (DTS) class DurableTaskSchedulerClient(TaskHubGrpcClient): @@ -14,8 +14,7 @@ def __init__(self, *, taskhub: str, secure_channel: Optional[bool] = True, metadata: Optional[list[tuple[str, str]]] = None, - use_managed_identity: Optional[bool] = False, - client_id: Optional[str] = None): + token_credential: Optional[TokenCredential] = None): if taskhub == None: raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") @@ -27,14 +26,7 @@ def __init__(self, *, # Append DurableTask-specific metadata self._metadata.append(("taskhub", taskhub)) self._metadata.append(("dts", "True")) - self._metadata.append(("use_managed_identity", str(use_managed_identity))) - self._metadata.append(("client_id", str(client_id or "None"))) - - self._access_token_manager = AccessTokenManager(use_managed_identity=use_managed_identity, - client_id=client_id) - token = self._access_token_manager.get_access_token() - self._metadata.append(("authorization", token)) - + self._metadata.append(("token_credential", token_credential)) self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] # We pass in None for the metadata so we don't construct an additional interceptor in the parent class diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py index 6ae00d9..e078f90 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py @@ -13,20 +13,18 @@ class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): def __init__(self, metadata: list[tuple[str, str]]): super().__init__(metadata) - - use_managed_identity = False - client_id = None + + self._token_credential = None # Check what authentication we are using if metadata: for key, value in metadata: - if key.lower() == "use_managed_identity": - self.use_managed_identity = value.strip().lower() == "true" # Convert to boolean - elif key.lower() == "client_id": - self.client_id = value + if key.lower() == "token_credential": + self._token_credential = value - self._token_manager = AccessTokenManager(use_managed_identity=use_managed_identity, - client_id=client_id) + self._token_manager = AccessTokenManager(token_credential=self._token_credential) + token = self._token_manager.get_access_token() + self._metadata.append(("authorization", token)) def _intercept_call( self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py index c52a955..d095032 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py @@ -4,35 +4,31 @@ from datetime import datetime, timedelta, timezone from typing import Optional import durabletask.internal.shared as shared +from azure.core.credentials import TokenCredential # By default, when there's 10minutes left before the token expires, refresh the token class AccessTokenManager: - def __init__(self, refresh_interval_seconds: int = 600, use_managed_identity: bool = False, client_id: str = None): - self.scope = "https://durabletask.io/.default" - self.refresh_interval_seconds = refresh_interval_seconds - self._use_managed_identity = use_managed_identity - self._client_id = client_id + def __init__(self, refresh_interval_seconds: int = 600, token_credential: TokenCredential = None): + self._scope = "https://durabletask.io/.default" + self._refresh_interval_seconds = refresh_interval_seconds self._logger = shared.get_logger("token_manager") - # Choose the appropriate credential based on use_managed_identity - if self._use_managed_identity: - if not self._client_id: - self._logger.debug("Using System Assigned Managed Identity for authentication.") - self.credential = ManagedIdentityCredential() - else: - self._logger.debug("Using User Assigned Managed Identity for authentication.") - self.credential = ManagedIdentityCredential(client_id=self._client_id) + # Choose the appropriate credential. + # Both TokenCredential and DefaultAzureCredential get_token methods return an AccessToken + if token_credential: + self._logger.debug("Using user provided token credentials.") + self._credential = token_credential else: - self.credential = DefaultAzureCredential() + self._credential = DefaultAzureCredential() self._logger.debug("Using Default Azure Credentials for authentication.") - self.token = None + self._token = self._credential.get_token(self._scope) self.expiry_time = None def get_access_token(self) -> str: - if self.token is None or self.is_token_expired(): + if self._token is None or self.is_token_expired(): self.refresh_token() - return self.token + return self._token # Checks if the token is expired, or if it will expire in the next "refresh_interval_seconds" seconds. # For example, if the token is created to have a lifespan of 2 hours, and the refresh buffer is set to 30 minutes, @@ -40,11 +36,11 @@ def get_access_token(self) -> str: def is_token_expired(self) -> bool: if self.expiry_time is None: return True - return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self.refresh_interval_seconds)) + return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self._refresh_interval_seconds)) def refresh_token(self): - new_token = self.credential.get_token(self.scope) - self.token = f"Bearer {new_token.token}" + new_token = self._credential.get_token(self._scope) + self._token = f"Bearer {new_token.token}" # Convert UNIX timestamp to timezone-aware datetime self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 603918e..9bd667b 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -5,16 +5,16 @@ from durabletask.worker import TaskHubGrpcWorker from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl +from azure.core.credentials import TokenCredential # Worker class used for Durable Task Scheduler (DTS) class DurableTaskSchedulerWorker(TaskHubGrpcWorker): def __init__(self, *, host_address: str, taskhub: str, - secure_channel: bool, + secure_channel: Optional[bool] = True, metadata: Optional[list[tuple[str, str]]] = None, - use_managed_identity: Optional[bool] = False, - client_id: Optional[str] = None): + token_credential: Optional[TokenCredential] = None): if taskhub == None: raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") @@ -24,15 +24,9 @@ def __init__(self, *, self._metadata = metadata.copy() # Copy to prevent modifying input # Append DurableTask-specific metadata - self._metadata.append(("taskhub", taskhub or "default-taskhub")) + self._metadata.append(("taskhub", taskhub)) self._metadata.append(("dts", "True")) - self._metadata.append(("use_managed_identity", str(use_managed_identity))) - self._metadata.append(("client_id", str(client_id or "None"))) - - self._access_token_manager = AccessTokenManager(use_managed_identity=use_managed_identity, - client_id=client_id) - token = self._access_token_manager.get_access_token() - self._metadata.append(("authorization", token)) + self._metadata.append(("token_credential", token_credential)) interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] # We pass in None for the metadata so we don't construct an additional interceptor in the parent class diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index 768be6f..5e402fe 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -47,7 +47,7 @@ def sequence(ctx: task.OrchestrationContext, _): # configure and start the worker -with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, use_managed_identity=False, client_id="", taskhub=taskhub_name) as w: +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) as w: w.add_orchestrator(sequence) w.add_activity(hello) w.start() From 9c65176129407ce9b552c44bd2eb4d6a139e204e Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Thu, 13 Feb 2025 14:59:34 -0700 Subject: [PATCH 23/31] More review feedback for token passing Signed-off-by: Ryan Lettieri --- .../durabletask/azuremanaged/client.py | 15 +++---------- .../durabletask_grpc_interceptor.py | 21 +++++++------------ .../internal/access_token_manager.py | 9 +------- .../durabletask/azuremanaged/worker.py | 15 +++---------- examples/dts/dts_activity_sequence.py | 8 +++++-- examples/dts/dts_fanout_fanin.py | 9 ++++++-- 6 files changed, 28 insertions(+), 49 deletions(-) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index e4e71d5..16bbb9a 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -12,22 +12,13 @@ class DurableTaskSchedulerClient(TaskHubGrpcClient): def __init__(self, *, host_address: str, taskhub: str, - secure_channel: Optional[bool] = True, - metadata: Optional[list[tuple[str, str]]] = None, - token_credential: Optional[TokenCredential] = None): + token_credential: TokenCredential = None, + secure_channel: Optional[bool] = True): if taskhub == None: raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") - # Ensure metadata is a list - metadata = metadata or [] - self._metadata = metadata.copy() # Use a copy to avoid modifying original - - # Append DurableTask-specific metadata - self._metadata.append(("taskhub", taskhub)) - self._metadata.append(("dts", "True")) - self._metadata.append(("token_credential", token_credential)) - self._interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + self._interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)] # We pass in None for the metadata so we don't construct an additional interceptor in the parent class # Since the parent class doesn't use anything metadata for anything else, we can set it as None diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py index e078f90..280c878 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py @@ -3,7 +3,7 @@ from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager - +from azure.core.credentials import TokenCredential import grpc class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): @@ -11,20 +11,15 @@ class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an interceptor to add additional headers to all calls as needed.""" - def __init__(self, metadata: list[tuple[str, str]]): + def __init__(self, token_credential: TokenCredential, taskhub_name: str): + metadata = [("taskhub", taskhub_name)] super().__init__(metadata) - self._token_credential = None - - # Check what authentication we are using - if metadata: - for key, value in metadata: - if key.lower() == "token_credential": - self._token_credential = value - - self._token_manager = AccessTokenManager(token_credential=self._token_credential) - token = self._token_manager.get_access_token() - self._metadata.append(("authorization", token)) + if token_credential is not None: + self._token_credential = token_credential + self._token_manager = AccessTokenManager(token_credential=self._token_credential) + token = self._token_manager.get_access_token() + self._metadata.append(("authorization", token)) def _intercept_call( self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py index d095032..d12b29c 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py @@ -13,14 +13,7 @@ def __init__(self, refresh_interval_seconds: int = 600, token_credential: TokenC self._refresh_interval_seconds = refresh_interval_seconds self._logger = shared.get_logger("token_manager") - # Choose the appropriate credential. - # Both TokenCredential and DefaultAzureCredential get_token methods return an AccessToken - if token_credential: - self._logger.debug("Using user provided token credentials.") - self._credential = token_credential - else: - self._credential = DefaultAzureCredential() - self._logger.debug("Using Default Azure Credentials for authentication.") + self._credential = token_credential self._token = self._credential.get_token(self._scope) self.expiry_time = None diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 9bd667b..66b8fd2 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -12,22 +12,13 @@ class DurableTaskSchedulerWorker(TaskHubGrpcWorker): def __init__(self, *, host_address: str, taskhub: str, - secure_channel: Optional[bool] = True, - metadata: Optional[list[tuple[str, str]]] = None, - token_credential: Optional[TokenCredential] = None): + token_credential: TokenCredential = None, + secure_channel: Optional[bool] = True): if taskhub == None: raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") - # Ensure metadata is a list - metadata = metadata or [] - self._metadata = metadata.copy() # Copy to prevent modifying input - - # Append DurableTask-specific metadata - self._metadata.append(("taskhub", taskhub)) - self._metadata.append(("dts", "True")) - self._metadata.append(("token_credential", token_credential)) - interceptors = [DTSDefaultClientInterceptorImpl(self._metadata)] + interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)] # We pass in None for the metadata so we don't construct an additional interceptor in the parent class # Since the parent class doesn't use anything metadata for anything else, we can set it as None diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index 5e402fe..7285ac0 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -4,6 +4,7 @@ from durabletask import task from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker from durabletask.azuremanaged.client import DurableTaskSchedulerClient, OrchestrationStatus +from azure.identity import DefaultAzureCredential def hello(ctx: task.ActivityContext, name: str) -> str: """Activity function that returns a greeting""" @@ -45,15 +46,18 @@ def sequence(ctx: task.OrchestrationContext, _): print("If you are using bash, run the following: export ENDPOINT=\"\"") exit() +credential = DefaultAzureCredential() # configure and start the worker -with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) as w: +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) as w: w.add_orchestrator(sequence) w.add_activity(hello) w.start() # Construct the client and run the orchestrations - c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) instance_id = c.schedule_new_orchestration(sequence) state = c.wait_for_orchestration_completion(instance_id, timeout=60) if state and state.runtime_status == OrchestrationStatus.COMPLETED: diff --git a/examples/dts/dts_fanout_fanin.py b/examples/dts/dts_fanout_fanin.py index 05e3ef9..eace14a 100644 --- a/examples/dts/dts_fanout_fanin.py +++ b/examples/dts/dts_fanout_fanin.py @@ -8,6 +8,7 @@ from durabletask import client, task from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from azure.identity import DefaultAzureCredential def get_work_items(ctx: task.ActivityContext, _) -> list[str]: """Activity function that returns a list of work items""" @@ -71,15 +72,19 @@ def orchestrator(ctx: task.OrchestrationContext, _): print("If you are using bash, run the following: export ENDPOINT=\"\"") exit() +credential = DefaultAzureCredential() + # configure and start the worker -with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) as w: +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) as w: w.add_orchestrator(orchestrator) w.add_activity(process_work_item) w.add_activity(get_work_items) w.start() # create a client, start an orchestration, and wait for it to finish - c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, taskhub=taskhub_name) + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) instance_id = c.schedule_new_orchestration(orchestrator) state = c.wait_for_orchestration_completion(instance_id, timeout=30) if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: From 877dabb7eb6bfbd62205f5a8abb45b1fefdd1e47 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Thu, 13 Feb 2025 15:15:14 -0700 Subject: [PATCH 24/31] Addressing None comment and using correct metadata Signed-off-by: Ryan Lettieri --- durabletask-azuremanaged/durabletask/azuremanaged/client.py | 2 +- .../durabletask/azuremanaged/durabletask_grpc_interceptor.py | 4 ++-- durabletask-azuremanaged/durabletask/azuremanaged/worker.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index 16bbb9a..eec0095 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -12,7 +12,7 @@ class DurableTaskSchedulerClient(TaskHubGrpcClient): def __init__(self, *, host_address: str, taskhub: str, - token_credential: TokenCredential = None, + token_credential: TokenCredential, secure_channel: Optional[bool] = True): if taskhub == None: diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py index 280c878..5a63f4d 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py @@ -12,8 +12,8 @@ class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): interceptor to add additional headers to all calls as needed.""" def __init__(self, token_credential: TokenCredential, taskhub_name: str): - metadata = [("taskhub", taskhub_name)] - super().__init__(metadata) + self._metadata = [("taskhub", taskhub_name)] + super().__init__(self._metadata) if token_credential is not None: self._token_credential = token_credential diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index 66b8fd2..ff971b2 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -12,7 +12,7 @@ class DurableTaskSchedulerWorker(TaskHubGrpcWorker): def __init__(self, *, host_address: str, taskhub: str, - token_credential: TokenCredential = None, + token_credential: TokenCredential, secure_channel: Optional[bool] = True): if taskhub == None: From b39ffad43a8ca25715445fd9d70d9da0a54fad5b Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Thu, 13 Feb 2025 15:50:04 -0700 Subject: [PATCH 25/31] Updating unit tests Signed-off-by: Ryan Lettieri --- durabletask/internal/shared.py | 1 - tests/test_client.py | 30 +++++++++++++++--------------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 6327796..91a561d 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -24,7 +24,6 @@ def get_default_host_address() -> str: def get_grpc_channel( host_address: Optional[str], - metadata: Optional[list[tuple[str, str]]] = None, secure_channel: bool = False, interceptors: Optional[list[DefaultClientInterceptorImpl]] = None) -> grpc.Channel: diff --git a/tests/test_client.py b/tests/test_client.py index b3056d2..d5653a4 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,31 +6,31 @@ HOST_ADDRESS = 'localhost:50051' METADATA = [('key1', 'value1'), ('key2', 'value2')] - +INTERCEPTORS = DefaultClientInterceptorImpl(METADATA) def test_get_grpc_channel_insecure(): with patch('grpc.insecure_channel') as mock_channel: - get_grpc_channel(HOST_ADDRESS, METADATA, False) + get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS) def test_get_grpc_channel_secure(): with patch('grpc.secure_channel') as mock_channel, patch( 'grpc.ssl_channel_credentials') as mock_credentials: - get_grpc_channel(HOST_ADDRESS, METADATA, True) + get_grpc_channel(HOST_ADDRESS, True, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS, mock_credentials.return_value) def test_get_grpc_channel_default_host_address(): with patch('grpc.insecure_channel') as mock_channel: - get_grpc_channel(None, METADATA, False) + get_grpc_channel(None, False, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(get_default_host_address()) def test_get_grpc_channel_with_metadata(): with patch('grpc.insecure_channel') as mock_channel, patch( 'grpc.intercept_channel') as mock_intercept_channel: - get_grpc_channel(HOST_ADDRESS, METADATA, False) + get_grpc_channel(HOST_ADDRESS, False, interceptors=INTERCEPTORS) mock_channel.assert_called_once_with(HOST_ADDRESS) mock_intercept_channel.assert_called_once() @@ -48,41 +48,41 @@ def test_grpc_channel_with_host_name_protocol_stripping(): host_name = "myserver.com:1234" prefix = "grpc://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_insecure_channel.assert_called_with(host_name) prefix = "http://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_insecure_channel.assert_called_with(host_name) prefix = "HTTP://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_insecure_channel.assert_called_with(host_name) prefix = "GRPC://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_insecure_channel.assert_called_with(host_name) prefix = "" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_insecure_channel.assert_called_with(host_name) prefix = "grpcs://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) prefix = "https://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) prefix = "HTTPS://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) prefix = "GRPCS://" - get_grpc_channel(prefix + host_name, METADATA) + get_grpc_channel(prefix + host_name, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) prefix = "" - get_grpc_channel(prefix + host_name, METADATA, True) + get_grpc_channel(prefix + host_name, True, interceptors=INTERCEPTORS) mock_secure_channel.assert_called_with(host_name, ANY) \ No newline at end of file From 33c8b1151780a9f99e44a1c5f25f5ddd7a8ed410 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Thu, 13 Feb 2025 16:06:47 -0700 Subject: [PATCH 26/31] Fixing the type for the unit test Signed-off-by: Ryan Lettieri --- tests/test_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_client.py b/tests/test_client.py index d5653a4..64bbec8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -6,7 +6,7 @@ HOST_ADDRESS = 'localhost:50051' METADATA = [('key1', 'value1'), ('key2', 'value2')] -INTERCEPTORS = DefaultClientInterceptorImpl(METADATA) +INTERCEPTORS = [DefaultClientInterceptorImpl(METADATA)] def test_get_grpc_channel_insecure(): with patch('grpc.insecure_channel') as mock_channel: From 1da819e027e018bb8292e1d021d6fd96b3780391 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Thu, 13 Feb 2025 16:15:07 -0700 Subject: [PATCH 27/31] Fixing grpc calls Signed-off-by: Ryan Lettieri --- durabletask/client.py | 1 - durabletask/worker.py | 2 +- examples/dts/dts_activity_sequence.py | 132 +++++++++++++------------- 3 files changed, 67 insertions(+), 68 deletions(-) diff --git a/durabletask/client.py b/durabletask/client.py index cc4a673..7c1a5b6 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -113,7 +113,6 @@ def __init__(self, *, channel = shared.get_grpc_channel( host_address=host_address, - metadata=metadata, secure_channel=secure_channel, interceptors=self._interceptors ) diff --git a/durabletask/worker.py b/durabletask/worker.py index e67e5ca..7e25a59 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -131,7 +131,7 @@ def add_activity(self, fn: task.Activity) -> str: def start(self): """Starts the worker on a background thread and begins listening for work items.""" - channel = shared.get_grpc_channel(self._host_address, self._metadata, self._secure_channel, self._interceptors) + channel = shared.get_grpc_channel(self._host_address, self._secure_channel, self._interceptors) stub = stubs.TaskHubSidecarServiceStub(channel) if self._is_running: diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index 7285ac0..1ee5a93 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -1,66 +1,66 @@ -"""End-to-end sample that demonstrates how to configure an orchestrator -that calls an activity function in a sequence and prints the outputs.""" -import os -from durabletask import task -from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker -from durabletask.azuremanaged.client import DurableTaskSchedulerClient, OrchestrationStatus -from azure.identity import DefaultAzureCredential - -def hello(ctx: task.ActivityContext, name: str) -> str: - """Activity function that returns a greeting""" - return f'Hello {name}!' - - -def sequence(ctx: task.OrchestrationContext, _): - """Orchestrator function that calls the 'hello' activity function in a sequence""" - # call "hello" activity function in a sequence - result1 = yield ctx.call_activity(hello, input='Tokyo') - result2 = yield ctx.call_activity(hello, input='Seattle') - result3 = yield ctx.call_activity(hello, input='London') - - # return an array of results - return [result1, result2, result3] - - -# Read the environment variable -taskhub_name = os.getenv("TASKHUB") - -# Check if the variable exists -if taskhub_name: - print(f"The value of TASKHUB is: {taskhub_name}") -else: - print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use") - print("If you are using windows powershell, run the following: $env:TASKHUB=\"\"") - print("If you are using bash, run the following: export TASKHUB=\"\"") - exit() - -# Read the environment variable -endpoint = os.getenv("ENDPOINT") - -# Check if the variable exists -if endpoint: - print(f"The value of ENDPOINT is: {endpoint}") -else: - print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler") - print("If you are using windows powershell, run the following: $env:ENDPOINT=\"\"") - print("If you are using bash, run the following: export ENDPOINT=\"\"") - exit() - -credential = DefaultAzureCredential() - -# configure and start the worker -with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, - taskhub=taskhub_name, token_credential=credential) as w: - w.add_orchestrator(sequence) - w.add_activity(hello) - w.start() - - # Construct the client and run the orchestrations - c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, - taskhub=taskhub_name, token_credential=credential) - instance_id = c.schedule_new_orchestration(sequence) - state = c.wait_for_orchestration_completion(instance_id, timeout=60) - if state and state.runtime_status == OrchestrationStatus.COMPLETED: - print(f'Orchestration completed! Result: {state.serialized_output}') - elif state: - print(f'Orchestration failed: {state.failure_details}') +"""End-to-end sample that demonstrates how to configure an orchestrator +that calls an activity function in a sequence and prints the outputs.""" +import os +from durabletask import task +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker +from durabletask.azuremanaged.client import DurableTaskSchedulerClient, OrchestrationStatus +from azure.identity import DefaultAzureCredential + +def hello(ctx: task.ActivityContext, name: str) -> str: + """Activity function that returns a greeting""" + return f'Hello {name}!' + + +def sequence(ctx: task.OrchestrationContext, _): + """Orchestrator function that calls the 'hello' activity function in a sequence""" + # call "hello" activity function in a sequence + result1 = yield ctx.call_activity(hello, input='Tokyo') + result2 = yield ctx.call_activity(hello, input='Seattle') + result3 = yield ctx.call_activity(hello, input='London') + + # return an array of results + return [result1, result2, result3] + + +# Read the environment variable +taskhub_name = os.getenv("TASKHUB") + +# Check if the variable exists +if taskhub_name: + print(f"The value of TASKHUB is: {taskhub_name}") +else: + print("TASKHUB is not set. Please set the TASKHUB environment variable to the name of the taskhub you wish to use") + print("If you are using windows powershell, run the following: $env:TASKHUB=\"\"") + print("If you are using bash, run the following: export TASKHUB=\"\"") + exit() + +# Read the environment variable +endpoint = os.getenv("ENDPOINT") + +# Check if the variable exists +if endpoint: + print(f"The value of ENDPOINT is: {endpoint}") +else: + print("ENDPOINT is not set. Please set the ENDPOINT environment variable to the endpoint of the scheduler") + print("If you are using windows powershell, run the following: $env:ENDPOINT=\"\"") + print("If you are using bash, run the following: export ENDPOINT=\"\"") + exit() + +credential = DefaultAzureCredential() + +# configure and start the worker +with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) as w: + w.add_orchestrator(sequence) + w.add_activity(hello) + w.start() + + # Construct the client and run the orchestrations + c = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True, + taskhub=taskhub_name, token_credential=credential) + instance_id = c.schedule_new_orchestration(sequence) + state = c.wait_for_orchestration_completion(instance_id, timeout=60) + if state and state.runtime_status == OrchestrationStatus.COMPLETED: + print(f'Orchestration completed! Result: {state.serialized_output}') + elif state: + print(f'Orchestration failed: {state.failure_details}') From 61422208eb634f67ee729d6f4e2645d74e31539b Mon Sep 17 00:00:00 2001 From: Chris Gillum Date: Thu, 13 Feb 2025 21:59:49 -0800 Subject: [PATCH 28/31] Fix linter errors and update documentation --- CHANGELOG.md | 1 + README.md | 7 +- .../durabletask/azuremanaged/client.py | 17 +++-- .../internal/access_token_manager.py | 35 +++++---- .../durabletask_grpc_interceptor.py | 76 ++++++++++--------- .../durabletask/azuremanaged/worker.py | 21 ++--- durabletask/client.py | 26 +++---- durabletask/internal/grpc_interceptor.py | 12 +-- durabletask/internal/shared.py | 17 +++-- durabletask/task.py | 7 +- durabletask/worker.py | 17 ++--- examples/dts/README.md | 46 +++++------ examples/dts/dts_activity_sequence.py | 11 ++- examples/dts/dts_fanout_fanin.py | 12 +-- requirements.txt | 1 + 15 files changed, 168 insertions(+), 138 deletions(-) rename durabletask-azuremanaged/durabletask/azuremanaged/{ => internal}/durabletask_grpc_interceptor.py (70%) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee736f0..13b0e69 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added `set_custom_status` orchestrator API ([#31](https://github.com/microsoft/durabletask-python/pull/31)) - contributed by [@famarting](https://github.com/famarting) - Added `purge_orchestration` client API ([#34](https://github.com/microsoft/durabletask-python/pull/34)) - contributed by [@famarting](https://github.com/famarting) +- Added new `durabletask-azuremanaged` package for use with the [Durable Task Scheduler](https://techcommunity.microsoft.com/blog/appsonazureblog/announcing-limited-early-access-of-the-durable-task-scheduler-for-azure-durable-/4286526) - by [@RyanLettieri](https://github.com/RyanLettieri) ### Changes diff --git a/README.md b/README.md index 644635e..87af41d 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,14 @@ -# Durable Task Client SDK for Python +# Durable Task SDK for Python [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![Build Validation](https://github.com/microsoft/durabletask-python/actions/workflows/pr-validation.yml/badge.svg)](https://github.com/microsoft/durabletask-python/actions/workflows/pr-validation.yml) [![PyPI version](https://badge.fury.io/py/durabletask.svg)](https://badge.fury.io/py/durabletask) -This repo contains a Python client SDK for use with the [Durable Task Framework for Go](https://github.com/microsoft/durabletask-go) and [Dapr Workflow](https://docs.dapr.io/developing-applications/building-blocks/workflow/workflow-overview/). With this SDK, you can define, schedule, and manage durable orchestrations using ordinary Python code. +This repo contains a Python SDK for use with the [Azure Durable Task Scheduler](https://techcommunity.microsoft.com/blog/appsonazureblog/announcing-limited-early-access-of-the-durable-task-scheduler-for-azure-durable-/4286526) and the [Durable Task Framework for Go](https://github.com/microsoft/durabletask-go). With this SDK, you can define, schedule, and manage durable orchestrations using ordinary Python code. ⚠️ **This SDK is currently under active development and is not yet ready for production use.** ⚠️ -> Note that this project is **not** currently affiliated with the [Durable Functions](https://docs.microsoft.com/azure/azure-functions/durable/durable-functions-overview) project for Azure Functions. If you are looking for a Python SDK for Durable Functions, please see [this repo](https://github.com/Azure/azure-functions-durable-python). - +> Note that this SDK is **not** currently compatible with [Azure Durable Functions](https://docs.microsoft.com/azure/azure-functions/durable/durable-functions-overview). If you are looking for a Python SDK for Azure Durable Functions, please see [this repo](https://github.com/Azure/azure-functions-durable-python). ## Supported patterns diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/client.py b/durabletask-azuremanaged/durabletask/azuremanaged/client.py index eec0095..f641eae 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/client.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/client.py @@ -1,24 +1,25 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Optional -from durabletask.client import TaskHubGrpcClient, OrchestrationStatus -from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager -from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl from azure.core.credentials import TokenCredential +from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \ + DTSDefaultClientInterceptorImpl +from durabletask.client import TaskHubGrpcClient + + # Client class used for Durable Task Scheduler (DTS) class DurableTaskSchedulerClient(TaskHubGrpcClient): def __init__(self, *, host_address: str, taskhub: str, token_credential: TokenCredential, - secure_channel: Optional[bool] = True): + secure_channel: bool = True): - if taskhub == None: + if not taskhub: raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") - self._interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)] + interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)] # We pass in None for the metadata so we don't construct an additional interceptor in the parent class # Since the parent class doesn't use anything metadata for anything else, we can set it as None @@ -26,4 +27,4 @@ def __init__(self, *, host_address=host_address, secure_channel=secure_channel, metadata=None, - interceptors=self._interceptors) + interceptors=interceptors) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py index d12b29c..f0e7a42 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/access_token_manager.py @@ -1,24 +1,33 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from azure.identity import DefaultAzureCredential, ManagedIdentityCredential from datetime import datetime, timedelta, timezone from typing import Optional + +from azure.core.credentials import AccessToken, TokenCredential + import durabletask.internal.shared as shared -from azure.core.credentials import TokenCredential + # By default, when there's 10minutes left before the token expires, refresh the token class AccessTokenManager: - def __init__(self, refresh_interval_seconds: int = 600, token_credential: TokenCredential = None): + + _token: Optional[AccessToken] + + def __init__(self, token_credential: Optional[TokenCredential], refresh_interval_seconds: int = 600): self._scope = "https://durabletask.io/.default" self._refresh_interval_seconds = refresh_interval_seconds self._logger = shared.get_logger("token_manager") self._credential = token_credential - - self._token = self._credential.get_token(self._scope) - self.expiry_time = None - def get_access_token(self) -> str: + if self._credential is not None: + self._token = self._credential.get_token(self._scope) + self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc) + else: + self._token = None + self.expiry_time = None + + def get_access_token(self) -> Optional[AccessToken]: if self._token is None or self.is_token_expired(): self.refresh_token() return self._token @@ -32,9 +41,9 @@ def is_token_expired(self) -> bool: return datetime.now(timezone.utc) >= (self.expiry_time - timedelta(seconds=self._refresh_interval_seconds)) def refresh_token(self): - new_token = self._credential.get_token(self._scope) - self._token = f"Bearer {new_token.token}" - - # Convert UNIX timestamp to timezone-aware datetime - self.expiry_time = datetime.fromtimestamp(new_token.expires_on, tz=timezone.utc) - self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}") \ No newline at end of file + if self._credential is not None: + self._token = self._credential.get_token(self._scope) + + # Convert UNIX timestamp to timezone-aware datetime + self.expiry_time = datetime.fromtimestamp(self._token.expires_on, tz=timezone.utc) + self._logger.debug(f"Token refreshed. Expires at: {self.expiry_time}") diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py b/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py similarity index 70% rename from durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py rename to durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py index 5a63f4d..a23cac9 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/durabletask_grpc_interceptor.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/internal/durabletask_grpc_interceptor.py @@ -1,35 +1,41 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -from durabletask.internal.grpc_interceptor import _ClientCallDetails, DefaultClientInterceptorImpl -from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager -from azure.core.credentials import TokenCredential -import grpc - -class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): - """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an - interceptor to add additional headers to all calls as needed.""" - - def __init__(self, token_credential: TokenCredential, taskhub_name: str): - self._metadata = [("taskhub", taskhub_name)] - super().__init__(self._metadata) - - if token_credential is not None: - self._token_credential = token_credential - self._token_manager = AccessTokenManager(token_credential=self._token_credential) - token = self._token_manager.get_access_token() - self._metadata.append(("authorization", token)) - - def _intercept_call( - self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: - """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC - call details.""" - # Refresh the auth token if it is present and needed - if self._metadata is not None: - for i, (key, _) in enumerate(self._metadata): - if key.lower() == "authorization": # Ensure case-insensitive comparison - new_token = self._token_manager.get_access_token() # Get the new token - self._metadata[i] = ("authorization", new_token) # Update the token - - return super()._intercept_call(client_call_details) +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import grpc +from azure.core.credentials import TokenCredential + +from durabletask.azuremanaged.internal.access_token_manager import \ + AccessTokenManager +from durabletask.internal.grpc_interceptor import ( + DefaultClientInterceptorImpl, _ClientCallDetails) + + +class DTSDefaultClientInterceptorImpl (DefaultClientInterceptorImpl): + """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, + StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an + interceptor to add additional headers to all calls as needed.""" + + def __init__(self, token_credential: TokenCredential, taskhub_name: str): + self._metadata = [("taskhub", taskhub_name)] + super().__init__(self._metadata) + + if token_credential is not None: + self._token_credential = token_credential + self._token_manager = AccessTokenManager(token_credential=self._token_credential) + access_token = self._token_manager.get_access_token() + if access_token is not None: + self._metadata.append(("authorization", f"Bearer {access_token.token}")) + + def _intercept_call( + self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC + call details.""" + # Refresh the auth token if it is present and needed + if self._metadata is not None: + for i, (key, _) in enumerate(self._metadata): + if key.lower() == "authorization": # Ensure case-insensitive comparison + new_token = self._token_manager.get_access_token() # Get the new token + if new_token is not None: + self._metadata[i] = ("authorization", f"Bearer {new_token.token}") # Update the token + + return super()._intercept_call(client_call_details) diff --git a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py index ff971b2..d10c2f7 100644 --- a/durabletask-azuremanaged/durabletask/azuremanaged/worker.py +++ b/durabletask-azuremanaged/durabletask/azuremanaged/worker.py @@ -1,22 +1,23 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Optional -from durabletask.worker import TaskHubGrpcWorker -from durabletask.azuremanaged.internal.access_token_manager import AccessTokenManager -from durabletask.azuremanaged.durabletask_grpc_interceptor import DTSDefaultClientInterceptorImpl from azure.core.credentials import TokenCredential +from durabletask.azuremanaged.internal.durabletask_grpc_interceptor import \ + DTSDefaultClientInterceptorImpl +from durabletask.worker import TaskHubGrpcWorker + + # Worker class used for Durable Task Scheduler (DTS) class DurableTaskSchedulerWorker(TaskHubGrpcWorker): def __init__(self, *, host_address: str, taskhub: str, token_credential: TokenCredential, - secure_channel: Optional[bool] = True): - - if taskhub == None: - raise ValueError("Taskhub value cannot be empty. Please provide a value for your taskhub") + secure_channel: bool = True): + + if not taskhub: + raise ValueError("The taskhub value cannot be empty.") interceptors = [DTSDefaultClientInterceptorImpl(token_credential, taskhub)] @@ -25,5 +26,5 @@ def __init__(self, *, super().__init__( host_address=host_address, secure_channel=secure_channel, - metadata=None, - interceptors=interceptors) \ No newline at end of file + metadata=None, + interceptors=interceptors) diff --git a/durabletask/client.py b/durabletask/client.py index 7c1a5b6..60e194f 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Optional, TypeVar, Union +from typing import Any, Optional, Sequence, TypeVar, Union import grpc from google.protobuf import wrappers_pb2 @@ -15,9 +15,8 @@ import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl - from durabletask import task +from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl TInput = TypeVar('TInput') TOutput = TypeVar('TOutput') @@ -99,22 +98,23 @@ def __init__(self, *, log_handler: Optional[logging.Handler] = None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, - interceptors: Optional[list[Union[grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor]]] = None): + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None): - # Determine the interceptors to use + # If the caller provided metadata, we need to create a new interceptor for it and + # add it to the list of interceptors. if interceptors is not None: - self._interceptors = interceptors - if metadata: - self._interceptors.append(DefaultClientInterceptorImpl(metadata)) - elif metadata: - self._interceptors = [DefaultClientInterceptorImpl(metadata)] + interceptors = list(interceptors) + if metadata is not None: + interceptors.append(DefaultClientInterceptorImpl(metadata)) + elif metadata is not None: + interceptors = [DefaultClientInterceptorImpl(metadata)] else: - self._interceptors = None + interceptors = None channel = shared.get_grpc_channel( host_address=host_address, secure_channel=secure_channel, - interceptors=self._interceptors + interceptors=interceptors ) self._stub = stubs.TaskHubSidecarServiceStub(channel) self._logger = shared.get_logger("client", log_handler, log_formatter) @@ -134,7 +134,7 @@ def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInpu scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, version=wrappers_pb2.StringValue(value=""), orchestrationIdReusePolicy=reuse_id_policy, - ) + ) self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") res: pb.CreateInstanceResponse = self._stub.StartInstance(req) diff --git a/durabletask/internal/grpc_interceptor.py b/durabletask/internal/grpc_interceptor.py index 738fca9..69db3c5 100644 --- a/durabletask/internal/grpc_interceptor.py +++ b/durabletask/internal/grpc_interceptor.py @@ -19,10 +19,10 @@ class _ClientCallDetails( class DefaultClientInterceptorImpl ( - grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, - grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): + grpc.UnaryUnaryClientInterceptor, grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, grpc.StreamStreamClientInterceptor): """The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor, - StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an + StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an interceptor to add additional headers to all calls as needed.""" def __init__(self, metadata: list[tuple[str, str]]): @@ -30,17 +30,17 @@ def __init__(self, metadata: list[tuple[str, str]]): self._metadata = metadata def _intercept_call( - self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: + self, client_call_details: _ClientCallDetails) -> grpc.ClientCallDetails: """Internal intercept_call implementation which adds metadata to grpc metadata in the RPC call details.""" if self._metadata is None: return client_call_details - + if client_call_details.metadata is not None: metadata = list(client_call_details.metadata) else: metadata = [] - + metadata.extend(self._metadata) client_call_details = _ClientCallDetails( client_call_details.method, client_call_details.timeout, metadata, diff --git a/durabletask/internal/shared.py b/durabletask/internal/shared.py index 91a561d..1872ad4 100644 --- a/durabletask/internal/shared.py +++ b/durabletask/internal/shared.py @@ -5,11 +5,17 @@ import json import logging from types import SimpleNamespace -from typing import Any, Optional -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl +from typing import Any, Optional, Sequence, Union import grpc +ClientInterceptor = Union[ + grpc.UnaryUnaryClientInterceptor, + grpc.UnaryStreamClientInterceptor, + grpc.StreamUnaryClientInterceptor, + grpc.StreamStreamClientInterceptor +] + # Field name used to indicate that an object was automatically serialized # and should be deserialized as a SimpleNamespace AUTO_SERIALIZED = "__durabletask_autoobject__" @@ -25,8 +31,8 @@ def get_default_host_address() -> str: def get_grpc_channel( host_address: Optional[str], secure_channel: bool = False, - interceptors: Optional[list[DefaultClientInterceptorImpl]] = None) -> grpc.Channel: - + interceptors: Optional[Sequence[ClientInterceptor]] = None) -> grpc.Channel: + if host_address is None: host_address = get_default_host_address() @@ -55,6 +61,7 @@ def get_grpc_channel( channel = grpc.intercept_channel(channel, *interceptors) return channel + def get_logger( name_suffix: str, log_handler: Optional[logging.Handler] = None, @@ -99,7 +106,7 @@ def default(self, obj): if dataclasses.is_dataclass(obj): # Dataclasses are not serializable by default, so we convert them to a dict and mark them for # automatic deserialization by the receiver - d = dataclasses.asdict(obj) # type: ignore + d = dataclasses.asdict(obj) # type: ignore d[AUTO_SERIALIZED] = True return d elif isinstance(obj, SimpleNamespace): diff --git a/durabletask/task.py b/durabletask/task.py index a40602b..9e8a08a 100644 --- a/durabletask/task.py +++ b/durabletask/task.py @@ -277,6 +277,7 @@ def get_tasks(self) -> list[Task]: def on_child_completed(self, task: Task[T]): pass + class WhenAllTask(CompositeTask[list[T]]): """A task that completes when all of its child tasks complete.""" @@ -333,7 +334,7 @@ class RetryableTask(CompletableTask[T]): """A task that can be retried according to a retry policy.""" def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction, - start_time:datetime, is_sub_orch: bool) -> None: + start_time: datetime, is_sub_orch: bool) -> None: super().__init__() self._action = action self._retry_policy = retry_policy @@ -343,7 +344,7 @@ def __init__(self, retry_policy: RetryPolicy, action: pb.OrchestratorAction, def increment_attempt_count(self) -> None: self._attempt_count += 1 - + def compute_next_delay(self) -> Optional[timedelta]: if self._attempt_count >= self._retry_policy.max_number_of_attempts: return None @@ -351,7 +352,7 @@ def compute_next_delay(self) -> Optional[timedelta]: retry_expiration: datetime = datetime.max if self._retry_policy.retry_timeout is not None and self._retry_policy.retry_timeout != datetime.max: retry_expiration = self._start_time + self._retry_policy.retry_timeout - + if self._retry_policy.backoff_coefficient is None: backoff_coefficient = 1.0 else: diff --git a/durabletask/worker.py b/durabletask/worker.py index 7e25a59..2c31e52 100644 --- a/durabletask/worker.py +++ b/durabletask/worker.py @@ -9,14 +9,13 @@ from typing import Any, Generator, Optional, Sequence, TypeVar, Union import grpc -from google.protobuf import empty_pb2, wrappers_pb2 +from google.protobuf import empty_pb2 import durabletask.internal.helpers as ph import durabletask.internal.helpers as pbh import durabletask.internal.orchestrator_service_pb2 as pb import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared - from durabletask import task from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl @@ -84,6 +83,7 @@ class ActivityNotRegisteredError(ValueError): class TaskHubGrpcWorker: _response_stream: Optional[grpc.Future] = None + _interceptors: Optional[list[shared.ClientInterceptor]] = None def __init__(self, *, host_address: Optional[str] = None, @@ -91,10 +91,9 @@ def __init__(self, *, log_handler=None, log_formatter: Optional[logging.Formatter] = None, secure_channel: bool = False, - interceptors: Optional[list[DefaultClientInterceptorImpl]] = None): # Add interceptors + interceptors: Optional[Sequence[shared.ClientInterceptor]] = None): self._registry = _Registry() self._host_address = host_address if host_address else shared.get_default_host_address() - self._metadata = metadata self._logger = shared.get_logger("worker", log_handler, log_formatter) self._shutdown = Event() self._is_running = False @@ -102,15 +101,14 @@ def __init__(self, *, # Determine the interceptors to use if interceptors is not None: - self._interceptors = interceptors + self._interceptors = list(interceptors) if metadata: self._interceptors.append(DefaultClientInterceptorImpl(metadata)) - elif self._metadata: - self._interceptors = [DefaultClientInterceptorImpl(self._metadata)] + elif metadata: + self._interceptors = [DefaultClientInterceptorImpl(metadata)] else: self._interceptors = None - def __enter__(self): return self @@ -161,7 +159,7 @@ def run_loop(): elif work_item.HasField('activityRequest'): executor.submit(self._execute_activity, work_item.activityRequest, stub, work_item.completionToken) elif work_item.HasField('healthPing'): - pass # no-op + pass # no-op else: self._logger.warning(f'Unexpected work item type: {request_type}') @@ -490,6 +488,7 @@ def __init__(self, actions: list[pb.OrchestratorAction], encoded_custom_status: self.actions = actions self.encoded_custom_status = encoded_custom_status + class _OrchestrationExecutor: _generator: Optional[task.Orchestrator] = None diff --git a/examples/dts/README.md b/examples/dts/README.md index 5c34b03..b75feff 100644 --- a/examples/dts/README.md +++ b/examples/dts/README.md @@ -11,40 +11,40 @@ The simplest way to create a taskhub is by using the az cli commands: 1. Create a scheduler: az durabletask scheduler create --resource-group --name --location --ip-allowlist "[0.0.0.0/0]" --sku-capacity 1 --sku-name "Dedicated" --tags "{}" -2. Create your taskhub - az durabletask taskhub create --resource-group --scheduler-name --name +1. Create your taskhub -3. Retrieve the endpoint for the scheduler. This can be done by locating the taskhub in the portal. + ```bash + az durabletask taskhub create --resource-group --scheduler-name --name + ``` -4. Set the appropriate environment variables for the TASKHUB and ENDPOINT +1. Retrieve the endpoint for the scheduler. This can be done by locating the taskhub in the portal. -```sh -export TASKHUB= -``` +1. Set the appropriate environment variables for the TASKHUB and ENDPOINT -```sh -export ENDPOINT= -``` + ```bash + export TASKHUB= + export ENDPOINT= + ``` -5. Since the samples rely on azure identity, ensure the package is installed and up-to-date +1. Since the samples rely on azure identity, ensure the package is installed and up-to-date -```sh -python3 -m pip install azure-identity -``` + ```bash + python3 -m pip install azure-identity + ``` -6. Install the correct pacakges from the top level of this repository, i.e. durabletask-python/ +1. Install the correct packages from the top level of this repository, i.e. durabletask-python/ -```sh -python3 -m pip install . -``` + ```bash + python3 -m pip install . + ``` -7. Install the DTS specific packages from the durabletask-python/durabletask-azuremanaged directory +1. Install the DTS specific packages from the durabletask-python/durabletask-azuremanaged directory -```sh -pip3 install -e . -``` + ```bash + pip3 install -e . + ``` -8. Grant yourself the DurableTaskDataContributor role over your scheduler +1. Grant yourself the `Durable Task Data Contributor` role over your scheduler ## Running the examples diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index 1ee5a93..dd19e83 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -1,11 +1,14 @@ """End-to-end sample that demonstrates how to configure an orchestrator that calls an activity function in a sequence and prints the outputs.""" import os -from durabletask import task -from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker -from durabletask.azuremanaged.client import DurableTaskSchedulerClient, OrchestrationStatus + from azure.identity import DefaultAzureCredential +from durabletask import client, task +from durabletask.azuremanaged.client import DurableTaskSchedulerClient +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + + def hello(ctx: task.ActivityContext, name: str) -> str: """Activity function that returns a greeting""" return f'Hello {name}!' @@ -60,7 +63,7 @@ def sequence(ctx: task.OrchestrationContext, _): taskhub=taskhub_name, token_credential=credential) instance_id = c.schedule_new_orchestration(sequence) state = c.wait_for_orchestration_completion(instance_id, timeout=60) - if state and state.runtime_status == OrchestrationStatus.COMPLETED: + if state and state.runtime_status == client.OrchestrationStatus.COMPLETED: print(f'Orchestration completed! Result: {state.serialized_output}') elif state: print(f'Orchestration failed: {state.failure_details}') diff --git a/examples/dts/dts_fanout_fanin.py b/examples/dts/dts_fanout_fanin.py index eace14a..8ab68df 100644 --- a/examples/dts/dts_fanout_fanin.py +++ b/examples/dts/dts_fanout_fanin.py @@ -1,14 +1,16 @@ """End-to-end sample that demonstrates how to configure an orchestrator that a dynamic number activity functions in parallel, waits for them all to complete, and prints an aggregate summary of the outputs.""" +import os import random import time -import os -from durabletask import client, task + +from azure.identity import DefaultAzureCredential + from durabletask import client, task -from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker from durabletask.azuremanaged.client import DurableTaskSchedulerClient -from azure.identity import DefaultAzureCredential +from durabletask.azuremanaged.worker import DurableTaskSchedulerWorker + def get_work_items(ctx: task.ActivityContext, _) -> list[str]: """Activity function that returns a list of work items""" @@ -91,4 +93,4 @@ def orchestrator(ctx: task.OrchestrationContext, _): print(f'Orchestration completed! Result: {state.serialized_output}') elif state: print(f'Orchestration failed: {state.failure_details}') - exit() \ No newline at end of file + exit() diff --git a/requirements.txt b/requirements.txt index 49896d3..0da7d46 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,4 +3,5 @@ grpcio>=1.60.0 # 1.60.0 is the version introducing protobuf 1.25.X support, newe protobuf pytest pytest-cov +azure-core azure-identity \ No newline at end of file From 58f4f93607eb30954e211d44262fe3d1cdc9b820 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Mon, 17 Feb 2025 19:31:09 -0700 Subject: [PATCH 29/31] Specifying version reqiuirement for pyproject.toml Signed-off-by: Ryan Lettieri --- durabletask-azuremanaged/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/durabletask-azuremanaged/pyproject.toml b/durabletask-azuremanaged/pyproject.toml index a23e4a8..ac6be6f 100644 --- a/durabletask-azuremanaged/pyproject.toml +++ b/durabletask-azuremanaged/pyproject.toml @@ -26,8 +26,8 @@ requires-python = ">=3.9" license = {file = "LICENSE"} readme = "README.md" dependencies = [ - "durabletask", - "azure-identity" + "durabletask>=0.2.0", + "azure-identity>=1.19.0" ] [project.urls] From d82c1b78c87d4988ee5734f7b04f63e3eb854b0c Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Mon, 17 Feb 2025 19:34:53 -0700 Subject: [PATCH 30/31] Updating README Signed-off-by: Ryan Lettieri --- examples/dts/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dts/README.md b/examples/dts/README.md index b75feff..9b4a3fd 100644 --- a/examples/dts/README.md +++ b/examples/dts/README.md @@ -1,6 +1,6 @@ # Examples -This directory contains examples of how to author durable orchestrations using the Durable Task Python SDK in conjunction with the Durable Task Scheduler (DTS). +This directory contains examples of how to author durable orchestrations using the Durable Task Python SDK in conjunction with the Durable Task Scheduler (DTS). Please note that the installation instructions provided below will use the version of DTS directly from the your branch rather than installing through PyPI. ## Prerequisites From b3a099e81dbeaa37d29795354d4789eb40a4a496 Mon Sep 17 00:00:00 2001 From: Ryan Lettieri Date: Mon, 17 Feb 2025 19:37:20 -0700 Subject: [PATCH 31/31] Adding comment for credential type Signed-off-by: Ryan Lettieri --- examples/dts/dts_activity_sequence.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/dts/dts_activity_sequence.py b/examples/dts/dts_activity_sequence.py index dd19e83..2ff3c22 100644 --- a/examples/dts/dts_activity_sequence.py +++ b/examples/dts/dts_activity_sequence.py @@ -49,6 +49,8 @@ def sequence(ctx: task.OrchestrationContext, _): print("If you are using bash, run the following: export ENDPOINT=\"\"") exit() +# Note that any azure-identity credential type and configuration can be used here as DTS supports various credential +# types such as Managed Identities credential = DefaultAzureCredential() # configure and start the worker