Skip to content

Commit 3481914

Browse files
authored
Capture file IDs from code interpreter in streaming responses (#2741)
1 parent 4c6a5d4 commit 3481914

File tree

10 files changed

+785
-13
lines changed

10 files changed

+785
-13
lines changed

python/packages/azure-ai/agent_framework_azure_ai/_chat_client.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@
6363
McpTool,
6464
MessageDeltaChunk,
6565
MessageDeltaTextContent,
66+
MessageDeltaTextFileCitationAnnotation,
67+
MessageDeltaTextFilePathAnnotation,
6668
MessageDeltaTextUrlCitationAnnotation,
6769
MessageImageUrlParam,
6870
MessageInputContentBlock,
@@ -471,6 +473,45 @@ def _extract_url_citations(
471473

472474
return url_citations
473475

476+
def _extract_file_path_contents(self, message_delta_chunk: MessageDeltaChunk) -> list[HostedFileContent]:
477+
"""Extract file references from MessageDeltaChunk annotations.
478+
479+
Code interpreter generates files that are referenced via file path or file citation
480+
annotations in the message content. This method extracts those file IDs and returns
481+
them as HostedFileContent objects.
482+
483+
Handles two annotation types:
484+
- MessageDeltaTextFilePathAnnotation: Contains file_path.file_id
485+
- MessageDeltaTextFileCitationAnnotation: Contains file_citation.file_id
486+
487+
Args:
488+
message_delta_chunk: The message delta chunk to process
489+
490+
Returns:
491+
List of HostedFileContent objects for any files referenced in annotations
492+
"""
493+
file_contents: list[HostedFileContent] = []
494+
495+
for content in message_delta_chunk.delta.content:
496+
if isinstance(content, MessageDeltaTextContent) and content.text and content.text.annotations:
497+
for annotation in content.text.annotations:
498+
if isinstance(annotation, MessageDeltaTextFilePathAnnotation):
499+
# Extract file_id from the file_path annotation
500+
file_path = getattr(annotation, "file_path", None)
501+
if file_path is not None:
502+
file_id = getattr(file_path, "file_id", None)
503+
if file_id:
504+
file_contents.append(HostedFileContent(file_id=file_id))
505+
elif isinstance(annotation, MessageDeltaTextFileCitationAnnotation):
506+
# Extract file_id from the file_citation annotation
507+
file_citation = getattr(annotation, "file_citation", None)
508+
if file_citation is not None:
509+
file_id = getattr(file_citation, "file_id", None)
510+
if file_id:
511+
file_contents.append(HostedFileContent(file_id=file_id))
512+
513+
return file_contents
514+
474515
def _get_real_url_from_citation_reference(
475516
self, citation_url: str, azure_search_tool_calls: list[dict[str, Any]]
476517
) -> str:
@@ -530,6 +571,9 @@ async def _process_stream(
530571
# Extract URL citations from the delta chunk
531572
url_citations = self._extract_url_citations(event_data, azure_search_tool_calls)
532573

574+
# Extract file path contents from code interpreter outputs
575+
file_contents = self._extract_file_path_contents(event_data)
576+
533577
# Create contents with citations if any exist
534578
citation_content: list[Contents] = []
535579
if event_data.text or url_citations:
@@ -538,6 +582,9 @@ async def _process_stream(
538582
text_content_obj.annotations = url_citations
539583
citation_content.append(text_content_obj)
540584

585+
# Add file contents from file path annotations
586+
citation_content.extend(file_contents)
587+
541588
yield ChatResponseUpdate(
542589
role=role,
543590
contents=citation_content if citation_content else None,

python/packages/azure-ai/tests/test_azure_ai_agent_client.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
FunctionCallContent,
2525
FunctionResultContent,
2626
HostedCodeInterpreterTool,
27+
HostedFileContent,
2728
HostedFileSearchTool,
2829
HostedMCPTool,
2930
HostedVectorStoreContent,
@@ -42,6 +43,8 @@
4243
FileInfo,
4344
MessageDeltaChunk,
4445
MessageDeltaTextContent,
46+
MessageDeltaTextFileCitationAnnotation,
47+
MessageDeltaTextFilePathAnnotation,
4548
MessageDeltaTextUrlCitationAnnotation,
4649
RequiredFunctionToolCall,
4750
RequiredMcpToolCall,
@@ -1362,6 +1365,108 @@ def test_azure_ai_chat_client_extract_url_citations_with_citations(mock_agents_c
13621365
assert citation.annotated_regions[0].end_index == 20
13631366

13641367

1368+
def test_azure_ai_chat_client_extract_file_path_contents_with_file_path_annotation(
1369+
mock_agents_client: MagicMock,
1370+
) -> None:
1371+
"""Test _extract_file_path_contents with MessageDeltaChunk containing file path annotation."""
1372+
chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
1373+
1374+
# Create mock file_path annotation
1375+
mock_file_path = MagicMock()
1376+
mock_file_path.file_id = "assistant-test-file-123"
1377+
1378+
mock_annotation = MagicMock(spec=MessageDeltaTextFilePathAnnotation)
1379+
mock_annotation.file_path = mock_file_path
1380+
1381+
# Create mock text content with annotations
1382+
mock_text = MagicMock()
1383+
mock_text.annotations = [mock_annotation]
1384+
1385+
mock_text_content = MagicMock(spec=MessageDeltaTextContent)
1386+
mock_text_content.text = mock_text
1387+
1388+
# Create mock delta
1389+
mock_delta = MagicMock()
1390+
mock_delta.content = [mock_text_content]
1391+
1392+
# Create mock MessageDeltaChunk
1393+
mock_chunk = MagicMock(spec=MessageDeltaChunk)
1394+
mock_chunk.delta = mock_delta
1395+
1396+
# Call the method
1397+
file_contents = chat_client._extract_file_path_contents(mock_chunk)
1398+
1399+
# Verify results
1400+
assert len(file_contents) == 1
1401+
assert isinstance(file_contents[0], HostedFileContent)
1402+
assert file_contents[0].file_id == "assistant-test-file-123"
1403+
1404+
1405+
def test_azure_ai_chat_client_extract_file_path_contents_with_file_citation_annotation(
1406+
mock_agents_client: MagicMock,
1407+
) -> None:
1408+
"""Test _extract_file_path_contents with MessageDeltaChunk containing file citation annotation."""
1409+
chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
1410+
1411+
# Create mock file_citation annotation
1412+
mock_file_citation = MagicMock()
1413+
mock_file_citation.file_id = "cfile_test-citation-456"
1414+
1415+
mock_annotation = MagicMock(spec=MessageDeltaTextFileCitationAnnotation)
1416+
mock_annotation.file_citation = mock_file_citation
1417+
1418+
# Create mock text content with annotations
1419+
mock_text = MagicMock()
1420+
mock_text.annotations = [mock_annotation]
1421+
1422+
mock_text_content = MagicMock(spec=MessageDeltaTextContent)
1423+
mock_text_content.text = mock_text
1424+
1425+
# Create mock delta
1426+
mock_delta = MagicMock()
1427+
mock_delta.content = [mock_text_content]
1428+
1429+
# Create mock MessageDeltaChunk
1430+
mock_chunk = MagicMock(spec=MessageDeltaChunk)
1431+
mock_chunk.delta = mock_delta
1432+
1433+
# Call the method
1434+
file_contents = chat_client._extract_file_path_contents(mock_chunk)
1435+
1436+
# Verify results
1437+
assert len(file_contents) == 1
1438+
assert isinstance(file_contents[0], HostedFileContent)
1439+
assert file_contents[0].file_id == "cfile_test-citation-456"
1440+
1441+
1442+
def test_azure_ai_chat_client_extract_file_path_contents_empty_annotations(
1443+
mock_agents_client: MagicMock,
1444+
) -> None:
1445+
"""Test _extract_file_path_contents with no annotations returns empty list."""
1446+
chat_client = create_test_azure_ai_chat_client(mock_agents_client, agent_id="test-agent")
1447+
1448+
# Create mock text content with no annotations
1449+
mock_text = MagicMock()
1450+
mock_text.annotations = []
1451+
1452+
mock_text_content = MagicMock(spec=MessageDeltaTextContent)
1453+
mock_text_content.text = mock_text
1454+
1455+
# Create mock delta
1456+
mock_delta = MagicMock()
1457+
mock_delta.content = [mock_text_content]
1458+
1459+
# Create mock MessageDeltaChunk
1460+
mock_chunk = MagicMock(spec=MessageDeltaChunk)
1461+
mock_chunk.delta = mock_delta
1462+
1463+
# Call the method
1464+
file_contents = chat_client._extract_file_path_contents(mock_chunk)
1465+
1466+
# Verify results
1467+
assert len(file_contents) == 0
1468+
1469+
13651470
def get_weather(
13661471
location: Annotated[str, Field(description="The location to get the weather for.")],
13671472
) -> str:

python/packages/core/agent_framework/openai/_responses_client.py

Lines changed: 71 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence, Sequence
44
from datetime import datetime, timezone
55
from itertools import chain
6-
from typing import Any, TypeVar
6+
from typing import Any, TypeVar, cast
77

88
from openai import AsyncOpenAI, BadRequestError
99
from openai.types.responses.file_search_tool_param import FileSearchToolParam
@@ -199,7 +199,7 @@ def _prepare_text_config(
199199
return response_format, prepared_text
200200

201201
if isinstance(response_format, Mapping):
202-
format_config = self._convert_response_format(response_format)
202+
format_config = self._convert_response_format(cast("Mapping[str, Any]", response_format))
203203
if prepared_text is None:
204204
prepared_text = {}
205205
elif "format" in prepared_text and prepared_text["format"] != format_config:
@@ -212,20 +212,21 @@ def _prepare_text_config(
212212
def _convert_response_format(self, response_format: Mapping[str, Any]) -> dict[str, Any]:
213213
"""Convert Chat style response_format into Responses text format config."""
214214
if "format" in response_format and isinstance(response_format["format"], Mapping):
215-
return dict(response_format["format"])
215+
return dict(cast("Mapping[str, Any]", response_format["format"]))
216216

217217
format_type = response_format.get("type")
218218
if format_type == "json_schema":
219219
schema_section = response_format.get("json_schema", response_format)
220220
if not isinstance(schema_section, Mapping):
221221
raise ServiceInvalidRequestError("json_schema response_format must be a mapping.")
222-
schema = schema_section.get("schema")
222+
schema_section_typed = cast("Mapping[str, Any]", schema_section)
223+
schema: Any = schema_section_typed.get("schema")
223224
if schema is None:
224225
raise ServiceInvalidRequestError("json_schema response_format requires a schema.")
225-
name = (
226-
schema_section.get("name")
227-
or schema_section.get("title")
228-
or (schema.get("title") if isinstance(schema, Mapping) else None)
226+
name: str = str(
227+
schema_section_typed.get("name")
228+
or schema_section_typed.get("title")
229+
or (cast("Mapping[str, Any]", schema).get("title") if isinstance(schema, Mapping) else None)
229230
or "response"
230231
)
231232
format_config: dict[str, Any] = {
@@ -532,12 +533,13 @@ def _openai_content_parser(
532533
"text": content.text,
533534
},
534535
}
535-
if content.additional_properties is not None:
536-
if status := content.additional_properties.get("status"):
536+
props: dict[str, Any] | None = getattr(content, "additional_properties", None)
537+
if props:
538+
if status := props.get("status"):
537539
ret["status"] = status
538-
if reasoning_text := content.additional_properties.get("reasoning_text"):
540+
if reasoning_text := props.get("reasoning_text"):
539541
ret["content"] = {"type": "reasoning_text", "text": reasoning_text}
540-
if encrypted_content := content.additional_properties.get("encrypted_content"):
542+
if encrypted_content := props.get("encrypted_content"):
541543
ret["encrypted_content"] = encrypted_content
542544
return ret
543545
case DataContent() | UriContent():
@@ -824,7 +826,7 @@ def _create_response_content(
824826
"raw_representation": response,
825827
}
826828

827-
conversation_id = self.get_conversation_id(response, chat_options.store)
829+
conversation_id = self.get_conversation_id(response, chat_options.store) # type: ignore[reportArgumentType]
828830

829831
if conversation_id:
830832
args["conversation_id"] = conversation_id
@@ -911,6 +913,8 @@ def _create_streaming_response_content(
911913
metadata.update(self._get_metadata_from_response(event_part))
912914
case "refusal":
913915
contents.append(TextContent(text=event_part.refusal, raw_representation=event))
916+
case _:
917+
pass
914918
case "response.output_text.delta":
915919
contents.append(TextContent(text=event.delta, raw_representation=event))
916920
metadata.update(self._get_metadata_from_response(event))
@@ -1032,6 +1036,60 @@ def _create_streaming_response_content(
10321036
raw_representation=event,
10331037
)
10341038
)
1039+
case "response.output_text.annotation.added":
1040+
# Handle streaming text annotations (file citations, file paths, etc.)
1041+
annotation: Any = event.annotation
1042+
1043+
def _get_ann_value(key: str) -> Any:
1044+
"""Extract value from annotation (dict or object)."""
1045+
if isinstance(annotation, dict):
1046+
return cast("dict[str, Any]", annotation).get(key)
1047+
return getattr(annotation, key, None)
1048+
1049+
ann_type = _get_ann_value("type")
1050+
ann_file_id = _get_ann_value("file_id")
1051+
if ann_type == "file_path":
1052+
if ann_file_id:
1053+
contents.append(
1054+
HostedFileContent(
1055+
file_id=str(ann_file_id),
1056+
additional_properties={
1057+
"annotation_index": event.annotation_index,
1058+
"index": _get_ann_value("index"),
1059+
},
1060+
raw_representation=event,
1061+
)
1062+
)
1063+
elif ann_type == "file_citation":
1064+
if ann_file_id:
1065+
contents.append(
1066+
HostedFileContent(
1067+
file_id=str(ann_file_id),
1068+
additional_properties={
1069+
"annotation_index": event.annotation_index,
1070+
"filename": _get_ann_value("filename"),
1071+
"index": _get_ann_value("index"),
1072+
},
1073+
raw_representation=event,
1074+
)
1075+
)
1076+
elif ann_type == "container_file_citation":
1077+
if ann_file_id:
1078+
contents.append(
1079+
HostedFileContent(
1080+
file_id=str(ann_file_id),
1081+
additional_properties={
1082+
"annotation_index": event.annotation_index,
1083+
"container_id": _get_ann_value("container_id"),
1084+
"filename": _get_ann_value("filename"),
1085+
"start_index": _get_ann_value("start_index"),
1086+
"end_index": _get_ann_value("end_index"),
1087+
},
1088+
raw_representation=event,
1089+
)
1090+
)
1091+
else:
1092+
logger.debug("Unparsed annotation type in streaming: %s", ann_type)
10351093
case _:
10361094
logger.debug("Unparsed event of type: %s: %s", event.type, event)
10371095

0 commit comments

Comments
 (0)