Skip to content

Commit 2de5f84

Browse files
committed
Add _invoke_model_activity
1 parent f9ae167 commit 2de5f84

File tree

1 file changed

+279
-0
lines changed

1 file changed

+279
-0
lines changed
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
"""A durable activity that invokes a LLM model.
2+
3+
Implements mapping of OpenAI datastructures to Pydantic friendly types.
4+
"""
5+
6+
import enum
7+
import json
8+
from dataclasses import dataclass
9+
from datetime import timedelta
10+
from typing import Any, Optional, Union
11+
12+
from agents import (
13+
AgentOutputSchemaBase,
14+
CodeInterpreterTool,
15+
FileSearchTool,
16+
FunctionTool,
17+
Handoff,
18+
HostedMCPTool,
19+
ImageGenerationTool,
20+
ModelProvider,
21+
ModelResponse,
22+
ModelSettings,
23+
ModelTracing,
24+
OpenAIProvider,
25+
RunContextWrapper,
26+
Tool,
27+
TResponseInputItem,
28+
UserError,
29+
WebSearchTool,
30+
)
31+
from openai import (
32+
APIStatusError,
33+
AsyncOpenAI,
34+
)
35+
from openai.types.responses.tool_param import Mcp
36+
from pydantic_core import to_json
37+
try:
38+
from azure.durable_functions import ApplicationError
39+
except ImportError:
40+
# Fallback if ApplicationError is not available
41+
class ApplicationError(Exception):
42+
def __init__(self, message: str, non_retryable: bool = False, next_retry_delay = None):
43+
super().__init__(message)
44+
self.non_retryable = non_retryable
45+
self.next_retry_delay = next_retry_delay
46+
47+
48+
@dataclass
49+
class HandoffInput:
50+
"""Data conversion friendly representation of a Handoff. Contains only the fields which are needed by the model
51+
execution to determine what to handoff to, not the actual handoff invocation, which remains in the workflow context.
52+
"""
53+
54+
tool_name: str
55+
tool_description: str
56+
input_json_schema: dict[str, Any]
57+
agent_name: str
58+
strict_json_schema: bool = True
59+
60+
61+
@dataclass
62+
class FunctionToolInput:
63+
"""Data conversion friendly representation of a FunctionTool. Contains only the fields which are needed by the model
64+
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
65+
"""
66+
67+
name: str
68+
description: str
69+
params_json_schema: dict[str, Any]
70+
strict_json_schema: bool = True
71+
72+
73+
@dataclass
74+
class HostedMCPToolInput:
75+
"""Data conversion friendly representation of a HostedMCPTool. Contains only the fields which are needed by the model
76+
execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context.
77+
"""
78+
79+
tool_config: Mcp
80+
81+
82+
ToolInput = Union[
83+
FunctionToolInput,
84+
FileSearchTool,
85+
WebSearchTool,
86+
ImageGenerationTool,
87+
CodeInterpreterTool,
88+
HostedMCPToolInput,
89+
]
90+
91+
92+
@dataclass
93+
class AgentOutputSchemaInput(AgentOutputSchemaBase):
94+
"""Data conversion friendly representation of AgentOutputSchema."""
95+
96+
output_type_name: Optional[str]
97+
is_wrapped: bool
98+
output_schema: Optional[dict[str, Any]]
99+
strict_json_schema: bool
100+
101+
def is_plain_text(self) -> bool:
102+
"""Whether the output type is plain text (versus a JSON object)."""
103+
return self.output_type_name is None or self.output_type_name == "str"
104+
105+
def is_strict_json_schema(self) -> bool:
106+
"""Whether the JSON schema is in strict mode."""
107+
return self.strict_json_schema
108+
109+
def json_schema(self) -> dict[str, Any]:
110+
"""The JSON schema of the output type."""
111+
if self.is_plain_text():
112+
raise UserError("Output type is plain text, so no JSON schema is available")
113+
if self.output_schema is None:
114+
raise UserError("Output schema is not defined")
115+
return self.output_schema
116+
117+
def validate_json(self, json_str: str) -> Any:
118+
"""Validate the JSON string against the schema."""
119+
raise NotImplementedError()
120+
121+
def name(self) -> str:
122+
"""Get the name of the output type."""
123+
if self.output_type_name is None:
124+
raise ValueError("output_type_name is None")
125+
return self.output_type_name
126+
127+
128+
class ModelTracingInput(enum.IntEnum):
129+
"""Conversion friendly representation of ModelTracing.
130+
131+
Needed as ModelTracing is enum.Enum instead of IntEnum
132+
"""
133+
134+
DISABLED = 0
135+
ENABLED = 1
136+
ENABLED_WITHOUT_DATA = 2
137+
138+
139+
@dataclass
140+
class ActivityModelInput:
141+
"""Input for the invoke_model_activity activity."""
142+
143+
input: Union[str, list[TResponseInputItem]]
144+
model_settings: ModelSettings
145+
tracing: ModelTracingInput
146+
model_name: Optional[str] = None
147+
system_instructions: Optional[str] = None
148+
tools: list[ToolInput] = None
149+
output_schema: Optional[AgentOutputSchemaInput] = None
150+
handoffs: list[HandoffInput] = None
151+
previous_response_id: Optional[str] = None
152+
prompt: Optional[Any] = None
153+
154+
def __post_init__(self):
155+
"""Initialize default values for list fields."""
156+
if self.tools is None:
157+
self.tools = []
158+
if self.handoffs is None:
159+
self.handoffs = []
160+
161+
def to_json(self) -> str:
162+
"""Convert the ActivityModelInput to a JSON string."""
163+
return to_json(self).decode('utf-8')
164+
165+
166+
class ModelActivity:
167+
"""Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled.
168+
Disabling retries in your model of choice is recommended to allow activity retries to define the retry model.
169+
"""
170+
171+
def __init__(self, model_provider: Optional[ModelProvider] = None):
172+
"""Initialize the activity with a model provider."""
173+
self._model_provider = model_provider or OpenAIProvider(
174+
openai_client=AsyncOpenAI(max_retries=0)
175+
)
176+
177+
async def invoke_model_activity(self, input: ActivityModelInput) -> ModelResponse:
178+
"""Activity that invokes a model with the given input."""
179+
model = self._model_provider.get_model(input.model_name)
180+
181+
async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
182+
return ""
183+
184+
async def empty_on_invoke_handoff(
185+
ctx: RunContextWrapper[Any], input: str
186+
) -> Any:
187+
return None
188+
189+
# workaround for https://github.com/pydantic/pydantic/issues/9541
190+
# ValidatorIterator returned
191+
input_json = to_json(input.input)
192+
input_input = json.loads(input_json)
193+
194+
def make_tool(tool: ToolInput) -> Tool:
195+
if isinstance(
196+
tool,
197+
(
198+
FileSearchTool,
199+
WebSearchTool,
200+
ImageGenerationTool,
201+
CodeInterpreterTool,
202+
),
203+
):
204+
return tool
205+
elif isinstance(tool, HostedMCPToolInput):
206+
return HostedMCPTool(
207+
tool_config=tool.tool_config,
208+
)
209+
elif isinstance(tool, FunctionToolInput):
210+
return FunctionTool(
211+
name=tool.name,
212+
description=tool.description,
213+
params_json_schema=tool.params_json_schema,
214+
on_invoke_tool=empty_on_invoke_tool,
215+
strict_json_schema=tool.strict_json_schema,
216+
)
217+
else:
218+
raise UserError(f"Unknown tool type: {tool.name}")
219+
220+
tools = [make_tool(x) for x in input.tools]
221+
handoffs: list[Handoff[Any, Any]] = [
222+
Handoff(
223+
tool_name=x.tool_name,
224+
tool_description=x.tool_description,
225+
input_json_schema=x.input_json_schema,
226+
agent_name=x.agent_name,
227+
strict_json_schema=x.strict_json_schema,
228+
on_invoke_handoff=empty_on_invoke_handoff,
229+
)
230+
for x in input.handoffs
231+
]
232+
233+
try:
234+
return await model.get_response(
235+
system_instructions=input.system_instructions,
236+
input=input_input,
237+
model_settings=input.model_settings,
238+
tools=tools,
239+
output_schema=input.output_schema,
240+
handoffs=handoffs,
241+
tracing=ModelTracing(input.tracing),
242+
previous_response_id=input.previous_response_id,
243+
prompt=input.prompt,
244+
)
245+
except APIStatusError as e:
246+
# Listen to server hints
247+
retry_after = None
248+
retry_after_ms_header = e.response.headers.get("retry-after-ms")
249+
if retry_after_ms_header is not None:
250+
retry_after = timedelta(milliseconds=float(retry_after_ms_header))
251+
252+
if retry_after is None:
253+
retry_after_header = e.response.headers.get("retry-after")
254+
if retry_after_header is not None:
255+
retry_after = timedelta(seconds=float(retry_after_header))
256+
257+
should_retry_header = e.response.headers.get("x-should-retry")
258+
if should_retry_header == "true":
259+
raise e
260+
if should_retry_header == "false":
261+
raise ApplicationError(
262+
"Non retryable OpenAI error",
263+
non_retryable=True,
264+
next_retry_delay=retry_after,
265+
) from e
266+
267+
# Specifically retryable status codes
268+
if e.response.status_code in [408, 409, 429, 500]:
269+
raise ApplicationError(
270+
"Retryable OpenAI status code",
271+
non_retryable=False,
272+
next_retry_delay=retry_after,
273+
) from e
274+
275+
raise ApplicationError(
276+
"Non retryable OpenAI status code",
277+
non_retryable=True,
278+
next_retry_delay=retry_after,
279+
) from e

0 commit comments

Comments
 (0)