diff --git a/src/a2a/utils/proto_utils.py b/src/a2a/utils/proto_utils.py index 3f3db5783..408c47bf2 100644 --- a/src/a2a/utils/proto_utils.py +++ b/src/a2a/utils/proto_utils.py @@ -80,8 +80,14 @@ def file( cls, file: types.FileWithUri | types.FileWithBytes ) -> a2a_pb2.FilePart: if isinstance(file, types.FileWithUri): - return a2a_pb2.FilePart(file_with_uri=file.uri) - return a2a_pb2.FilePart(file_with_bytes=file.bytes.encode('utf-8')) + return a2a_pb2.FilePart( + file_with_uri=file.uri, mime_type=file.mime_type, name=file.name + ) + return a2a_pb2.FilePart( + file_with_bytes=file.bytes.encode('utf-8'), + mime_type=file.mime_type, + name=file.name, + ) @classmethod def task(cls, task: types.Task) -> a2a_pb2.Task: @@ -500,9 +506,19 @@ def data(cls, data: a2a_pb2.DataPart) -> dict[str, Any]: def file( cls, file: a2a_pb2.FilePart ) -> types.FileWithUri | types.FileWithBytes: + common_args = { + 'mime_type': file.mime_type or None, + 'name': file.name or None, + } if file.HasField('file_with_uri'): - return types.FileWithUri(uri=file.file_with_uri) - return types.FileWithBytes(bytes=file.file_with_bytes.decode('utf-8')) + return types.FileWithUri( + uri=file.file_with_uri, + **common_args, + ) + return types.FileWithBytes( + bytes=file.file_with_bytes.decode('utf-8'), + **common_args, + ) @classmethod def task_or_message( diff --git a/tests/utils/test_proto_utils.py b/tests/utils/test_proto_utils.py index 83848c248..c3f1b6a42 100644 --- a/tests/utils/test_proto_utils.py +++ b/tests/utils/test_proto_utils.py @@ -22,7 +22,11 @@ def sample_message() -> types.Message: types.Part(root=types.TextPart(text='Hello')), types.Part( root=types.FilePart( - file=types.FileWithUri(uri='file:///test.txt') + file=types.FileWithUri( + uri='file:///test.txt', + name='test.txt', + mime_type='text/plain', + ), ) ), types.Part(root=types.DataPart(data={'key': 'value'})), @@ -148,6 +152,8 @@ def test_roundtrip_message(self, sample_message: types.Message): # Test file part handling assert proto_msg.content[1].file.file_with_uri == 'file:///test.txt' + assert proto_msg.content[1].file.mime_type == 'text/plain' + assert proto_msg.content[1].file.name == 'test.txt' roundtrip_msg = proto_utils.FromProto.message(proto_msg) assert roundtrip_msg == sample_message