Skip to content

Commit acc41aa

Browse files
committed
Add _DurableModelStub
1 parent 5fa7e69 commit acc41aa

File tree

3 files changed

+171
-12
lines changed

3 files changed

+171
-12
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
from __future__ import annotations
2+
import azure.durable_functions as df
3+
4+
import logging
5+
from typing import Optional
6+
7+
logger = logging.getLogger(__name__)
8+
9+
from typing import Any, AsyncIterator, Union, cast
10+
11+
from agents import (
12+
AgentOutputSchema,
13+
AgentOutputSchemaBase,
14+
CodeInterpreterTool,
15+
FileSearchTool,
16+
FunctionTool,
17+
Handoff,
18+
HostedMCPTool,
19+
ImageGenerationTool,
20+
Model,
21+
ModelResponse,
22+
ModelSettings,
23+
ModelTracing,
24+
Tool,
25+
TResponseInputItem,
26+
WebSearchTool,
27+
)
28+
from agents.items import TResponseStreamEvent
29+
from openai.types.responses.response_prompt_param import ResponsePromptParam
30+
31+
32+
class _DurableModelStub(Model):
33+
def __init__(
34+
self,
35+
model_name: Optional[str],
36+
durable_orchestration_context: df.DurableOrchestrationContext,
37+
) -> None:
38+
self.model_name = model_name
39+
self.durable_orchestration_context = durable_orchestration_context
40+
41+
async def get_response(
42+
self,
43+
system_instructions: Optional[str],
44+
input: Union[str, list[TResponseInputItem]],
45+
model_settings: ModelSettings,
46+
tools: list[Tool],
47+
output_schema: Optional[AgentOutputSchemaBase],
48+
handoffs: list[Handoff],
49+
tracing: ModelTracing,
50+
*,
51+
previous_response_id: Optional[str],
52+
prompt: Optional[ResponsePromptParam],
53+
) -> ModelResponse:
54+
def make_tool_info(tool: Tool) -> ToolInput:
55+
if isinstance(
56+
tool,
57+
(
58+
FileSearchTool,
59+
WebSearchTool,
60+
ImageGenerationTool,
61+
CodeInterpreterTool,
62+
),
63+
):
64+
return tool
65+
elif isinstance(tool, HostedMCPTool):
66+
return HostedMCPToolInput(tool_config=tool.tool_config)
67+
elif isinstance(tool, FunctionTool):
68+
return FunctionToolInput(
69+
name=tool.name,
70+
description=tool.description,
71+
params_json_schema=tool.params_json_schema,
72+
strict_json_schema=tool.strict_json_schema,
73+
)
74+
else:
75+
raise ValueError(f"Unsupported tool type: {tool.name}")
76+
77+
tool_infos = [make_tool_info(x) for x in tools]
78+
handoff_infos = [
79+
HandoffInput(
80+
tool_name=x.tool_name,
81+
tool_description=x.tool_description,
82+
input_json_schema=x.input_json_schema,
83+
agent_name=x.agent_name,
84+
strict_json_schema=x.strict_json_schema,
85+
)
86+
for x in handoffs
87+
]
88+
if output_schema is not None and not isinstance(
89+
output_schema, AgentOutputSchema
90+
):
91+
raise TypeError(
92+
f"Only AgentOutputSchema is supported by Durable Model, got {type(output_schema).__name__}"
93+
)
94+
agent_output_schema = output_schema
95+
output_schema_input = (
96+
None
97+
if agent_output_schema is None
98+
else AgentOutputSchemaInput(
99+
output_type_name=agent_output_schema.name(),
100+
is_wrapped=agent_output_schema._is_wrapped,
101+
output_schema=agent_output_schema.json_schema()
102+
if not agent_output_schema.is_plain_text()
103+
else None,
104+
strict_json_schema=agent_output_schema.is_strict_json_schema(),
105+
)
106+
)
107+
108+
activity_input = ActivityModelInput(
109+
model_name=self.model_name,
110+
system_instructions=system_instructions,
111+
input=cast(Union[str, list[TResponseInputItem]], input),
112+
model_settings=model_settings,
113+
tools=tool_infos,
114+
output_schema=output_schema_input,
115+
handoffs=handoff_infos,
116+
tracing=ModelTracingInput(tracing.value),
117+
previous_response_id=previous_response_id,
118+
prompt=prompt,
119+
)
120+
121+
activity_output =self.durable_orchestration_context.call_activity(
122+
"InvokeModelActivity",
123+
activity_input
124+
)
125+
126+
return activity_output
127+
128+
def stream_response(
129+
self,
130+
system_instructions: Optional[str],
131+
input: Union[str, list[TResponseInputItem]],
132+
model_settings: ModelSettings,
133+
tools: list[Tool],
134+
output_schema: Optional[AgentOutputSchemaBase],
135+
handoffs: list[Handoff],
136+
tracing: ModelTracing,
137+
*,
138+
previous_response_id: Optional[str],
139+
prompt: ResponsePromptParam | None,
140+
) -> AsyncIterator[TResponseStreamEvent]:
141+
raise NotImplementedError("Durable model doesn't support streams yet")

samples-v2/openai_agents/durable_openai_runner.py

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from typing import Any, Callable, Optional
44
import azure.functions as func
5+
import azure.durable_functions as df
56

67
import json
78
import typing
@@ -21,11 +22,14 @@
2122
from agents.run import DEFAULT_AGENT_RUNNER, DEFAULT_MAX_TURNS, AgentRunner
2223
from pydantic_core import to_json
2324

25+
from durable_model_stub import _DurableModelStub
26+
2427
logger = logging.getLogger(__name__)
2528

2629
class DurableOpenAIRunner:
27-
def __init__(self) -> None:
30+
def __init__(self, durable_orchestration_context: df.DurableOrchestrationContext) -> None:
2831
self._runner = DEFAULT_AGENT_RUNNER or AgentRunner()
32+
self.durable_orchestration_context = durable_orchestration_context
2933

3034
def run_sync(
3135
self,
@@ -53,13 +57,14 @@ def run_sync(
5357
raise ValueError(
5458
"Durable Functions require a model name to be a string in the run config and/or agent."
5559
)
56-
# updated_run_config = replace(
57-
# run_config,
58-
# model=_TemporalModelStub(
59-
# model_name=model_name
60-
# ),
61-
# )
62-
updated_run_config = run_config
60+
61+
updated_run_config = replace(
62+
run_config,
63+
model=_DurableModelStub(
64+
model_name=model_name,
65+
durable_orchestration_context=self.durable_orchestration_context
66+
),
67+
)
6368

6469
return self._runner.run_sync(
6570
starting_agent=starting_agent,
@@ -87,3 +92,18 @@ def run_streamed(
8792
**kwargs: Any,
8893
) -> RunResultStreaming:
8994
raise RuntimeError("Durable Functions do not support streaming.")
95+
def run(
96+
self,
97+
starting_agent: Agent[TContext],
98+
input: Union[str, list[TResponseInputItem]],
99+
**kwargs: Any,
100+
) -> RunResult:
101+
raise RuntimeError("Durable Functions do not support asynchronous runs.")
102+
103+
def run_streamed(
104+
self,
105+
starting_agent: Agent[TContext],
106+
input: Union[str, list[TResponseInputItem]],
107+
**kwargs: Any,
108+
) -> RunResultStreaming:
109+
raise RuntimeError("Durable Functions do not support streaming.")

samples-v2/openai_agents/function_app.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
from agents.run import set_default_agent_runner
55
from durable_openai_runner import DurableOpenAIRunner
66

7-
set_default_agent_runner(DurableOpenAIRunner())
8-
9-
107
app = func.FunctionApp(http_auth_level=func.AuthLevel.FUNCTION)
118

129
@app.route(route="orchestrators/{functionName}")
@@ -20,7 +17,8 @@ async def hello_orchestration_starter(req: func.HttpRequest, client):
2017

2118
@app.orchestration_trigger(context_name="context")
2219
def basic_hello_world_orchestrator(context):
20+
set_default_agent_runner(DurableOpenAIRunner(durable_orchestration_context=context))
21+
2322
from basic.hello_world import main
2423
result = main()
2524
return result
26-

0 commit comments

Comments
 (0)