Skip to content

Commit 5b517ac

Browse files
authored
Merge branch 'main' into auto-update-a2a-types-00cf76e7bbc752842ef254f3d4136ed1b5751f6e
2 parents 4693b71 + 72b2ee7 commit 5b517ac

File tree

2 files changed

+27
-5
lines changed

2 files changed

+27
-5
lines changed

src/a2a/utils/proto_utils.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,14 @@ def file(
8080
cls, file: types.FileWithUri | types.FileWithBytes
8181
) -> a2a_pb2.FilePart:
8282
if isinstance(file, types.FileWithUri):
83-
return a2a_pb2.FilePart(file_with_uri=file.uri)
84-
return a2a_pb2.FilePart(file_with_bytes=file.bytes.encode('utf-8'))
83+
return a2a_pb2.FilePart(
84+
file_with_uri=file.uri, mime_type=file.mime_type, name=file.name
85+
)
86+
return a2a_pb2.FilePart(
87+
file_with_bytes=file.bytes.encode('utf-8'),
88+
mime_type=file.mime_type,
89+
name=file.name,
90+
)
8591

8692
@classmethod
8793
def task(cls, task: types.Task) -> a2a_pb2.Task:
@@ -500,9 +506,19 @@ def data(cls, data: a2a_pb2.DataPart) -> dict[str, Any]:
500506
def file(
501507
cls, file: a2a_pb2.FilePart
502508
) -> types.FileWithUri | types.FileWithBytes:
509+
common_args = {
510+
'mime_type': file.mime_type or None,
511+
'name': file.name or None,
512+
}
503513
if file.HasField('file_with_uri'):
504-
return types.FileWithUri(uri=file.file_with_uri)
505-
return types.FileWithBytes(bytes=file.file_with_bytes.decode('utf-8'))
514+
return types.FileWithUri(
515+
uri=file.file_with_uri,
516+
**common_args,
517+
)
518+
return types.FileWithBytes(
519+
bytes=file.file_with_bytes.decode('utf-8'),
520+
**common_args,
521+
)
506522

507523
@classmethod
508524
def task_or_message(

tests/utils/test_proto_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,11 @@ def sample_message() -> types.Message:
2222
types.Part(root=types.TextPart(text='Hello')),
2323
types.Part(
2424
root=types.FilePart(
25-
file=types.FileWithUri(uri='file:///test.txt')
25+
file=types.FileWithUri(
26+
uri='file:///test.txt',
27+
name='test.txt',
28+
mime_type='text/plain',
29+
),
2630
)
2731
),
2832
types.Part(root=types.DataPart(data={'key': 'value'})),
@@ -148,6 +152,8 @@ def test_roundtrip_message(self, sample_message: types.Message):
148152

149153
# Test file part handling
150154
assert proto_msg.content[1].file.file_with_uri == 'file:///test.txt'
155+
assert proto_msg.content[1].file.mime_type == 'text/plain'
156+
assert proto_msg.content[1].file.name == 'test.txt'
151157

152158
roundtrip_msg = proto_utils.FromProto.message(proto_msg)
153159
assert roundtrip_msg == sample_message

0 commit comments

Comments
 (0)