|
18 | 18 | from .._run_context import RunContext |
19 | 19 | from .._thinking_part import split_content_into_text_and_thinking |
20 | 20 | from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime |
21 | | -from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool |
| 21 | +from ..builtin_tools import CodeExecutionTool, FileSearchTool, ImageGenerationTool, MCPServerTool, WebSearchTool |
22 | 22 | from ..exceptions import UserError |
23 | 23 | from ..messages import ( |
24 | 24 | AudioUrl, |
@@ -1070,9 +1070,10 @@ def _process_response( # noqa: C901 |
1070 | 1070 | elif isinstance(item, responses.response_output_item.LocalShellCall): # pragma: no cover |
1071 | 1071 | # Pydantic AI doesn't yet support the `codex-mini-latest` LocalShell built-in tool |
1072 | 1072 | pass |
1073 | | - elif isinstance(item, responses.ResponseFileSearchToolCall): # pragma: no cover |
1074 | | - # Pydantic AI doesn't yet support the FileSearch built-in tool |
1075 | | - pass |
| 1073 | + elif isinstance(item, responses.ResponseFileSearchToolCall): |
| 1074 | + call_part, return_part = _map_file_search_tool_call(item, self.system) |
| 1075 | + items.append(call_part) |
| 1076 | + items.append(return_part) |
1076 | 1077 | elif isinstance(item, responses.response_output_item.McpCall): |
1077 | 1078 | call_part, return_part = _map_mcp_call(item, self.system) |
1078 | 1079 | items.append(call_part) |
@@ -1267,6 +1268,11 @@ def _get_builtin_tools(self, model_request_parameters: ModelRequestParameters) - |
1267 | 1268 | type='approximate', **tool.user_location |
1268 | 1269 | ) |
1269 | 1270 | tools.append(web_search_tool) |
| 1271 | + elif isinstance(tool, FileSearchTool): |
| 1272 | + file_search_tool = responses.FileSearchToolParam( |
| 1273 | + type='file_search', vector_store_ids=tool.vector_store_ids |
| 1274 | + ) |
| 1275 | + tools.append(file_search_tool) |
1270 | 1276 | elif isinstance(tool, CodeExecutionTool): |
1271 | 1277 | has_image_generating_tool = True |
1272 | 1278 | tools.append({'type': 'code_interpreter', 'container': {'type': 'auto'}}) |
@@ -1404,6 +1410,7 @@ async def _map_messages( # noqa: C901 |
1404 | 1410 | message_item: responses.ResponseOutputMessageParam | None = None |
1405 | 1411 | reasoning_item: responses.ResponseReasoningItemParam | None = None |
1406 | 1412 | web_search_item: responses.ResponseFunctionWebSearchParam | None = None |
| 1413 | + file_search_item: responses.ResponseFileSearchToolCallParam | None = None |
1407 | 1414 | code_interpreter_item: responses.ResponseCodeInterpreterToolCallParam | None = None |
1408 | 1415 | for item in message.parts: |
1409 | 1416 | if isinstance(item, TextPart): |
@@ -1473,6 +1480,18 @@ async def _map_messages( # noqa: C901 |
1473 | 1480 | type='web_search_call', |
1474 | 1481 | ) |
1475 | 1482 | openai_messages.append(web_search_item) |
| 1483 | + elif ( |
| 1484 | + item.tool_name == FileSearchTool.kind |
| 1485 | + and item.tool_call_id |
| 1486 | + and (args := item.args_as_dict()) |
| 1487 | + ): |
| 1488 | + file_search_item = responses.ResponseFileSearchToolCallParam( |
| 1489 | + id=item.tool_call_id, |
| 1490 | + action=cast(responses.response_file_search_tool_call_param.Action, args), |
| 1491 | + status='completed', |
| 1492 | + type='file_search_call', |
| 1493 | + ) |
| 1494 | + openai_messages.append(file_search_item) |
1476 | 1495 | elif item.tool_name == ImageGenerationTool.kind and item.tool_call_id: |
1477 | 1496 | # The cast is necessary because of https://github.com/openai/openai-python/issues/2648 |
1478 | 1497 | image_generation_item = cast( |
@@ -1532,6 +1551,14 @@ async def _map_messages( # noqa: C901 |
1532 | 1551 | and (status := content.get('status')) |
1533 | 1552 | ): |
1534 | 1553 | web_search_item['status'] = status |
| 1554 | + elif ( |
| 1555 | + item.tool_name == FileSearchTool.kind |
| 1556 | + and file_search_item is not None |
| 1557 | + and isinstance(item.content, dict) # pyright: ignore[reportUnknownMemberType] |
| 1558 | + and (content := cast(dict[str, Any], item.content)) # pyright: ignore[reportUnknownMemberType] |
| 1559 | + and (status := content.get('status')) |
| 1560 | + ): |
| 1561 | + file_search_item['status'] = status |
1535 | 1562 | elif item.tool_name == ImageGenerationTool.kind: |
1536 | 1563 | # Image generation result does not need to be sent back, just the `id` off of `BuiltinToolCallPart`. |
1537 | 1564 | pass |
@@ -1845,6 +1872,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: |
1845 | 1872 | yield self._parts_manager.handle_part( |
1846 | 1873 | vendor_part_id=f'{chunk.item.id}-call', part=replace(call_part, args=None) |
1847 | 1874 | ) |
| 1875 | + elif isinstance(chunk.item, responses.ResponseFileSearchToolCall): |
| 1876 | + call_part, _ = _map_file_search_tool_call(chunk.item, self.provider_name) |
| 1877 | + yield self._parts_manager.handle_part( |
| 1878 | + vendor_part_id=f'{chunk.item.id}-call', part=replace(call_part, args=None) |
| 1879 | + ) |
1848 | 1880 | elif isinstance(chunk.item, responses.ResponseCodeInterpreterToolCall): |
1849 | 1881 | call_part, _, _ = _map_code_interpreter_tool_call(chunk.item, self.provider_name) |
1850 | 1882 |
|
@@ -1913,6 +1945,17 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: |
1913 | 1945 | elif isinstance(chunk.item, responses.ResponseFunctionWebSearch): |
1914 | 1946 | call_part, return_part = _map_web_search_tool_call(chunk.item, self.provider_name) |
1915 | 1947 |
|
| 1948 | + maybe_event = self._parts_manager.handle_tool_call_delta( |
| 1949 | + vendor_part_id=f'{chunk.item.id}-call', |
| 1950 | + args=call_part.args, |
| 1951 | + ) |
| 1952 | + if maybe_event is not None: # pragma: no branch |
| 1953 | + yield maybe_event |
| 1954 | + |
| 1955 | + yield self._parts_manager.handle_part(vendor_part_id=f'{chunk.item.id}-return', part=return_part) |
| 1956 | + elif isinstance(chunk.item, responses.ResponseFileSearchToolCall): |
| 1957 | + call_part, return_part = _map_file_search_tool_call(chunk.item, self.provider_name) |
| 1958 | + |
1916 | 1959 | maybe_event = self._parts_manager.handle_tool_call_delta( |
1917 | 1960 | vendor_part_id=f'{chunk.item.id}-call', |
1918 | 1961 | args=call_part.args, |
@@ -2216,6 +2259,34 @@ def _map_web_search_tool_call( |
2216 | 2259 | ) |
2217 | 2260 |
|
2218 | 2261 |
|
| 2262 | +def _map_file_search_tool_call( |
| 2263 | + item: responses.ResponseFileSearchToolCall, provider_name: str |
| 2264 | +) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart]: |
| 2265 | + args: dict[str, Any] | None = None |
| 2266 | + |
| 2267 | + result = { |
| 2268 | + 'status': item.status, |
| 2269 | + } |
| 2270 | + |
| 2271 | + if action := item.action: |
| 2272 | + args = action.model_dump(mode='json') |
| 2273 | + |
| 2274 | + return ( |
| 2275 | + BuiltinToolCallPart( |
| 2276 | + tool_name=FileSearchTool.kind, |
| 2277 | + tool_call_id=item.id, |
| 2278 | + args=args, |
| 2279 | + provider_name=provider_name, |
| 2280 | + ), |
| 2281 | + BuiltinToolReturnPart( |
| 2282 | + tool_name=FileSearchTool.kind, |
| 2283 | + tool_call_id=item.id, |
| 2284 | + content=result, |
| 2285 | + provider_name=provider_name, |
| 2286 | + ), |
| 2287 | + ) |
| 2288 | + |
| 2289 | + |
2219 | 2290 | def _map_image_generation_tool_call( |
2220 | 2291 | item: responses.response_output_item.ImageGenerationCall, provider_name: str |
2221 | 2292 | ) -> tuple[BuiltinToolCallPart, BuiltinToolReturnPart, FilePart | None]: |
|
0 commit comments