|
32 | 32 | from azure.core.tracing.decorator_async import distributed_trace_async
|
33 | 33 |
|
34 | 34 | from ... import models as _models
|
35 |
| -from ...models._enums import FilePurpose, RunStatus |
| 35 | +from ...models._enums import FilePurpose |
36 | 36 | from ._operations import FilesOperations as FilesOperationsGenerated
|
37 | 37 | from ._operations import MessagesOperations as MessagesOperationsGenerated
|
38 | 38 | from ._operations import RunsOperations as RunsOperationsGenerated
|
@@ -442,6 +442,7 @@ async def create_and_process(
|
442 | 442 | response_format: Optional["_types.AgentsResponseFormatOption"] = None,
|
443 | 443 | parallel_tool_calls: Optional[bool] = None,
|
444 | 444 | metadata: Optional[Dict[str, str]] = None,
|
| 445 | + run_handler: Optional[_models.AsyncRunHandler] = None, |
445 | 446 | polling_interval: int = 1,
|
446 | 447 | **kwargs: Any,
|
447 | 448 | ) -> _models.ThreadRun:
|
@@ -522,6 +523,9 @@ async def create_and_process(
|
522 | 523 | 64 characters in length and values may be up to 512 characters in length. Default value is
|
523 | 524 | None.
|
524 | 525 | :paramtype metadata: dict[str, str]
|
| 526 | + :keyword run_handler: Optional handler to customize run processing and tool execution. |
| 527 | + Default value is None. |
| 528 | + :paramtype run_handler: ~azure.ai.agents.models.AsyncRunHandler |
525 | 529 | :keyword polling_interval: The time in seconds to wait between polling the service for run status.
|
526 | 530 | Default value is 1.
|
527 | 531 | :paramtype polling_interval: int
|
@@ -553,51 +557,9 @@ async def create_and_process(
|
553 | 557 | )
|
554 | 558 |
|
555 | 559 | # Monitor and process the run status
|
556 |
| - current_retry = 0 |
557 |
| - while run.status in [ |
558 |
| - RunStatus.QUEUED, |
559 |
| - RunStatus.IN_PROGRESS, |
560 |
| - RunStatus.REQUIRES_ACTION, |
561 |
| - ]: |
562 |
| - await asyncio.sleep(polling_interval) |
563 |
| - run = await self.get(thread_id=thread_id, run_id=run.id) |
564 |
| - |
565 |
| - if run.status == "requires_action" and isinstance(run.required_action, _models.SubmitToolOutputsAction): |
566 |
| - tool_calls = run.required_action.submit_tool_outputs.tool_calls |
567 |
| - if not tool_calls: |
568 |
| - logger.warning("No tool calls provided - cancelling run") |
569 |
| - await self.cancel(thread_id=thread_id, run_id=run.id) |
570 |
| - break |
571 |
| - # We need tool set only if we are executing local function. In case if |
572 |
| - # the tool is azure_function we just need to wait when it will be finished. |
573 |
| - if any(tool_call.type == "function" for tool_call in tool_calls): |
574 |
| - toolset = _models.AsyncToolSet() |
575 |
| - toolset.add(self._function_tool) |
576 |
| - tool_outputs = await toolset.execute_tool_calls(tool_calls) |
577 |
| - |
578 |
| - if _has_errors_in_toolcalls_output(tool_outputs): |
579 |
| - if current_retry >= self._function_tool_max_retry: # pylint:disable=no-else-return |
580 |
| - logger.warning( |
581 |
| - "Tool outputs contain errors - reaching max retry %s", self._function_tool_max_retry |
582 |
| - ) |
583 |
| - return await self.cancel(thread_id=thread_id, run_id=run.id) |
584 |
| - else: |
585 |
| - logger.warning("Tool outputs contain errors - retrying") |
586 |
| - current_retry += 1 |
587 |
| - |
588 |
| - logger.debug("Tool outputs: %s", tool_outputs) |
589 |
| - if tool_outputs: |
590 |
| - run2 = await self.submit_tool_outputs( |
591 |
| - thread_id=thread_id, run_id=run.id, tool_outputs=tool_outputs |
592 |
| - ) |
593 |
| - logger.debug("Tool outputs submitted to run: %s", run2.id) |
594 |
| - elif isinstance(run.required_action, _models.SubmitToolApprovalAction): |
595 |
| - logger.warning("Automatic MCP tool approval is not supported.") |
596 |
| - await self.cancel(thread_id=thread_id, run_id=run.id) |
597 |
| - |
598 |
| - logger.debug("Current run ID: %s with status: %s", run.id, run.status) |
599 |
| - |
600 |
| - return run |
| 560 | + run_handler_obj = run_handler or _models.AsyncRunHandler() |
| 561 | + |
| 562 | + return await run_handler_obj._start(self, run, polling_interval) # pylint: disable=protected-access |
601 | 563 |
|
602 | 564 | @overload
|
603 | 565 | async def stream(
|
|
0 commit comments