|
| 1 | +"""A durable activity that invokes a LLM model. |
| 2 | +
|
| 3 | +Implements mapping of OpenAI datastructures to Pydantic friendly types. |
| 4 | +""" |
| 5 | + |
| 6 | +import enum |
| 7 | +import json |
| 8 | +from dataclasses import dataclass |
| 9 | +from datetime import timedelta |
| 10 | +from typing import Any, Optional, Union |
| 11 | + |
| 12 | +from agents import ( |
| 13 | + AgentOutputSchemaBase, |
| 14 | + CodeInterpreterTool, |
| 15 | + FileSearchTool, |
| 16 | + FunctionTool, |
| 17 | + Handoff, |
| 18 | + HostedMCPTool, |
| 19 | + ImageGenerationTool, |
| 20 | + ModelProvider, |
| 21 | + ModelResponse, |
| 22 | + ModelSettings, |
| 23 | + ModelTracing, |
| 24 | + OpenAIProvider, |
| 25 | + RunContextWrapper, |
| 26 | + Tool, |
| 27 | + TResponseInputItem, |
| 28 | + UserError, |
| 29 | + WebSearchTool, |
| 30 | +) |
| 31 | +from openai import ( |
| 32 | + APIStatusError, |
| 33 | + AsyncOpenAI, |
| 34 | +) |
| 35 | +from openai.types.responses.tool_param import Mcp |
| 36 | +from pydantic_core import to_json |
| 37 | +try: |
| 38 | + from azure.durable_functions import ApplicationError |
| 39 | +except ImportError: |
| 40 | + # Fallback if ApplicationError is not available |
| 41 | + class ApplicationError(Exception): |
| 42 | + def __init__(self, message: str, non_retryable: bool = False, next_retry_delay = None): |
| 43 | + super().__init__(message) |
| 44 | + self.non_retryable = non_retryable |
| 45 | + self.next_retry_delay = next_retry_delay |
| 46 | + |
| 47 | + |
| 48 | +@dataclass |
| 49 | +class HandoffInput: |
| 50 | + """Data conversion friendly representation of a Handoff. Contains only the fields which are needed by the model |
| 51 | + execution to determine what to handoff to, not the actual handoff invocation, which remains in the workflow context. |
| 52 | + """ |
| 53 | + |
| 54 | + tool_name: str |
| 55 | + tool_description: str |
| 56 | + input_json_schema: dict[str, Any] |
| 57 | + agent_name: str |
| 58 | + strict_json_schema: bool = True |
| 59 | + |
| 60 | + |
| 61 | +@dataclass |
| 62 | +class FunctionToolInput: |
| 63 | + """Data conversion friendly representation of a FunctionTool. Contains only the fields which are needed by the model |
| 64 | + execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context. |
| 65 | + """ |
| 66 | + |
| 67 | + name: str |
| 68 | + description: str |
| 69 | + params_json_schema: dict[str, Any] |
| 70 | + strict_json_schema: bool = True |
| 71 | + |
| 72 | + |
| 73 | +@dataclass |
| 74 | +class HostedMCPToolInput: |
| 75 | + """Data conversion friendly representation of a HostedMCPTool. Contains only the fields which are needed by the model |
| 76 | + execution to determine what tool to call, not the actual tool invocation, which remains in the workflow context. |
| 77 | + """ |
| 78 | + |
| 79 | + tool_config: Mcp |
| 80 | + |
| 81 | + |
| 82 | +ToolInput = Union[ |
| 83 | + FunctionToolInput, |
| 84 | + FileSearchTool, |
| 85 | + WebSearchTool, |
| 86 | + ImageGenerationTool, |
| 87 | + CodeInterpreterTool, |
| 88 | + HostedMCPToolInput, |
| 89 | +] |
| 90 | + |
| 91 | + |
| 92 | +@dataclass |
| 93 | +class AgentOutputSchemaInput(AgentOutputSchemaBase): |
| 94 | + """Data conversion friendly representation of AgentOutputSchema.""" |
| 95 | + |
| 96 | + output_type_name: Optional[str] |
| 97 | + is_wrapped: bool |
| 98 | + output_schema: Optional[dict[str, Any]] |
| 99 | + strict_json_schema: bool |
| 100 | + |
| 101 | + def is_plain_text(self) -> bool: |
| 102 | + """Whether the output type is plain text (versus a JSON object).""" |
| 103 | + return self.output_type_name is None or self.output_type_name == "str" |
| 104 | + |
| 105 | + def is_strict_json_schema(self) -> bool: |
| 106 | + """Whether the JSON schema is in strict mode.""" |
| 107 | + return self.strict_json_schema |
| 108 | + |
| 109 | + def json_schema(self) -> dict[str, Any]: |
| 110 | + """The JSON schema of the output type.""" |
| 111 | + if self.is_plain_text(): |
| 112 | + raise UserError("Output type is plain text, so no JSON schema is available") |
| 113 | + if self.output_schema is None: |
| 114 | + raise UserError("Output schema is not defined") |
| 115 | + return self.output_schema |
| 116 | + |
| 117 | + def validate_json(self, json_str: str) -> Any: |
| 118 | + """Validate the JSON string against the schema.""" |
| 119 | + raise NotImplementedError() |
| 120 | + |
| 121 | + def name(self) -> str: |
| 122 | + """Get the name of the output type.""" |
| 123 | + if self.output_type_name is None: |
| 124 | + raise ValueError("output_type_name is None") |
| 125 | + return self.output_type_name |
| 126 | + |
| 127 | + |
| 128 | +class ModelTracingInput(enum.IntEnum): |
| 129 | + """Conversion friendly representation of ModelTracing. |
| 130 | +
|
| 131 | + Needed as ModelTracing is enum.Enum instead of IntEnum |
| 132 | + """ |
| 133 | + |
| 134 | + DISABLED = 0 |
| 135 | + ENABLED = 1 |
| 136 | + ENABLED_WITHOUT_DATA = 2 |
| 137 | + |
| 138 | + |
| 139 | +@dataclass |
| 140 | +class ActivityModelInput: |
| 141 | + """Input for the invoke_model_activity activity.""" |
| 142 | + |
| 143 | + input: Union[str, list[TResponseInputItem]] |
| 144 | + model_settings: ModelSettings |
| 145 | + tracing: ModelTracingInput |
| 146 | + model_name: Optional[str] = None |
| 147 | + system_instructions: Optional[str] = None |
| 148 | + tools: list[ToolInput] = None |
| 149 | + output_schema: Optional[AgentOutputSchemaInput] = None |
| 150 | + handoffs: list[HandoffInput] = None |
| 151 | + previous_response_id: Optional[str] = None |
| 152 | + prompt: Optional[Any] = None |
| 153 | + |
| 154 | + def __post_init__(self): |
| 155 | + """Initialize default values for list fields.""" |
| 156 | + if self.tools is None: |
| 157 | + self.tools = [] |
| 158 | + if self.handoffs is None: |
| 159 | + self.handoffs = [] |
| 160 | + |
| 161 | + def to_json(self) -> str: |
| 162 | + """Convert the ActivityModelInput to a JSON string.""" |
| 163 | + return to_json(self).decode('utf-8') |
| 164 | + |
| 165 | + |
| 166 | +class ModelActivity: |
| 167 | + """Class wrapper for model invocation activities to allow model customization. By default, we use an OpenAIProvider with retries disabled. |
| 168 | + Disabling retries in your model of choice is recommended to allow activity retries to define the retry model. |
| 169 | + """ |
| 170 | + |
| 171 | + def __init__(self, model_provider: Optional[ModelProvider] = None): |
| 172 | + """Initialize the activity with a model provider.""" |
| 173 | + self._model_provider = model_provider or OpenAIProvider( |
| 174 | + openai_client=AsyncOpenAI(max_retries=0) |
| 175 | + ) |
| 176 | + |
| 177 | + async def invoke_model_activity(self, input: ActivityModelInput) -> ModelResponse: |
| 178 | + """Activity that invokes a model with the given input.""" |
| 179 | + model = self._model_provider.get_model(input.model_name) |
| 180 | + |
| 181 | + async def empty_on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str: |
| 182 | + return "" |
| 183 | + |
| 184 | + async def empty_on_invoke_handoff( |
| 185 | + ctx: RunContextWrapper[Any], input: str |
| 186 | + ) -> Any: |
| 187 | + return None |
| 188 | + |
| 189 | + # workaround for https://github.com/pydantic/pydantic/issues/9541 |
| 190 | + # ValidatorIterator returned |
| 191 | + input_json = to_json(input.input) |
| 192 | + input_input = json.loads(input_json) |
| 193 | + |
| 194 | + def make_tool(tool: ToolInput) -> Tool: |
| 195 | + if isinstance( |
| 196 | + tool, |
| 197 | + ( |
| 198 | + FileSearchTool, |
| 199 | + WebSearchTool, |
| 200 | + ImageGenerationTool, |
| 201 | + CodeInterpreterTool, |
| 202 | + ), |
| 203 | + ): |
| 204 | + return tool |
| 205 | + elif isinstance(tool, HostedMCPToolInput): |
| 206 | + return HostedMCPTool( |
| 207 | + tool_config=tool.tool_config, |
| 208 | + ) |
| 209 | + elif isinstance(tool, FunctionToolInput): |
| 210 | + return FunctionTool( |
| 211 | + name=tool.name, |
| 212 | + description=tool.description, |
| 213 | + params_json_schema=tool.params_json_schema, |
| 214 | + on_invoke_tool=empty_on_invoke_tool, |
| 215 | + strict_json_schema=tool.strict_json_schema, |
| 216 | + ) |
| 217 | + else: |
| 218 | + raise UserError(f"Unknown tool type: {tool.name}") |
| 219 | + |
| 220 | + tools = [make_tool(x) for x in input.tools] |
| 221 | + handoffs: list[Handoff[Any, Any]] = [ |
| 222 | + Handoff( |
| 223 | + tool_name=x.tool_name, |
| 224 | + tool_description=x.tool_description, |
| 225 | + input_json_schema=x.input_json_schema, |
| 226 | + agent_name=x.agent_name, |
| 227 | + strict_json_schema=x.strict_json_schema, |
| 228 | + on_invoke_handoff=empty_on_invoke_handoff, |
| 229 | + ) |
| 230 | + for x in input.handoffs |
| 231 | + ] |
| 232 | + |
| 233 | + try: |
| 234 | + return await model.get_response( |
| 235 | + system_instructions=input.system_instructions, |
| 236 | + input=input_input, |
| 237 | + model_settings=input.model_settings, |
| 238 | + tools=tools, |
| 239 | + output_schema=input.output_schema, |
| 240 | + handoffs=handoffs, |
| 241 | + tracing=ModelTracing(input.tracing), |
| 242 | + previous_response_id=input.previous_response_id, |
| 243 | + prompt=input.prompt, |
| 244 | + ) |
| 245 | + except APIStatusError as e: |
| 246 | + # Listen to server hints |
| 247 | + retry_after = None |
| 248 | + retry_after_ms_header = e.response.headers.get("retry-after-ms") |
| 249 | + if retry_after_ms_header is not None: |
| 250 | + retry_after = timedelta(milliseconds=float(retry_after_ms_header)) |
| 251 | + |
| 252 | + if retry_after is None: |
| 253 | + retry_after_header = e.response.headers.get("retry-after") |
| 254 | + if retry_after_header is not None: |
| 255 | + retry_after = timedelta(seconds=float(retry_after_header)) |
| 256 | + |
| 257 | + should_retry_header = e.response.headers.get("x-should-retry") |
| 258 | + if should_retry_header == "true": |
| 259 | + raise e |
| 260 | + if should_retry_header == "false": |
| 261 | + raise ApplicationError( |
| 262 | + "Non retryable OpenAI error", |
| 263 | + non_retryable=True, |
| 264 | + next_retry_delay=retry_after, |
| 265 | + ) from e |
| 266 | + |
| 267 | + # Specifically retryable status codes |
| 268 | + if e.response.status_code in [408, 409, 429, 500]: |
| 269 | + raise ApplicationError( |
| 270 | + "Retryable OpenAI status code", |
| 271 | + non_retryable=False, |
| 272 | + next_retry_delay=retry_after, |
| 273 | + ) from e |
| 274 | + |
| 275 | + raise ApplicationError( |
| 276 | + "Non retryable OpenAI status code", |
| 277 | + non_retryable=True, |
| 278 | + next_retry_delay=retry_after, |
| 279 | + ) from e |
0 commit comments