Skip to content

Commit aefe6b2

Browse files
committed
Add metadata conversion for Parts
1 parent 112abd1 commit aefe6b2

File tree

1 file changed

+35
-6
lines changed

1 file changed

+35
-6
lines changed

src/a2a/utils/proto_utils.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,19 @@ def metadata(
7373
@classmethod
7474
def part(cls, part: types.Part) -> a2a_pb2.Part:
7575
if isinstance(part.root, types.TextPart):
76-
return a2a_pb2.Part(text=part.root.text)
76+
return a2a_pb2.Part(
77+
text=part.root.text, metadata=cls.metadata(part.root.metadata)
78+
)
7779
if isinstance(part.root, types.FilePart):
78-
return a2a_pb2.Part(file=cls.file(part.root.file))
80+
return a2a_pb2.Part(
81+
file=cls.file(part.root.file),
82+
metadata=cls.metadata(part.root.metadata),
83+
)
7984
if isinstance(part.root, types.DataPart):
80-
return a2a_pb2.Part(data=cls.data(part.root.data))
85+
return a2a_pb2.Part(
86+
data=cls.data(part.root.data),
87+
metadata=cls.metadata(part.root.metadata),
88+
)
8189
raise ValueError(f'Unsupported part type: {part.root}')
8290

8391
@classmethod
@@ -502,11 +510,32 @@ def metadata(cls, metadata: struct_pb2.Struct) -> dict[str, Any]:
502510
@classmethod
503511
def part(cls, part: a2a_pb2.Part) -> types.Part:
504512
if part.HasField('text'):
505-
return types.Part(root=types.TextPart(text=part.text))
513+
return types.Part(
514+
root=types.TextPart(
515+
text=part.text,
516+
metadata=cls.metadata(part.metadata)
517+
if part.metadata
518+
else None,
519+
),
520+
)
506521
if part.HasField('file'):
507-
return types.Part(root=types.FilePart(file=cls.file(part.file)))
522+
return types.Part(
523+
root=types.FilePart(
524+
file=cls.file(part.file),
525+
metadata=cls.metadata(part.metadata)
526+
if part.metadata
527+
else None,
528+
),
529+
)
508530
if part.HasField('data'):
509-
return types.Part(root=types.DataPart(data=cls.data(part.data)))
531+
return types.Part(
532+
root=types.DataPart(
533+
data=cls.data(part.data),
534+
metadata=cls.metadata(part.metadata)
535+
if part.metadata
536+
else None,
537+
),
538+
)
510539
raise ValueError(f'Unsupported part type: {part}')
511540

512541
@classmethod

0 commit comments

Comments
 (0)