Skip to content

Commit b460110

Browse files
committed
Switch to ActivityCallTracker
1 parent 5f6ad04 commit b460110

File tree

6 files changed

+57
-37
lines changed

6 files changed

+57
-37
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from azure.durable_functions.models.DurableOrchestrationContext import DurableOrchestrationContext
2+
from yield_exception import YieldException
3+
4+
class ActivityCallTracker:
5+
def __init__(self, context: DurableOrchestrationContext):
6+
self.context = context
7+
self.activities_called = 0
8+
self.tasks_to_yield = []
9+
10+
def call_activity(self, activity_name, input: str):
11+
task = self.context.call_activity(activity_name, input)
12+
13+
self.activities_called += 1
14+
15+
histories = self.context.histories
16+
completed_tasks = [entry for entry in histories if entry.event_type == 5]
17+
if len(completed_tasks) < self.activities_called:
18+
# yield immediately
19+
raise YieldException(task)
20+
else:
21+
# yield later
22+
self.tasks_to_yield.append(task)
23+
return completed_tasks[self.activities_called - 1].Result

samples-v2/openai_agents/durable_model_stub.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
HandoffInput,
99
AgentOutputSchemaInput
1010
)
11-
from orchestrator_exceptions import OrchestratorYielded
11+
from activity_call_tracker import ActivityCallTracker
1212
import azure.durable_functions as df
1313

1414
import json
@@ -44,10 +44,10 @@ class _DurableModelStub(Model):
4444
def __init__(
4545
self,
4646
model_name: Optional[str],
47-
durable_orchestration_context: df.DurableOrchestrationContext,
47+
activity_call_tracker: ActivityCallTracker,
4848
) -> None:
4949
self.model_name = model_name
50-
self.durable_orchestration_context = durable_orchestration_context
50+
self.activity_call_tracker = activity_call_tracker
5151

5252
async def get_response(
5353
self,
@@ -129,18 +129,9 @@ def make_tool_info(tool: Tool) -> ToolInput:
129129
prompt=prompt,
130130
)
131131

132-
# Serialize activity_input to JSON
133132
activity_input_json = activity_input.to_json()
134133

135-
task = self.durable_orchestration_context.call_activity(
136-
"invoke_model_activity",
137-
activity_input_json
138-
)
139-
140-
if not self.durable_orchestration_context.is_replaying:
141-
raise OrchestratorYielded(task)
142-
143-
result = task.result
134+
result = self.activity_call_tracker.call_activity("invoke_model_activity", activity_input_json)
144135
return result
145136

146137
def stream_response(

samples-v2/openai_agents/durable_openai_runner.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import logging
33
from typing import Any, Callable, Optional
4+
from activity_call_tracker import ActivityCallTracker
45
import azure.functions as func
56
import azure.durable_functions as df
67

@@ -27,9 +28,9 @@
2728
logger = logging.getLogger(__name__)
2829

2930
class DurableOpenAIRunner:
30-
def __init__(self, durable_orchestration_context: df.DurableOrchestrationContext) -> None:
31+
def __init__(self, activity_call_tracker: ActivityCallTracker) -> None:
3132
self._runner = DEFAULT_AGENT_RUNNER or AgentRunner()
32-
self.durable_orchestration_context = durable_orchestration_context
33+
self.activity_call_tracker = activity_call_tracker
3334

3435
def run_sync(
3536
self,
@@ -60,9 +61,9 @@ def run_sync(
6061

6162
updated_run_config = replace(
6263
run_config,
63-
model=_DurableModelStub(
64-
model_name=model_name,
65-
durable_orchestration_context=self.durable_orchestration_context
64+
model = _DurableModelStub(
65+
model_name = model_name,
66+
activity_call_tracker = self.activity_call_tracker,
6667
),
6768
)
6869

samples-v2/openai_agents/function_app.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,6 @@
22
import azure.functions as func
33
import logging
44

5-
from agents.run import set_default_agent_runner
6-
from durable_openai_runner import DurableOpenAIRunner
7-
from orchestrator_exceptions import OrchestratorYielded
8-
95
from agents import (
106
AgentOutputSchemaBase,
117
CodeInterpreterTool,
@@ -26,6 +22,11 @@
2622
WebSearchTool,
2723
)
2824

25+
from agents.run import set_default_agent_runner
26+
from durable_openai_runner import DurableOpenAIRunner
27+
from yield_exception import YieldException
28+
from activity_call_tracker import ActivityCallTracker
29+
2930
app = func.FunctionApp(http_auth_level=func.AuthLevel.FUNCTION)
3031

3132
@app.route(route="orchestrators/{functionName}")
@@ -39,14 +40,24 @@ async def hello_orchestration_starter(req: func.HttpRequest, client):
3940

4041
@app.orchestration_trigger(context_name="context")
4142
def basic_hello_world_orchestrator(context):
42-
set_default_agent_runner(DurableOpenAIRunner(durable_orchestration_context=context))
43+
activity_call_tracker = ActivityCallTracker(context)
44+
durable_openai_runner = DurableOpenAIRunner(activity_call_tracker=activity_call_tracker)
45+
set_default_agent_runner(durable_openai_runner)
4346

4447
try:
4548
from basic.hello_world import main
4649
result = main()
4750
return result
48-
except OrchestratorYielded as e:
49-
yield e.activity_output
51+
except YieldException as e:
52+
for task in activity_call_tracker.tasks_to_yield:
53+
yield task
54+
activity_call_tracker.tasks_to_yield.clear()
55+
yield e.task
56+
finally:
57+
for task in activity_call_tracker.tasks_to_yield:
58+
yield task
59+
activity_call_tracker.tasks_to_yield.clear()
60+
5061

5162
@app.activity_trigger(input_name="input")
5263
async def invoke_model_activity(input: str):

samples-v2/openai_agents/orchestrator_exceptions.py

Lines changed: 0 additions & 12 deletions
This file was deleted.
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from azure.durable_functions.models.Task import TaskBase
2+
3+
class YieldException(Exception):
4+
def __init__(self, task: TaskBase):
5+
super().__init__("Orchestrator should yield.")
6+
self.task = task

0 commit comments

Comments
 (0)