1
+ from __future__ import annotations
1
2
import enum
2
3
import json
4
+ import logging
3
5
from dataclasses import dataclass
4
6
from datetime import timedelta
5
- from typing import Any , Optional , Union
7
+ from typing import Any , AsyncIterator , Optional , Union , cast
6
8
9
+ import azure .functions as func
7
10
from pydantic import BaseModel , Field
8
11
from agents import (
12
+ AgentOutputSchema ,
9
13
AgentOutputSchemaBase ,
10
14
CodeInterpreterTool ,
11
15
FileSearchTool ,
12
16
FunctionTool ,
13
17
Handoff ,
14
18
HostedMCPTool ,
15
19
ImageGenerationTool ,
20
+ Model ,
16
21
ModelProvider ,
17
22
ModelResponse ,
18
23
ModelSettings ,
24
29
UserError ,
25
30
WebSearchTool ,
26
31
)
32
+ from agents .items import TResponseStreamEvent
27
33
from openai import (
28
34
APIStatusError ,
29
35
AsyncOpenAI ,
30
36
)
31
37
from openai .types .responses .tool_param import Mcp
38
+ from openai .types .responses .response_prompt_param import ResponsePromptParam
39
+
40
+ from .context import DurableAIAgentContext
41
+
32
42
try :
33
43
from azure .durable_functions import ApplicationError
34
44
except ImportError :
@@ -39,6 +49,8 @@ def __init__(self, message: str, non_retryable: bool = False, next_retry_delay =
39
49
self .non_retryable = non_retryable
40
50
self .next_retry_delay = next_retry_delay
41
51
52
+ logger = logging .getLogger (__name__ )
53
+
42
54
43
55
class HandoffInput (BaseModel ):
44
56
"""Data conversion friendly representation of a Handoff. Contains only the fields which are needed by the model
@@ -259,3 +271,132 @@ def make_tool(tool: ToolInput) -> Tool:
259
271
non_retryable = True ,
260
272
next_retry_delay = retry_after ,
261
273
) from e
274
+
275
+
276
+ class _DurableModelStub (Model ):
277
+ def __init__ (
278
+ self ,
279
+ model_name : Optional [str ],
280
+ context : DurableAIAgentContext ,
281
+ ) -> None :
282
+ self .model_name = model_name
283
+ self .context = context
284
+
285
+ async def get_response (
286
+ self ,
287
+ system_instructions : Optional [str ],
288
+ input : Union [str , list [TResponseInputItem ]],
289
+ model_settings : ModelSettings ,
290
+ tools : list [Tool ],
291
+ output_schema : Optional [AgentOutputSchemaBase ],
292
+ handoffs : list [Handoff ],
293
+ tracing : ModelTracing ,
294
+ * ,
295
+ previous_response_id : Optional [str ],
296
+ prompt : Optional [ResponsePromptParam ],
297
+ ) -> ModelResponse :
298
+ def make_tool_info (tool : Tool ) -> ToolInput :
299
+ if isinstance (
300
+ tool ,
301
+ (
302
+ FileSearchTool ,
303
+ WebSearchTool ,
304
+ ImageGenerationTool ,
305
+ CodeInterpreterTool ,
306
+ ),
307
+ ):
308
+ return tool
309
+ elif isinstance (tool , HostedMCPTool ):
310
+ return HostedMCPToolInput (tool_config = tool .tool_config )
311
+ elif isinstance (tool , FunctionTool ):
312
+ return FunctionToolInput (
313
+ name = tool .name ,
314
+ description = tool .description ,
315
+ params_json_schema = tool .params_json_schema ,
316
+ strict_json_schema = tool .strict_json_schema ,
317
+ )
318
+ else :
319
+ raise ValueError (f"Unsupported tool type: { tool .name } " )
320
+
321
+ tool_infos = [make_tool_info (x ) for x in tools ]
322
+ handoff_infos = [
323
+ HandoffInput (
324
+ tool_name = x .tool_name ,
325
+ tool_description = x .tool_description ,
326
+ input_json_schema = x .input_json_schema ,
327
+ agent_name = x .agent_name ,
328
+ strict_json_schema = x .strict_json_schema ,
329
+ )
330
+ for x in handoffs
331
+ ]
332
+ if output_schema is not None and not isinstance (
333
+ output_schema , AgentOutputSchema
334
+ ):
335
+ raise TypeError (
336
+ f"Only AgentOutputSchema is supported by Durable Model, got { type (output_schema ).__name__ } "
337
+ )
338
+ agent_output_schema = output_schema
339
+ output_schema_input = (
340
+ None
341
+ if agent_output_schema is None
342
+ else AgentOutputSchemaInput (
343
+ output_type_name = agent_output_schema .name (),
344
+ is_wrapped = agent_output_schema ._is_wrapped ,
345
+ output_schema = agent_output_schema .json_schema ()
346
+ if not agent_output_schema .is_plain_text ()
347
+ else None ,
348
+ strict_json_schema = agent_output_schema .is_strict_json_schema (),
349
+ )
350
+ )
351
+
352
+ activity_input = ActivityModelInput (
353
+ model_name = self .model_name ,
354
+ system_instructions = system_instructions ,
355
+ input = cast (Union [str , list [TResponseInputItem ]], input ),
356
+ model_settings = model_settings ,
357
+ tools = tool_infos ,
358
+ output_schema = output_schema_input ,
359
+ handoffs = handoff_infos ,
360
+ tracing = ModelTracingInput .DISABLED , # ModelTracingInput(tracing.value),
361
+ previous_response_id = previous_response_id ,
362
+ prompt = prompt ,
363
+ )
364
+
365
+ activity_input_json = activity_input .to_json ()
366
+
367
+ response = self .context ._get_activity_call_result ("invoke_model_activity" , activity_input_json )
368
+ json_response = json .loads (response )
369
+ model_response = ModelResponse (** json_response )
370
+ return model_response
371
+
372
+ def stream_response (
373
+ self ,
374
+ system_instructions : Optional [str ],
375
+ input : Union [str , list [TResponseInputItem ]],
376
+ model_settings : ModelSettings ,
377
+ tools : list [Tool ],
378
+ output_schema : Optional [AgentOutputSchemaBase ],
379
+ handoffs : list [Handoff ],
380
+ tracing : ModelTracing ,
381
+ * ,
382
+ previous_response_id : Optional [str ],
383
+ prompt : ResponsePromptParam | None ,
384
+ ) -> AsyncIterator [TResponseStreamEvent ]:
385
+ raise NotImplementedError ("Durable model doesn't support streams yet" )
386
+
387
+
388
+ def create_invoke_model_activity (app : func .FunctionApp ):
389
+ """Create and register the invoke_model_activity function with the provided FunctionApp."""
390
+
391
+ @app .activity_trigger (input_name = "input" )
392
+ async def invoke_model_activity (input : str ):
393
+ """Activity that handles OpenAI model invocations."""
394
+ activity_input = ActivityModelInput .from_json (input )
395
+
396
+ model_invoker = ModelInvoker ()
397
+ result = await model_invoker .invoke_model_activity (activity_input )
398
+
399
+ json_obj = ModelResponse .__pydantic_serializer__ .to_json (result )
400
+ return json_obj .decode ()
401
+
402
+ return invoke_model_activity
0 commit comments