1919from os import environ
2020from typing import (
2121 TYPE_CHECKING ,
22+ Any ,
2223 Iterable ,
24+ Literal ,
2325 Mapping ,
26+ Optional ,
2427 Sequence ,
28+ Union ,
2529 cast ,
2630)
2731from urllib .parse import urlparse
6569_MODEL = "model"
6670
6771
72+ @dataclass (frozen = True )
73+ class ToolCall :
74+ type : Literal ["tool_call" ]
75+ arguments : Any
76+ name : str
77+ id : Optional [str ]
78+
79+
80+ @dataclass (frozen = True )
81+ class ToolCallResponse :
82+ type : Literal ["tool_call_response" ]
83+ response : Any
84+ id : Optional [str ]
85+
86+
87+ @dataclass (frozen = True )
88+ class TextPart :
89+ type : Literal ["text" ]
90+ content : str
91+
92+
93+ MessagePart = Union [TextPart , ToolCall , ToolCallResponse , Any ]
94+
95+
96+ @dataclass ()
97+ class InputMessage (Any ):
98+ role : str
99+ parts : list [MessagePart ]
100+
101+
102+ @dataclass ()
103+ class OutputMessage (Any ):
104+ role : str
105+ parts : list [MessagePart ]
106+ finish_reason : Union [str , FinishReason ]
107+
108+
68109@dataclass (frozen = True )
69110class GenerateContentParams :
70111 model : str
@@ -256,7 +297,7 @@ def request_to_events(
256297 id_ = f"{ function_response .name } _{ idx } " ,
257298 role = content .role ,
258299 content = json_format .MessageToDict (
259- function_response ._pb .response
300+ function_response ._pb .response # type: ignore[reportUnknownMemberType]
260301 )
261302 if capture_content
262303 else None ,
@@ -290,15 +331,15 @@ def create_operation_details_event(
290331 event .attributes = attributes
291332 if not capture_content :
292333 return event
293-
294- attributes ["gen_ai.system_instructions" ] = [
295- {
296- "type" : "text" ,
297- "content" : "\n " .join (
298- part .text for part in params .system_instruction .parts
299- ),
300- }
301- ]
334+ if params . system_instruction :
335+ attributes ["gen_ai.system_instructions" ] = [
336+ {
337+ "type" : "text" ,
338+ "content" : "\n " .join (
339+ part .text for part in params .system_instruction .parts
340+ ),
341+ }
342+ ]
302343 if params .contents :
303344 attributes ["gen_ai.input.messages" ] = [
304345 _convert_content_to_message (content ) for content in params .contents
@@ -313,47 +354,50 @@ def create_operation_details_event(
313354def _convert_response_to_output_messages (
314355 response : prediction_service .GenerateContentResponse
315356 | prediction_service_v1beta1 .GenerateContentResponse ,
316- ) -> list :
317- output_messages = []
357+ ) -> list [ OutputMessage ] :
358+ output_messages : list [ OutputMessage ] = []
318359 for candidate in response .candidates :
319360 message = _convert_content_to_message (candidate .content )
320- message [ " finish_reason" ] = _map_finish_reason (candidate .finish_reason )
361+ message . finish_reason = _map_finish_reason (candidate .finish_reason )
321362 output_messages .append (message )
322363 return output_messages
323364
324365
325- def _convert_content_to_message (content : content .Content ) -> dict :
326- message = {"role" : content .role , "parts" : []}
366+ def _convert_content_to_message (
367+ content : content .Content | content_v1beta1 .Content ,
368+ ) -> InputMessage :
369+ parts : MessagePart = []
370+ message = InputMessage (role = content .role , parts = parts )
327371 for idx , part in enumerate (content .parts ):
328372 if "function_response" in part :
329373 part = part .function_response
330- message [ " parts" ] .append (
331- {
332- " type" : "tool_call_response" ,
333- "id" : f"{ part .name } _{ idx } " ,
334- " response" : json_format .MessageToDict (part ._pb .response ),
335- }
374+ parts .append (
375+ ToolCallResponse (
376+ type = "tool_call_response" ,
377+ id = f"{ part .name } _{ idx } " ,
378+ response = json_format .MessageToDict (part ._pb .response ), # type: ignore[reportUnknownMemberType]
379+ )
336380 )
337381 elif "function_call" in part :
338382 part = part .function_call
339- message [ " parts" ] .append (
340- {
341- " type" : "tool_call" ,
342- "id" : f"{ part .name } _{ idx } " ,
343- " name" : part .name ,
344- "response" : json_format .MessageToDict (
345- part ._pb .args ,
383+ parts .append (
384+ ToolCall (
385+ type = "tool_call" ,
386+ id = f"{ part .name } _{ idx } " ,
387+ name = part .name ,
388+ arguments = json_format .MessageToDict (
389+ part ._pb .args , # type: ignore[reportUnknownMemberType]
346390 ),
347- }
391+ )
348392 )
349393 elif "text" in part :
350- message ["parts" ].append ({"type" : "text" , "content" : part .text })
351- part = part .text
394+ parts .append (TextPart (type = "text" , content = part .text ))
352395 else :
353- message [ "parts" ]. append (
354- type ( part ). to_dict ( part , always_print_fields_with_no_presence = False )
396+ dict_part = type ( part ). to_dict ( # type: ignore[reportUnknownMemberType]
397+ part , always_print_fields_with_no_presence = False
355398 )
356- message ["parts" ][- 1 ]["type" ] = type (part )
399+ dict_part ["type" ] = type (part )
400+ parts .append (dict_part )
357401 return message
358402
359403
@@ -401,7 +445,7 @@ def _extract_tool_calls(
401445 function = ChoiceToolCall .Function (
402446 name = part .function_call .name ,
403447 arguments = json_format .MessageToDict (
404- part .function_call ._pb .args
448+ part .function_call ._pb .args # type: ignore[reportUnknownMemberType]
405449 )
406450 if capture_content
407451 else None ,
@@ -420,7 +464,9 @@ def _parts_to_any_value(
420464 return [
421465 cast (
422466 "dict[str, AnyValue]" ,
423- type (part ).to_dict (part , always_print_fields_with_no_presence = False ), # type: ignore[reportUnknownMemberType]
467+ type (part ).to_dict ( # type: ignore[reportUnknownMemberType]
468+ part , always_print_fields_with_no_presence = False
469+ ),
424470 )
425471 for part in parts
426472 ]
0 commit comments