Skip to content

Commit 49e37a5

Browse files
committed
Fix activity input serialization
1 parent 3877efc commit 49e37a5

File tree

3 files changed

+87
-5
lines changed

3 files changed

+87
-5
lines changed

samples-v2/openai_agents/_invoke_model_activity.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,49 @@ def to_json(self) -> str:
162162
"""Convert the ActivityModelInput to a JSON string."""
163163
return to_json(self).decode('utf-8')
164164

165+
@classmethod
166+
def from_json(cls, json_str: str) -> 'ActivityModelInput':
167+
"""Create an ActivityModelInput instance from a JSON string."""
168+
import json
169+
data = json.loads(json_str)
170+
171+
# Convert complex types back from dictionaries
172+
if 'model_settings' in data and isinstance(data['model_settings'], dict):
173+
data['model_settings'] = ModelSettings(**data['model_settings'])
174+
175+
if 'tracing' in data and isinstance(data['tracing'], int):
176+
data['tracing'] = ModelTracingInput(data['tracing'])
177+
178+
# Convert tool inputs back to proper types
179+
if 'tools' in data and data['tools']:
180+
converted_tools = []
181+
for tool_data in data['tools']:
182+
if isinstance(tool_data, dict):
183+
# Check the tool type and convert accordingly
184+
if 'name' in tool_data and 'description' in tool_data and 'params_json_schema' in tool_data:
185+
# FunctionToolInput
186+
converted_tools.append(FunctionToolInput(**tool_data))
187+
elif 'tool_config' in tool_data:
188+
# HostedMCPToolInput
189+
converted_tools.append(HostedMCPToolInput(**tool_data))
190+
else:
191+
# For other tool types like FileSearchTool, WebSearchTool, etc.
192+
# These might be already properly serialized/deserialized by pydantic_core
193+
converted_tools.append(tool_data)
194+
else:
195+
converted_tools.append(tool_data)
196+
data['tools'] = converted_tools
197+
198+
# Convert handoffs back to proper types
199+
if 'handoffs' in data and data['handoffs']:
200+
data['handoffs'] = [HandoffInput(**handoff) for handoff in data['handoffs']]
201+
202+
# Convert output_schema back to proper type
203+
if 'output_schema' in data and data['output_schema'] and isinstance(data['output_schema'], dict):
204+
data['output_schema'] = AgentOutputSchemaInput(**data['output_schema'])
205+
206+
return cls(**data)
207+
165208

166209
class ModelActivity:
167210
"""Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled.

samples-v2/openai_agents/durable_model_stub.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,17 @@
11
from __future__ import annotations
2-
from _invoke_model_activity import ActivityModelInput, ModelTracingInput
2+
from _invoke_model_activity import (
3+
ActivityModelInput,
4+
ModelTracingInput,
5+
ToolInput,
6+
HostedMCPToolInput,
7+
FunctionToolInput,
8+
HandoffInput,
9+
AgentOutputSchemaInput
10+
)
11+
from orchestrator_exceptions import OrchestratorYielded
312
import azure.durable_functions as df
413

14+
import json
515
import logging
616
from typing import Optional
717

@@ -119,9 +129,12 @@ def make_tool_info(tool: Tool) -> ToolInput:
119129
prompt=prompt,
120130
)
121131

122-
activity_output =self.durable_orchestration_context.call_activity(
132+
# Serialize activity_input to JSON
133+
activity_input_json = activity_input.to_json()
134+
135+
activity_output = self.durable_orchestration_context.call_activity(
123136
"invoke_model_activity",
124-
activity_input
137+
activity_input_json
125138
)
126139

127140
raise OrchestratorYielded(activity_output)

samples-v2/openai_agents/function_app.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,27 @@
44

55
from agents.run import set_default_agent_runner
66
from durable_openai_runner import DurableOpenAIRunner
7+
from orchestrator_exceptions import OrchestratorYielded
8+
9+
from agents import (
10+
AgentOutputSchemaBase,
11+
CodeInterpreterTool,
12+
FileSearchTool,
13+
FunctionTool,
14+
Handoff,
15+
HostedMCPTool,
16+
ImageGenerationTool,
17+
ModelProvider,
18+
ModelResponse,
19+
ModelSettings,
20+
ModelTracing,
21+
OpenAIProvider,
22+
RunContextWrapper,
23+
Tool,
24+
TResponseInputItem,
25+
UserError,
26+
WebSearchTool,
27+
)
728

829
app = func.FunctionApp(http_auth_level=func.AuthLevel.FUNCTION)
930

@@ -28,8 +49,13 @@ def basic_hello_world_orchestrator(context):
2849
yield e.activity_output
2950

3051
@app.activity_trigger(input_name="input")
31-
def invoke_model_activity(input: Any):
52+
async def invoke_model_activity(input: str):
3253
# Instantiate ModelActivity
3354
from _invoke_model_activity import ModelActivity, ActivityModelInput
55+
56+
# Deserialize input string into ActivityModelInput object
57+
activity_input = ActivityModelInput.from_json(input)
58+
3459
model_activity = ModelActivity()
35-
return model_activity.invoke_model_activity(input)
60+
return await model_activity.invoke_model_activity(activity_input)
61+

0 commit comments

Comments
 (0)