Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"QDRANT": "Qdrant",
"WEAVIATE": "Weaviate",
"OLLAMA": "Ollama",
"VERTEXAI": "VertexAI",
"VERTEXAI": "Vertex AI",
"GEMINI": "Gemini",
"MISTRAL": "Mistral",
"EMBEDCHAIN": "Embedchain",
Expand Down
20 changes: 20 additions & 0 deletions src/langtrace_python_sdk/constants/instrumentation/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,24 @@
"method": "ChatSession",
"operation": "send_message_streaming",
},
"PREDICTION_SERVICE_BETA_GENERATE_CONTENT": {
"module": "google.cloud.aiplatform_v1beta1.services.prediction_service.client",
"method": "PredictionServiceClient",
"operation": "generate_content",
},
"PREDICTION_SERVICE_GENERATE_CONTENT": {
"module": "google.cloud.aiplatform_v1.services.prediction_service.client",
"method": "PredictionServiceClient",
"operation": "generate_content",
},
"PREDICTION_SERVICE_BETA_STREAM_GENERATE_CONTENT": {
"module": "google.cloud.aiplatform_v1beta1.services.prediction_service.client",
"method": "PredictionServiceClient",
"operation": "stream_generate_content",
},
"PREDICTION_SERVICE_STREAM_GENERATE_CONTENT": {
"module": "google.cloud.aiplatform_v1.services.prediction_service.client",
"method": "PredictionServiceClient",
"operation": "stream_generate_content",
},
}
42 changes: 36 additions & 6 deletions src/langtrace_python_sdk/instrumentation/vertexai/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,13 @@ def patch_vertexai(name, version, tracer: Tracer):
def traced_method(wrapped, instance, args, kwargs):
service_provider = SERVICE_PROVIDERS["VERTEXAI"]
prompts = serialize_prompts(args, kwargs)

span_attributes = {
**get_langtrace_attributes(version, service_provider),
**get_llm_request_attributes(
kwargs,
prompts=prompts,
model=get_llm_model(instance),
model=get_llm_model(instance, kwargs),
),
**get_llm_url(instance),
SpanAttributes.LLM_PATH: "",
Expand Down Expand Up @@ -77,6 +78,10 @@ def set_response_attributes(span: Span, result):
if hasattr(result, "text"):
set_event_completion(span, [{"role": "assistant", "content": result.text}])

if hasattr(result, "candidates"):
parts = result.candidates[0].content.parts
set_event_completion(span, [{"role": "assistant", "content": parts[0].text}])

if hasattr(result, "usage_metadata") and result.usage_metadata is not None:
usage = result.usage_metadata
input_tokens = usage.prompt_token_count
Expand All @@ -96,17 +101,23 @@ def set_response_attributes(span: Span, result):


def is_streaming_response(response):
return isinstance(response, types.GeneratorType) or isinstance(
response, types.AsyncGeneratorType
return (
isinstance(response, types.GeneratorType)
or isinstance(response, types.AsyncGeneratorType)
or str(type(response).__name__) == "_StreamingResponseIterator"
)


def get_llm_model(instance):
def get_llm_model(instance, kwargs):
if "request" in kwargs:
return kwargs.get("request").model.split("/")[-1]

if hasattr(instance, "_model_name"):
return instance._model_name.replace("publishers/google/models/", "")
return getattr(instance, "_model_id", "unknown")


@silently_fail
def serialize_prompts(args, kwargs):
if args and len(args) > 0:
prompt_parts = []
Expand All @@ -122,5 +133,24 @@ def serialize_prompts(args, kwargs):

return [{"role": "user", "content": "\n".join(prompt_parts)}]
else:
content = kwargs.get("prompt") or kwargs.get("message")
return [{"role": "user", "content": content}] if content else []
# Handle PredictionServiceClient for google-cloud-aiplatform.
if "request" in kwargs:
prompt = []
prompt_body = kwargs.get("request")
if prompt_body.system_instruction:
for part in prompt_body.system_instruction.parts:
prompt.append({"role": "system", "content": part.text})

contents = prompt_body.contents

if not contents:
return []

for c in contents:
role = c.role if c.role else "user"
content = c.parts[0].text if c.parts else ""
prompt.append({"role": role, "content": content})
return prompt
else:
content = kwargs.get("prompt") or kwargs.get("message")
return [{"role": "user", "content": content}] if content else []
43 changes: 30 additions & 13 deletions src/langtrace_python_sdk/utils/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,18 +395,30 @@ def build_streaming_response(self, chunk):
content = [chunk.text]

# CohereV2
if (hasattr(chunk, "delta") and
chunk.delta is not None and
hasattr(chunk.delta, "message") and
chunk.delta.message is not None and
hasattr(chunk.delta.message, "content") and
chunk.delta.message.content is not None and
hasattr(chunk.delta.message.content, "text") and
chunk.delta.message.content.text is not None):
if (
hasattr(chunk, "delta")
and chunk.delta is not None
and hasattr(chunk.delta, "message")
and chunk.delta.message is not None
and hasattr(chunk.delta.message, "content")
and chunk.delta.message.content is not None
and hasattr(chunk.delta.message.content, "text")
and chunk.delta.message.content.text is not None
):
content = [chunk.delta.message.content.text]

# google-cloud-aiplatform
if hasattr(chunk, "candidates") and chunk.candidates is not None:
for candidate in chunk.candidates:
if hasattr(candidate, "content") and candidate.content is not None:
for part in candidate.content.parts:
if hasattr(part, "text") and part.text is not None:
content.append(part.text)
# Anthropic
if hasattr(chunk, "delta") and chunk.delta is not None and not hasattr(chunk.delta, "message"):
if (
hasattr(chunk, "delta")
and chunk.delta is not None
and not hasattr(chunk.delta, "message")
):
content = [chunk.delta.text] if hasattr(chunk.delta, "text") else []

if isinstance(chunk, dict):
Expand All @@ -425,9 +437,14 @@ def set_usage_attributes(self, chunk):

# CohereV2
if hasattr(chunk, "type") and chunk.type == "message-end":
if (hasattr(chunk, "delta") and chunk.delta is not None and
hasattr(chunk.delta, "usage") and chunk.delta.usage is not None and
hasattr(chunk.delta.usage, "billed_units") and chunk.delta.usage.billed_units is not None):
if (
hasattr(chunk, "delta")
and chunk.delta is not None
and hasattr(chunk.delta, "usage")
and chunk.delta.usage is not None
and hasattr(chunk.delta.usage, "billed_units")
and chunk.delta.usage.billed_units is not None
):
usage = chunk.delta.usage.billed_units
self.completion_tokens = int(usage.output_tokens)
self.prompt_tokens = int(usage.input_tokens)
Expand Down
2 changes: 1 addition & 1 deletion src/langtrace_python_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.4.0"
__version__ = "3.5.0"
Loading