Skip to content

Commit 552a2dd

Browse files
committed
Working orchestrators + activities
1 parent 7ff1525 commit 552a2dd

File tree

10 files changed

+306
-62
lines changed

10 files changed

+306
-62
lines changed
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import json
5+
6+
from datetime import timedelta
7+
from typing import Any, Optional
8+
import azure.functions as func
9+
10+
from durabletask.entities import EntityInstanceId
11+
from durabletask.client import TaskHubGrpcClient
12+
from durabletask.azurefunctions.internal.azurefunctions_grpc_interceptor import AzureFunctionsDefaultClientInterceptorImpl
13+
14+
15+
# Client class used for Durable Functions
16+
class DurableFunctionsClient(TaskHubGrpcClient):
17+
taskHubName: str
18+
connectionName: str
19+
creationUrls: dict[str, str]
20+
managementUrls: dict[str, str]
21+
baseUrl: str
22+
requiredQueryStringParameters: str
23+
rpcBaseUrl: str
24+
httpBaseUrl: str
25+
maxGrpcMessageSizeInBytes: int
26+
grpcHttpClientTimeout: timedelta
27+
28+
def __init__(self, client_as_string: str):
29+
client = json.loads(client_as_string)
30+
31+
self.taskHubName = client.get("taskHubName", "")
32+
self.connectionName = client.get("connectionName", "")
33+
self.creationUrls = client.get("creationUrls", {})
34+
self.managementUrls = client.get("managementUrls", {})
35+
self.baseUrl = client.get("baseUrl", "")
36+
self.requiredQueryStringParameters = client.get("requiredQueryStringParameters", "")
37+
self.rpcBaseUrl = client.get("rpcBaseUrl", "")
38+
self.httpBaseUrl = client.get("httpBaseUrl", "")
39+
self.maxGrpcMessageSizeInBytes = client.get("maxGrpcMessageSizeInBytes", 0)
40+
# TODO: convert the string value back to timedelta - annoying regex?
41+
self.grpcHttpClientTimeout = client.get("grpcHttpClientTimeout", timedelta(seconds=30))
42+
interceptors = [AzureFunctionsDefaultClientInterceptorImpl(self.taskHubName, self.requiredQueryStringParameters)]
43+
44+
# We pass in None for the metadata so we don't construct an additional interceptor in the parent class
45+
# Since the parent class doesn't use anything metadata for anything else, we can set it as None
46+
super().__init__(
47+
host_address=self.rpcBaseUrl,
48+
secure_channel=False,
49+
metadata=None,
50+
interceptors=interceptors)
51+
52+
def create_check_status_response(self, request: func.HttpRequest, instance_id: str) -> func.HttpResponse:
53+
"""Creates an HTTP response for checking the status of a Durable Function instance.
54+
55+
Args:
56+
request (func.HttpRequest): The incoming HTTP request.
57+
instance_id (str): The ID of the Durable Function instance.
58+
"""
59+
raise NotImplementedError("This method is not implemented yet.")
60+
61+
def create_http_management_payload(self, instance_id: str) -> dict[str, str]:
62+
"""Creates an HTTP management payload for a Durable Function instance.
63+
64+
Args:
65+
instance_id (str): The ID of the Durable Function instance.
66+
"""
67+
raise NotImplementedError("This method is not implemented yet.")
68+
69+
def read_entity_state(
70+
self,
71+
entity_id: EntityInstanceId,
72+
task_hub_name: Optional[str],
73+
connection_name: Optional[str]
74+
) -> tuple[bool, Any]:
75+
"""Reads the state of a Durable Entity.
76+
77+
Args:
78+
entity_id (str): The ID of the Durable Entity.
79+
task_hub_name (Optional[str]): The name of the task hub.
80+
connection_name (Optional[str]): The name of the connection.
81+
82+
Returns:
83+
(bool, Any): A tuple containing a boolean indicating if the entity exists and its state.
84+
"""
85+
raise NotImplementedError("This method is not implemented yet.")

durabletask-azurefunctions/durabletask/azurefunctions/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Constants used to determine the local running context."""
2-
# Todo: Remove unused constants after module is complete
2+
# TODO: Remove unused constants after module is complete
33
DEFAULT_LOCAL_HOST: str = 'localhost:7071'
44
DEFAULT_LOCAL_ORIGIN: str = f'http://{DEFAULT_LOCAL_HOST}'
55
DATETIME_STRING_FORMAT = '%Y-%m-%dT%H:%M:%S.%fZ'

durabletask-azurefunctions/durabletask/azurefunctions/decorators/durable_app.py

Lines changed: 82 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
11
# Copyright (c) Microsoft Corporation. All rights reserved.
22
# Licensed under the MIT License.
3+
import base64
4+
from functools import wraps
5+
6+
from durabletask.internal.orchestrator_service_pb2 import OrchestratorRequest, OrchestratorResponse
37
from .metadata import OrchestrationTrigger, ActivityTrigger, EntityTrigger, \
48
DurableClient
59
from typing import Callable, Optional
610
from typing import Union
7-
from azure.functions import FunctionRegister, TriggerApi, BindingApi, AuthLevel, OrchestrationContext
11+
from azure.functions import FunctionRegister, TriggerApi, BindingApi, AuthLevel
12+
13+
# TODO: Use __init__.py to optimize imports
14+
from durabletask.azurefunctions.client import DurableFunctionsClient
15+
from durabletask.azurefunctions.worker import DurableFunctionsWorker
16+
from durabletask.azurefunctions.internal.azurefunctions_null_stub import AzureFunctionsNullStub
817

918

1019
class Blueprint(TriggerApi, BindingApi):
@@ -37,9 +46,6 @@ def __init__(self,
3746
def _configure_orchestrator_callable(self, wrap) -> Callable:
3847
"""Obtain decorator to construct an Orchestrator class from a user-defined Function.
3948
40-
In the old programming model, this decorator's logic was unavoidable boilerplate
41-
in user-code. Now, this is handled internally by the framework.
42-
4349
Parameters
4450
----------
4551
wrap: Callable
@@ -54,14 +60,31 @@ def _configure_orchestrator_callable(self, wrap) -> Callable:
5460
def decorator(orchestrator_func):
5561
# Construct an orchestrator based on the end-user code
5662

57-
# TODO: Extract this logic (?)
58-
def handle(context: OrchestrationContext) -> str:
63+
# TODO: Move this logic somewhere better
64+
def handle(context) -> str:
5965
context_body = getattr(context, "body", None)
6066
if context_body is None:
6167
context_body = context
6268
orchestration_context = context_body
63-
# TODO: Run the orchestration using the context
64-
return ""
69+
request = OrchestratorRequest()
70+
request.ParseFromString(base64.b64decode(orchestration_context))
71+
stub = AzureFunctionsNullStub()
72+
worker = DurableFunctionsWorker()
73+
response: Optional[OrchestratorResponse] = None
74+
75+
def stub_complete(stub_response):
76+
nonlocal response
77+
response = stub_response
78+
stub.CompleteOrchestratorTask = stub_complete
79+
execution_started_events = [e for e in [e1 for e1 in request.newEvents] + [e2 for e2 in request.pastEvents] if e.HasField("executionStarted")]
80+
function_name = execution_started_events[-1].executionStarted.name
81+
worker.add_named_orchestrator(function_name, orchestrator_func)
82+
worker._execute_orchestrator(request, stub, None)
83+
84+
if response is None:
85+
raise Exception("Orchestrator execution did not produce a response.")
86+
# The Python worker returns the input as type "json", so double-encoding is necessary
87+
return '"' + base64.b64encode(response.SerializeToString()).decode('utf-8') + '"'
6588

6689
handle.orchestrator_function = orchestrator_func
6790

@@ -71,6 +94,55 @@ def handle(context: OrchestrationContext) -> str:
7194

7295
return decorator
7396

97+
def _configure_entity_callable(self, wrap) -> Callable:
98+
"""Obtain decorator to construct an Entity class from a user-defined Function.
99+
100+
Parameters
101+
----------
102+
wrap: Callable
103+
The next decorator to be applied.
104+
105+
Returns
106+
-------
107+
Callable
108+
The function to construct an Entity class from the user-defined Function,
109+
wrapped by the next decorator in the sequence.
110+
"""
111+
def decorator(entity_func):
112+
# TODO: Implement entity support - similar to orchestrators (?)
113+
raise NotImplementedError()
114+
115+
return decorator
116+
117+
def _add_rich_client(self, fb, parameter_name,
118+
client_constructor):
119+
# Obtain user-code and force type annotation on the client-binding parameter to be `str`.
120+
# This ensures a passing type-check of that specific parameter,
121+
# circumventing a limitation of the worker in type-checking rich DF Client objects.
122+
# TODO: Once rich-binding type checking is possible, remove the annotation change.
123+
user_code = fb._function._func
124+
user_code.__annotations__[parameter_name] = str
125+
126+
# `wraps` This ensures we re-export the same method-signature as the decorated method
127+
@wraps(user_code)
128+
async def df_client_middleware(*args, **kwargs):
129+
130+
# Obtain JSON-string currently passed as DF Client,
131+
# construct rich object from it,
132+
# and assign parameter to that rich object
133+
starter = kwargs[parameter_name]
134+
client = client_constructor(starter)
135+
kwargs[parameter_name] = client
136+
137+
# Invoke user code with rich DF Client binding
138+
return await user_code(*args, **kwargs)
139+
140+
# TODO: Is there a better way to support retrieving the unwrapped user code?
141+
df_client_middleware.client_function = fb._function._func # type: ignore
142+
143+
user_code_with_rich_client = df_client_middleware
144+
fb._function._func = user_code_with_rich_client
145+
74146
def orchestration_trigger(self, context_name: str,
75147
orchestration: Optional[str] = None):
76148
"""Register an Orchestrator Function.
@@ -133,6 +205,7 @@ def entity_trigger(self, context_name: str,
133205
Name of Entity Function.
134206
The value is None by default, in which case the name of the method is used.
135207
"""
208+
@self._configure_entity_callable
136209
@self._configure_function_builder
137210
def wrap(fb):
138211
def decorator():
@@ -171,7 +244,7 @@ def durable_client_input(self,
171244
@self._configure_function_builder
172245
def wrap(fb):
173246
def decorator():
174-
# self._add_rich_client(fb, client_name, DurableOrchestrationClient)
247+
self._add_rich_client(fb, client_name, DurableFunctionsClient)
175248

176249
fb.add_binding(
177250
binding=DurableClient(name=client_name,

durabletask-azurefunctions/durabletask/azurefunctions/internal/DurableClientConverter.py

Lines changed: 0 additions & 46 deletions
This file was deleted.
Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +0,0 @@
1-
from .DurableClientConverter import DurableClientConverter
2-
3-
__all__ = ["DurableClientConverter"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from importlib.metadata import version
5+
6+
from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl
7+
8+
9+
class AzureFunctionsDefaultClientInterceptorImpl (DefaultClientInterceptorImpl):
10+
"""The class implements a UnaryUnaryClientInterceptor, UnaryStreamClientInterceptor,
11+
StreamUnaryClientInterceptor and StreamStreamClientInterceptor from grpc to add an
12+
interceptor to add additional headers to all calls as needed."""
13+
required_query_string_parameters: str
14+
15+
def __init__(self, taskhub_name: str, required_query_string_parameters: str):
16+
self.required_query_string_parameters = required_query_string_parameters
17+
try:
18+
# Get the version of the azurefunctions package
19+
sdk_version = version('durabletask-azurefunctions')
20+
except Exception:
21+
# Fallback if version cannot be determined
22+
sdk_version = "unknown"
23+
user_agent = f"durabletask-python/{sdk_version}"
24+
self._metadata = [
25+
("taskhub", taskhub_name),
26+
("x-user-agent", user_agent)] # 'user-agent' is a reserved header in grpc, so we use 'x-user-agent' instead
27+
super().__init__(self._metadata)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
2+
from durabletask.internal.ProtoTaskHubSidecarServiceStub import ProtoTaskHubSidecarServiceStub
3+
4+
5+
class AzureFunctionsNullStub(ProtoTaskHubSidecarServiceStub):
6+
"""Missing associated documentation comment in .proto file."""
7+
8+
def __init__(self):
9+
"""Constructor.
10+
11+
Args:
12+
channel: A grpc.Channel.
13+
"""
14+
self.Hello = lambda *args, **kwargs: None
15+
self.StartInstance = lambda *args, **kwargs: None
16+
self.GetInstance = lambda *args, **kwargs: None
17+
self.RewindInstance = lambda *args, **kwargs: None
18+
self.WaitForInstanceStart = lambda *args, **kwargs: None
19+
self.WaitForInstanceCompletion = lambda *args, **kwargs: None
20+
self.RaiseEvent = lambda *args, **kwargs: None
21+
self.TerminateInstance = lambda *args, **kwargs: None
22+
self.SuspendInstance = lambda *args, **kwargs: None
23+
self.ResumeInstance = lambda *args, **kwargs: None
24+
self.QueryInstances = lambda *args, **kwargs: None
25+
self.PurgeInstances = lambda *args, **kwargs: None
26+
self.GetWorkItems = lambda *args, **kwargs: None
27+
self.CompleteActivityTask = lambda *args, **kwargs: None
28+
self.CompleteOrchestratorTask = lambda *args, **kwargs: None
29+
self.CompleteEntityTask = lambda *args, **kwargs: None
30+
self.StreamInstanceHistory = lambda *args, **kwargs: None
31+
self.CreateTaskHub = lambda *args, **kwargs: None
32+
self.DeleteTaskHub = lambda *args, **kwargs: None
33+
self.SignalEntity = lambda *args, **kwargs: None
34+
self.GetEntity = lambda *args, **kwargs: None
35+
self.QueryEntities = lambda *args, **kwargs: None
36+
self.CleanEntityStorage = lambda *args, **kwargs: None
37+
self.AbandonTaskActivityWorkItem = lambda *args, **kwargs: None
38+
self.AbandonTaskOrchestratorWorkItem = lambda *args, **kwargs: None
39+
self.AbandonTaskEntityWorkItem = lambda *args, **kwargs: None
Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,32 @@
1-
class TempClass:
2-
pass
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
from threading import Event
5+
from durabletask.worker import _Registry, ConcurrencyOptions
6+
from durabletask.internal import shared
7+
from durabletask.worker import TaskHubGrpcWorker
8+
9+
10+
# Worker class used for Durable Task Scheduler (DTS)
11+
class DurableFunctionsWorker(TaskHubGrpcWorker):
12+
"""TOOD: Docs
13+
"""
14+
15+
def __init__(self):
16+
# Don't call the parent constructor - we don't actually want to start an AsyncWorkerLoop
17+
# or recieve work items from anywhere but the method that is creating this worker
18+
self._registry = _Registry()
19+
self._host_address = ""
20+
self._logger = shared.get_logger("worker")
21+
self._shutdown = Event()
22+
self._is_running = False
23+
self._secure_channel = False
24+
25+
self._concurrency_options = ConcurrencyOptions()
26+
27+
self._interceptors = None
28+
29+
def add_named_orchestrator(self, name: str, func):
30+
"""TOOD: Docs
31+
"""
32+
self._registry.add_named_orchestrator(name, func)

0 commit comments

Comments
 (0)