@@ -27,12 +27,13 @@ def patch_vertexai(name, version, tracer: Tracer):
2727 def traced_method (wrapped , instance , args , kwargs ):
2828 service_provider = SERVICE_PROVIDERS ["VERTEXAI" ]
2929 prompts = serialize_prompts (args , kwargs )
30+
3031 span_attributes = {
3132 ** get_langtrace_attributes (version , service_provider ),
3233 ** get_llm_request_attributes (
3334 kwargs ,
3435 prompts = prompts ,
35- model = get_llm_model (instance ),
36+ model = get_llm_model (instance , kwargs ),
3637 ),
3738 ** get_llm_url (instance ),
3839 SpanAttributes .LLM_PATH : "" ,
@@ -77,6 +78,10 @@ def set_response_attributes(span: Span, result):
7778 if hasattr (result , "text" ):
7879 set_event_completion (span , [{"role" : "assistant" , "content" : result .text }])
7980
81+ if hasattr (result , "candidates" ):
82+ parts = result .candidates [0 ].content .parts
83+ set_event_completion (span , [{"role" : "assistant" , "content" : parts [0 ].text }])
84+
8085 if hasattr (result , "usage_metadata" ) and result .usage_metadata is not None :
8186 usage = result .usage_metadata
8287 input_tokens = usage .prompt_token_count
@@ -96,17 +101,23 @@ def set_response_attributes(span: Span, result):
96101
97102
98103def is_streaming_response (response ):
99- return isinstance (response , types .GeneratorType ) or isinstance (
100- response , types .AsyncGeneratorType
104+ return (
105+ isinstance (response , types .GeneratorType )
106+ or isinstance (response , types .AsyncGeneratorType )
107+ or str (type (response ).__name__ ) == "_StreamingResponseIterator"
101108 )
102109
103110
104- def get_llm_model (instance ):
111+ def get_llm_model (instance , kwargs ):
112+ if "request" in kwargs :
113+ return kwargs .get ("request" ).model .split ("/" )[- 1 ]
114+
105115 if hasattr (instance , "_model_name" ):
106116 return instance ._model_name .replace ("publishers/google/models/" , "" )
107117 return getattr (instance , "_model_id" , "unknown" )
108118
109119
120+ @silently_fail
110121def serialize_prompts (args , kwargs ):
111122 if args and len (args ) > 0 :
112123 prompt_parts = []
@@ -122,5 +133,24 @@ def serialize_prompts(args, kwargs):
122133
123134 return [{"role" : "user" , "content" : "\n " .join (prompt_parts )}]
124135 else :
125- content = kwargs .get ("prompt" ) or kwargs .get ("message" )
126- return [{"role" : "user" , "content" : content }] if content else []
136+ # Handle PredictionServiceClient for google-cloud-aiplatform.
137+ if "request" in kwargs :
138+ prompt = []
139+ prompt_body = kwargs .get ("request" )
140+ if prompt_body .system_instruction :
141+ for part in prompt_body .system_instruction .parts :
142+ prompt .append ({"role" : "system" , "content" : part .text })
143+
144+ contents = prompt_body .contents
145+
146+ if not contents :
147+ return []
148+
149+ for c in contents :
150+ role = c .role if c .role else "user"
151+ content = c .parts [0 ].text if c .parts else ""
152+ prompt .append ({"role" : role , "content" : content })
153+ return prompt
154+ else :
155+ content = kwargs .get ("prompt" ) or kwargs .get ("message" )
156+ return [{"role" : "user" , "content" : content }] if content else []
0 commit comments