Skip to content

Commit f5de04f

Browse files
andystaplesCopilot
andauthored
Allow calling sub-orchestrator by name (#69)
* Allow calling sub-orchestrator by name * Apply suggestions from code review Co-authored-by: Copilot <[email protected]> --------- Co-authored-by: Copilot <[email protected]>
1 parent eba1e8e commit f5de04f

File tree

4 files changed

+60
-3
lines changed

4 files changed

+60
-3
lines changed

durabletask/task.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> Task[EntityLock]:
201201
pass
202202

203203
@abstractmethod
204-
def call_sub_orchestrator(self, orchestrator: Orchestrator[TInput, TOutput], *,
204+
def call_sub_orchestrator(self, orchestrator: Union[Orchestrator[TInput, TOutput], str], *,
205205
input: Optional[TInput] = None,
206206
instance_id: Optional[str] = None,
207207
retry_policy: Optional[RetryPolicy] = None,

durabletask/worker.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,15 +1029,18 @@ def lock_entities(self, entities: list[EntityInstanceId]) -> task.Task[EntityLoc
10291029

10301030
def call_sub_orchestrator(
10311031
self,
1032-
orchestrator: task.Orchestrator[TInput, TOutput],
1032+
orchestrator: Union[task.Orchestrator[TInput, TOutput], str],
10331033
*,
10341034
input: Optional[TInput] = None,
10351035
instance_id: Optional[str] = None,
10361036
retry_policy: Optional[task.RetryPolicy] = None,
10371037
version: Optional[str] = None,
10381038
) -> task.Task[TOutput]:
10391039
id = self.next_sequence_number()
1040-
orchestrator_name = task.get_name(orchestrator)
1040+
if isinstance(orchestrator, str):
1041+
orchestrator_name = orchestrator
1042+
else:
1043+
orchestrator_name = task.get_name(orchestrator)
10411044
default_version = self._registry.versioning.default_version if self._registry.versioning else None
10421045
orchestrator_version = version if version else default_version
10431046
self.call_activity_function_helper(

tests/durabletask-azuremanaged/test_dts_orchestration_e2e.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,34 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int):
175175
assert activity_counter == 30
176176

177177

178+
def test_sub_orchestrator_by_name():
179+
sub_orchestrator_counter = 0
180+
181+
def orchestrator_child(ctx: task.OrchestrationContext, _):
182+
nonlocal sub_orchestrator_counter
183+
sub_orchestrator_counter += 1
184+
185+
def parent_orchestrator(ctx: task.OrchestrationContext, _):
186+
yield ctx.call_sub_orchestrator("orchestrator_child")
187+
188+
# Start a worker, which will connect to the sidecar in a background thread
189+
with DurableTaskSchedulerWorker(host_address=endpoint, secure_channel=True,
190+
taskhub=taskhub_name, token_credential=None) as w:
191+
w.add_orchestrator(orchestrator_child)
192+
w.add_orchestrator(parent_orchestrator)
193+
w.start()
194+
195+
task_hub_client = DurableTaskSchedulerClient(host_address=endpoint, secure_channel=True,
196+
taskhub=taskhub_name, token_credential=None)
197+
id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=None)
198+
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
199+
200+
assert state is not None
201+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
202+
assert state.failure_details is None
203+
assert sub_orchestrator_counter == 1
204+
205+
178206
def test_wait_for_multiple_external_events():
179207
def orchestrator(ctx: task.OrchestrationContext, _):
180208
a = yield ctx.wait_for_external_event('A')

tests/durabletask/test_orchestration_e2e.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,32 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int):
162162
assert activity_counter == 30
163163

164164

165+
def test_sub_orchestrator_by_name():
166+
sub_orchestrator_counter = 0
167+
168+
def orchestrator_child(ctx: task.OrchestrationContext, _):
169+
nonlocal sub_orchestrator_counter
170+
sub_orchestrator_counter += 1
171+
172+
def parent_orchestrator(ctx: task.OrchestrationContext, _):
173+
yield ctx.call_sub_orchestrator("orchestrator_child")
174+
175+
# Start a worker, which will connect to the sidecar in a background thread
176+
with worker.TaskHubGrpcWorker() as w:
177+
w.add_orchestrator(orchestrator_child)
178+
w.add_orchestrator(parent_orchestrator)
179+
w.start()
180+
181+
task_hub_client = client.TaskHubGrpcClient()
182+
id = task_hub_client.schedule_new_orchestration(parent_orchestrator, input=None)
183+
state = task_hub_client.wait_for_orchestration_completion(id, timeout=30)
184+
185+
assert state is not None
186+
assert state.runtime_status == client.OrchestrationStatus.COMPLETED
187+
assert state.failure_details is None
188+
assert sub_orchestrator_counter == 1
189+
190+
165191
def test_wait_for_multiple_external_events():
166192
def orchestrator(ctx: task.OrchestrationContext, _):
167193
a = yield ctx.wait_for_external_event('A')

0 commit comments

Comments
 (0)