1+ from __future__ import annotations
12import enum
23import json
4+ import logging
35from dataclasses import dataclass
46from datetime import timedelta
5- from typing import Any , Optional , Union
7+ from typing import Any , AsyncIterator , Optional , Union , cast
68
9+ import azure .functions as func
710from pydantic import BaseModel , Field
811from agents import (
12+ AgentOutputSchema ,
913 AgentOutputSchemaBase ,
1014 CodeInterpreterTool ,
1115 FileSearchTool ,
1216 FunctionTool ,
1317 Handoff ,
1418 HostedMCPTool ,
1519 ImageGenerationTool ,
20+ Model ,
1621 ModelProvider ,
1722 ModelResponse ,
1823 ModelSettings ,
2429 UserError ,
2530 WebSearchTool ,
2631)
32+ from agents .items import TResponseStreamEvent
2733from openai import (
2834 APIStatusError ,
2935 AsyncOpenAI ,
3036)
3137from openai .types .responses .tool_param import Mcp
38+ from openai .types .responses .response_prompt_param import ResponsePromptParam
39+
40+ from .context import DurableAIAgentContext
41+
3242try :
3343 from azure .durable_functions import ApplicationError
3444except ImportError :
@@ -39,6 +49,8 @@ def __init__(self, message: str, non_retryable: bool = False, next_retry_delay =
3949 self .non_retryable = non_retryable
4050 self .next_retry_delay = next_retry_delay
4151
52+ logger = logging .getLogger (__name__ )
53+
4254
4355class HandoffInput (BaseModel ):
4456 """Data conversion friendly representation of a Handoff. Contains only the fields which are needed by the model
@@ -259,3 +271,132 @@ def make_tool(tool: ToolInput) -> Tool:
259271 non_retryable = True ,
260272 next_retry_delay = retry_after ,
261273 ) from e
274+
275+
276+ class _DurableModelStub (Model ):
277+ def __init__ (
278+ self ,
279+ model_name : Optional [str ],
280+ context : DurableAIAgentContext ,
281+ ) -> None :
282+ self .model_name = model_name
283+ self .context = context
284+
285+ async def get_response (
286+ self ,
287+ system_instructions : Optional [str ],
288+ input : Union [str , list [TResponseInputItem ]],
289+ model_settings : ModelSettings ,
290+ tools : list [Tool ],
291+ output_schema : Optional [AgentOutputSchemaBase ],
292+ handoffs : list [Handoff ],
293+ tracing : ModelTracing ,
294+ * ,
295+ previous_response_id : Optional [str ],
296+ prompt : Optional [ResponsePromptParam ],
297+ ) -> ModelResponse :
298+ def make_tool_info (tool : Tool ) -> ToolInput :
299+ if isinstance (
300+ tool ,
301+ (
302+ FileSearchTool ,
303+ WebSearchTool ,
304+ ImageGenerationTool ,
305+ CodeInterpreterTool ,
306+ ),
307+ ):
308+ return tool
309+ elif isinstance (tool , HostedMCPTool ):
310+ return HostedMCPToolInput (tool_config = tool .tool_config )
311+ elif isinstance (tool , FunctionTool ):
312+ return FunctionToolInput (
313+ name = tool .name ,
314+ description = tool .description ,
315+ params_json_schema = tool .params_json_schema ,
316+ strict_json_schema = tool .strict_json_schema ,
317+ )
318+ else :
319+ raise ValueError (f"Unsupported tool type: { tool .name } " )
320+
321+ tool_infos = [make_tool_info (x ) for x in tools ]
322+ handoff_infos = [
323+ HandoffInput (
324+ tool_name = x .tool_name ,
325+ tool_description = x .tool_description ,
326+ input_json_schema = x .input_json_schema ,
327+ agent_name = x .agent_name ,
328+ strict_json_schema = x .strict_json_schema ,
329+ )
330+ for x in handoffs
331+ ]
332+ if output_schema is not None and not isinstance (
333+ output_schema , AgentOutputSchema
334+ ):
335+ raise TypeError (
336+ f"Only AgentOutputSchema is supported by Durable Model, got { type (output_schema ).__name__ } "
337+ )
338+ agent_output_schema = output_schema
339+ output_schema_input = (
340+ None
341+ if agent_output_schema is None
342+ else AgentOutputSchemaInput (
343+ output_type_name = agent_output_schema .name (),
344+ is_wrapped = agent_output_schema ._is_wrapped ,
345+ output_schema = agent_output_schema .json_schema ()
346+ if not agent_output_schema .is_plain_text ()
347+ else None ,
348+ strict_json_schema = agent_output_schema .is_strict_json_schema (),
349+ )
350+ )
351+
352+ activity_input = ActivityModelInput (
353+ model_name = self .model_name ,
354+ system_instructions = system_instructions ,
355+ input = cast (Union [str , list [TResponseInputItem ]], input ),
356+ model_settings = model_settings ,
357+ tools = tool_infos ,
358+ output_schema = output_schema_input ,
359+ handoffs = handoff_infos ,
360+ tracing = ModelTracingInput .DISABLED , # ModelTracingInput(tracing.value),
361+ previous_response_id = previous_response_id ,
362+ prompt = prompt ,
363+ )
364+
365+ activity_input_json = activity_input .to_json ()
366+
367+ response = self .context ._get_activity_call_result ("invoke_model_activity" , activity_input_json )
368+ json_response = json .loads (response )
369+ model_response = ModelResponse (** json_response )
370+ return model_response
371+
372+ def stream_response (
373+ self ,
374+ system_instructions : Optional [str ],
375+ input : Union [str , list [TResponseInputItem ]],
376+ model_settings : ModelSettings ,
377+ tools : list [Tool ],
378+ output_schema : Optional [AgentOutputSchemaBase ],
379+ handoffs : list [Handoff ],
380+ tracing : ModelTracing ,
381+ * ,
382+ previous_response_id : Optional [str ],
383+ prompt : ResponsePromptParam | None ,
384+ ) -> AsyncIterator [TResponseStreamEvent ]:
385+ raise NotImplementedError ("Durable model doesn't support streams yet" )
386+
387+
388+ def create_invoke_model_activity (app : func .FunctionApp ):
389+ """Create and register the invoke_model_activity function with the provided FunctionApp."""
390+
391+ @app .activity_trigger (input_name = "input" )
392+ async def invoke_model_activity (input : str ):
393+ """Activity that handles OpenAI model invocations."""
394+ activity_input = ActivityModelInput .from_json (input )
395+
396+ model_invoker = ModelInvoker ()
397+ result = await model_invoker .invoke_model_activity (activity_input )
398+
399+ json_obj = ModelResponse .__pydantic_serializer__ .to_json (result )
400+ return json_obj .decode ()
401+
402+ return invoke_model_activity
0 commit comments