|
2 | 2 | # Copyright (c) Microsoft Corporation. |
3 | 3 | # Licensed under the MIT License. |
4 | 4 | # ------------------------------------ |
5 | | -from typing import Any, Iterator, List, MutableMapping, Optional, Dict |
| 5 | +from typing import Any, Callable, Iterator, List, MutableMapping, Optional, Dict |
6 | 6 |
|
7 | 7 | import json |
8 | 8 | import os |
|
24 | 24 | ToolOutput, |
25 | 25 | AgentEventHandler, |
26 | 26 | ThreadRun, |
| 27 | + RunHandler, |
27 | 28 | ) |
28 | 29 |
|
29 | 30 | from user_functions import user_functions |
@@ -88,7 +89,7 @@ def get_mock_client(self) -> AgentsClient: |
88 | 89 | client.runs.submit_tool_outputs = MagicMock() |
89 | 90 | return client |
90 | 91 |
|
91 | | - def get_toolset(self, file_id: Optional[str], function: Optional[str]) -> Optional[ToolSet]: |
| 92 | + def get_toolset(self, file_id: Optional[str], function: Optional[Callable[..., Any]]) -> Optional[ToolSet]: |
92 | 93 | """Get the tool set with given file id and function""" |
93 | 94 | if file_id is None or function is None: |
94 | 95 | return None |
@@ -507,6 +508,58 @@ def test_create_run_tools_override( |
507 | 508 | else: |
508 | 509 | self._assert_tool_call(agents_client.runs.submit_tool_outputs, "run123", toolset1) |
509 | 510 |
|
| 511 | + @patch("azure.ai.agents._client.PipelineClient") |
| 512 | + def test_create_and_process_with_manual_function_calls( |
| 513 | + self, |
| 514 | + mock_pipeline_client_gen: MagicMock, |
| 515 | + ) -> None: |
| 516 | + """Test that if user have set tool set in create create_and_process_run method, that tools are used.""" |
| 517 | + toolset1 = self.get_toolset("file_for_agent_1", function1) |
| 518 | + mock_response = MagicMock() |
| 519 | + mock_response.status_code = 200 |
| 520 | + side_effect = [self._get_agent_json("first", "123", toolset1)] |
| 521 | + side_effect.append(self._get_run("run123", toolset1)) # create_run |
| 522 | + side_effect.append(self._get_run("run123", toolset1)) # get_run |
| 523 | + side_effect.append( |
| 524 | + self._get_run("run123", toolset1, is_complete=True) |
| 525 | + ) # get_run after resubmitting with tool results |
| 526 | + |
| 527 | + class MyRunHandler(RunHandler): |
| 528 | + def submit_function_call_output( |
| 529 | + self, |
| 530 | + run: ThreadRun, |
| 531 | + tool_call: RequiredFunctionToolCall, |
| 532 | + tool_call_details: RequiredFunctionToolCallDetails, |
| 533 | + **kwargs: Any, |
| 534 | + ) -> Optional[Any]: |
| 535 | + if tool_call_details.name == function1.__name__: |
| 536 | + return function1() |
| 537 | + |
| 538 | + mock_response.json.side_effect = side_effect |
| 539 | + mock_pipeline_response = MagicMock() |
| 540 | + mock_pipeline_response.http_response = mock_response |
| 541 | + mock_pipeline = MagicMock() |
| 542 | + mock_pipeline._pipeline.run.return_value = mock_pipeline_response |
| 543 | + mock_pipeline_client_gen.return_value = mock_pipeline |
| 544 | + agents_client = self.get_mock_client() |
| 545 | + with agents_client: |
| 546 | + # Check that pipelines are created as expected. |
| 547 | + self._set_toolcalls(agents_client, toolset1, None) |
| 548 | + run_handler = MyRunHandler() |
| 549 | + agent1 = agents_client.create_agent( |
| 550 | + model="gpt-4-1106-preview", |
| 551 | + name="first", |
| 552 | + instructions="You are a helpful agent", |
| 553 | + toolset=toolset1, |
| 554 | + ) |
| 555 | + self._assert_pipeline_and_reset(mock_pipeline._pipeline.run, tool_set=toolset1) |
| 556 | + |
| 557 | + # Create run with new tool set, which also can be none. |
| 558 | + agents_client.runs.create_and_process( |
| 559 | + thread_id="some_thread_id", agent_id=agent1.id, polling_interval=0, run_handler=run_handler |
| 560 | + ) |
| 561 | + self._assert_tool_call(agents_client.runs.submit_tool_outputs, "run123", toolset1) |
| 562 | + |
510 | 563 | @patch("azure.ai.agents._client.PipelineClient") |
511 | 564 | @pytest.mark.parametrize( |
512 | 565 | "file_agent_1,add_azure_fn", |
|
0 commit comments