Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 11 additions & 51 deletions azure/durable_functions/openai_agents/model_invocation_activity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import enum
import json
import logging
from datetime import timedelta
from typing import Any, AsyncIterator, Optional, Union, cast

from azure.durable_functions.models.RetryOptions import RetryOptions
Expand All @@ -28,9 +27,6 @@
WebSearchTool,
)
from agents.items import TResponseStreamEvent
from openai import (
APIStatusError,
)
from openai.types.responses.tool_param import Mcp
from openai.types.responses.response_prompt_param import ResponsePromptParam

Expand Down Expand Up @@ -243,53 +239,17 @@ def make_tool(tool: ToolInput) -> Tool:
for x in input.handoffs
]

try:
return await model.get_response(
system_instructions=input.system_instructions,
input=input_input,
model_settings=input.model_settings,
tools=tools,
output_schema=input.output_schema,
handoffs=handoffs,
tracing=ModelTracing(input.tracing),
previous_response_id=input.previous_response_id,
prompt=input.prompt,
)
except APIStatusError as e:
# Listen to server hints
retry_after = None
retry_after_ms_header = e.response.headers.get("retry-after-ms")
if retry_after_ms_header is not None:
retry_after = timedelta(milliseconds=float(retry_after_ms_header))

if retry_after is None:
retry_after_header = e.response.headers.get("retry-after")
if retry_after_header is not None:
retry_after = timedelta(seconds=float(retry_after_header))

should_retry_header = e.response.headers.get("x-should-retry")
if should_retry_header == "true":
raise e
if should_retry_header == "false":
raise ApplicationError(
"Non retryable OpenAI error",
non_retryable=True,
next_retry_delay=retry_after,
) from e

# Specifically retryable status codes
if e.response.status_code in [408, 409, 429, 500]:
raise ApplicationError(
"Retryable OpenAI status code",
non_retryable=False,
next_retry_delay=retry_after,
) from e

raise ApplicationError(
"Non retryable OpenAI status code",
non_retryable=True,
next_retry_delay=retry_after,
) from e
return await model.get_response(
system_instructions=input.system_instructions,
input=input_input,
model_settings=input.model_settings,
tools=tools,
output_schema=input.output_schema,
handoffs=handoffs,
tracing=ModelTracing(input.tracing),
previous_response_id=input.previous_response_id,
prompt=input.prompt,
)


class DurableActivityModel(Model):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .usage_telemetry import UsageTelemetry


async def durable_openai_agent_activity(input: str, model_provider: ModelProvider):
async def durable_openai_agent_activity(input: str, model_provider: ModelProvider) -> str:
"""Activity logic that handles OpenAI model invocations."""
activity_input = ActivityModelInput.from_json(input)

Expand Down