Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions azure/durable_functions/decorators/durable_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,21 +253,21 @@ def decorator():

return wrap

def _create_invoke_model_activity(self, model_provider):
def _create_invoke_model_activity(self, model_provider, activity_name):
"""Create and register the invoke_model_activity function with the provided FunctionApp."""

@self.activity_trigger(input_name="input")
async def invoke_model_activity(input: str):
@self.activity_trigger(input_name="input", activity=activity_name)
async def run_model_activity(input: str):
from azure.durable_functions.openai_agents.orchestrator_generator\
import durable_openai_agent_activity

return await durable_openai_agent_activity(input, model_provider)

return invoke_model_activity
return run_model_activity

def _setup_durable_openai_agent(self, model_provider):
def _setup_durable_openai_agent(self, model_provider, activity_name):
if not self._is_durable_openai_agent_setup:
self._create_invoke_model_activity(model_provider)
self._create_invoke_model_activity(model_provider, activity_name)
self._is_durable_openai_agent_setup = True

def durable_openai_agent_orchestrator(
Expand All @@ -294,14 +294,16 @@ def durable_openai_agent_orchestrator(
if model_provider is not None and type(model_provider) is not ModelProvider:
raise TypeError("Provided model provider must be of type ModelProvider")

self._setup_durable_openai_agent(model_provider)
activity_name = "run_model"

self._setup_durable_openai_agent(model_provider, activity_name)

def generator_wrapper_wrapper(func):

@wraps(func)
def generator_wrapper(context):
return durable_openai_agent_orchestrator_generator(
func, context, model_retry_options
func, context, model_retry_options, activity_name
)

return generator_wrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,10 +291,12 @@ def __init__(
model_name: Optional[str],
task_tracker: TaskTracker,
retry_options: Optional[RetryOptions],
activity_name: str,
) -> None:
self.model_name = model_name
self.task_tracker = task_tracker
self.retry_options = retry_options
self.activity_name = activity_name

async def get_response(
self,
Expand Down Expand Up @@ -382,13 +384,13 @@ def make_tool_info(tool: Tool) -> ToolInput:

if self.retry_options:
response = self.task_tracker.get_activity_call_result_with_retry(
"invoke_model_activity",
self.activity_name,
self.retry_options,
activity_input_json,
)
else:
response = self.task_tracker.get_activity_call_result(
"invoke_model_activity", activity_input_json
self.activity_name, activity_input_json
)

json_response = json.loads(response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,16 @@ def durable_openai_agent_orchestrator_generator(
func,
durable_orchestration_context: DurableOrchestrationContext,
model_retry_options: Optional[RetryOptions],
activity_name: str,
):
"""Adapts the synchronous OpenAI Agents function to an Durable orchestrator generator."""
ensure_event_loop()
task_tracker = TaskTracker(durable_orchestration_context)
durable_ai_agent_context = DurableAIAgentContext(
durable_orchestration_context, task_tracker, model_retry_options
)
durable_openai_runner = DurableOpenAIRunner(context=durable_ai_agent_context)
durable_openai_runner = DurableOpenAIRunner(
context=durable_ai_agent_context, activity_name=activity_name)
set_default_agent_runner(durable_openai_runner)

func_with_context = partial(func, durable_ai_agent_context)
Expand Down
4 changes: 3 additions & 1 deletion azure/durable_functions/openai_agents/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@
class DurableOpenAIRunner:
"""Runner for OpenAI agents using Durable Functions orchestration."""

def __init__(self, context: DurableAIAgentContext) -> None:
def __init__(self, context: DurableAIAgentContext, activity_name: str) -> None:
self._runner = DEFAULT_AGENT_RUNNER or AgentRunner()
self.context = context
self.activity_name = activity_name

def run_sync(
self,
Expand Down Expand Up @@ -62,6 +63,7 @@ def run_sync(
model_name=model_name,
task_tracker=self.context._task_tracker,
retry_options=self.context._model_retry_options,
activity_name=self.activity_name,
),
)

Expand Down
2 changes: 1 addition & 1 deletion tests/orchestrator/openai_agents/test_openai_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def openai_agent_return_pydantic_model_type(context):

return model

model_activity_name = "invoke_model_activity"
model_activity_name = "run_model"

def base_expected_state(output=None, replay_schema: ReplaySchema = ReplaySchema.V1) -> OrchestratorState:
return OrchestratorState(is_done=False, actions=[], output=output, replay_schema=replay_schema)
Expand Down