Skip to content

Commit 45ae162

Browse files
committed
add support for cohere client v2
1 parent 8959561 commit 45ae162

File tree

4 files changed

+175
-11
lines changed

4 files changed

+175
-11
lines changed

src/langtrace_python_sdk/constants/instrumentation/cohere.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,39 @@
44
"METHOD": "cohere.client.chat",
55
"ENDPOINT": "/v1/chat",
66
},
7+
"CHAT_CREATE_V2": {
8+
"URL": "https://api.cohere.ai",
9+
"METHOD": "cohere.client_v2.chat",
10+
"ENDPOINT": "/v2/chat",
11+
},
712
"EMBED": {
813
"URL": "https://api.cohere.ai",
914
"METHOD": "cohere.client.embed",
1015
"ENDPOINT": "/v1/embed",
1116
},
17+
"EMBED_V2": {
18+
"URL": "https://api.cohere.ai",
19+
"METHOD": "cohere.client_v2.embed",
20+
"ENDPOINT": "/v2/embed",
21+
},
1222
"CHAT_STREAM": {
1323
"URL": "https://api.cohere.ai",
1424
"METHOD": "cohere.client.chat_stream",
15-
"ENDPOINT": "/v1/messages",
25+
"ENDPOINT": "/v1/chat",
26+
},
27+
"CHAT_STREAM_V2": {
28+
"URL": "https://api.cohere.ai",
29+
"METHOD": "cohere.client_v2.chat_stream",
30+
"ENDPOINT": "/v2/chat",
1631
},
1732
"RERANK": {
1833
"URL": "https://api.cohere.ai",
1934
"METHOD": "cohere.client.rerank",
2035
"ENDPOINT": "/v1/rerank",
2136
},
37+
"RERANK_V2": {
38+
"URL": "https://api.cohere.ai",
39+
"METHOD": "cohere.client_v2.rerank",
40+
"ENDPOINT": "/v2/rerank",
41+
},
2242
}

src/langtrace_python_sdk/instrumentation/cohere/instrumentation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from langtrace_python_sdk.instrumentation.cohere.patch import (
2525
chat_create,
26+
chat_create_v2,
2627
chat_stream,
2728
embed,
2829
rerank,
@@ -48,6 +49,18 @@ def _instrument(self, **kwargs):
4849
chat_create("cohere.client.chat", version, tracer),
4950
)
5051

52+
wrap_function_wrapper(
53+
"cohere.client_v2",
54+
"ClientV2.chat",
55+
chat_create_v2("cohere.client_v2.chat", version, tracer),
56+
)
57+
58+
wrap_function_wrapper(
59+
"cohere.client_v2",
60+
"ClientV2.chat_stream",
61+
chat_create_v2("cohere.client_v2.chat", version, tracer, stream=True),
62+
)
63+
5164
wrap_function_wrapper(
5265
"cohere.client",
5366
"Client.chat_stream",
@@ -60,12 +73,24 @@ def _instrument(self, **kwargs):
6073
embed("cohere.client.embed", version, tracer),
6174
)
6275

76+
wrap_function_wrapper(
77+
"cohere.client_v2",
78+
"ClientV2.embed",
79+
embed("cohere.client.embed", version, tracer, v2=True),
80+
)
81+
6382
wrap_function_wrapper(
6483
"cohere.client",
6584
"Client.rerank",
6685
rerank("cohere.client.rerank", version, tracer),
6786
)
6887

88+
wrap_function_wrapper(
89+
"cohere.client_v2",
90+
"ClientV2.rerank",
91+
rerank("cohere.client.rerank", version, tracer, v2=True),
92+
)
93+
6994
def _instrument_module(self, module_name):
7095
pass
7196

src/langtrace_python_sdk/instrumentation/cohere/patch.py

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
get_span_name,
2525
set_event_completion,
2626
set_usage_attributes,
27+
StreamWrapper
2728
)
2829
from langtrace.trace_attributes import Event, LLMSpanAttributes
2930
from langtrace_python_sdk.utils import set_span_attribute
@@ -38,7 +39,7 @@
3839
from langtrace.trace_attributes import SpanAttributes
3940

4041

41-
def rerank(original_method, version, tracer):
42+
def rerank(original_method, version, tracer, v2=False):
4243
"""Wrap the `rerank` method."""
4344

4445
def traced_method(wrapped, instance, args, kwargs):
@@ -49,8 +50,8 @@ def traced_method(wrapped, instance, args, kwargs):
4950
**get_llm_request_attributes(kwargs, operation_name="rerank"),
5051
**get_llm_url(instance),
5152
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus",
52-
SpanAttributes.LLM_URL: APIS["RERANK"]["URL"],
53-
SpanAttributes.LLM_PATH: APIS["RERANK"]["ENDPOINT"],
53+
SpanAttributes.LLM_URL: APIS["RERANK" if not v2 else "RERANK_V2"]["URL"],
54+
SpanAttributes.LLM_PATH: APIS["RERANK" if not v2 else "RERANK_V2"]["ENDPOINT"],
5455
SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps(
5556
kwargs.get("documents"), cls=datetime_encoder
5657
),
@@ -61,7 +62,7 @@ def traced_method(wrapped, instance, args, kwargs):
6162
attributes = LLMSpanAttributes(**span_attributes)
6263

6364
span = tracer.start_span(
64-
name=get_span_name(APIS["RERANK"]["METHOD"]), kind=SpanKind.CLIENT
65+
name=get_span_name(APIS["RERANK" if not v2 else "RERANK_V2"]["METHOD"]), kind=SpanKind.CLIENT
6566
)
6667
for field, value in attributes.model_dump(by_alias=True).items():
6768
set_span_attribute(span, field, value)
@@ -119,7 +120,7 @@ def traced_method(wrapped, instance, args, kwargs):
119120
return traced_method
120121

121122

122-
def embed(original_method, version, tracer):
123+
def embed(original_method, version, tracer, v2=False):
123124
"""Wrap the `embed` method."""
124125

125126
def traced_method(wrapped, instance, args, kwargs):
@@ -129,8 +130,8 @@ def traced_method(wrapped, instance, args, kwargs):
129130
**get_langtrace_attributes(version, service_provider),
130131
**get_llm_request_attributes(kwargs, operation_name="embed"),
131132
**get_llm_url(instance),
132-
SpanAttributes.LLM_URL: APIS["EMBED"]["URL"],
133-
SpanAttributes.LLM_PATH: APIS["EMBED"]["ENDPOINT"],
133+
SpanAttributes.LLM_URL: APIS["EMBED" if not v2 else "EMBED_V2"]["URL"],
134+
SpanAttributes.LLM_PATH: APIS["EMBED" if not v2 else "EMBED_V2"]["ENDPOINT"],
134135
SpanAttributes.LLM_REQUEST_EMBEDDING_INPUTS: json.dumps(
135136
kwargs.get("texts")
136137
),
@@ -143,7 +144,7 @@ def traced_method(wrapped, instance, args, kwargs):
143144
attributes = LLMSpanAttributes(**span_attributes)
144145

145146
span = tracer.start_span(
146-
name=get_span_name(APIS["EMBED"]["METHOD"]),
147+
name=get_span_name(APIS["EMBED" if not v2 else "EMBED_V2"]["METHOD"]),
147148
kind=SpanKind.CLIENT,
148149
)
149150
for field, value in attributes.model_dump(by_alias=True).items():
@@ -343,6 +344,103 @@ def traced_method(wrapped, instance, args, kwargs):
343344
return traced_method
344345

345346

347+
def chat_create_v2(original_method, version, tracer, stream=False):
348+
"""Wrap the `chat_create` method for Cohere API v2."""
349+
350+
def traced_method(wrapped, instance, args, kwargs):
351+
service_provider = SERVICE_PROVIDERS["COHERE"]
352+
353+
messages = kwargs.get("messages", [])
354+
if kwargs.get("preamble"):
355+
messages = [{"role": "system", "content": kwargs["preamble"]}] + messages
356+
357+
span_attributes = {
358+
**get_langtrace_attributes(version, service_provider),
359+
**get_llm_request_attributes(kwargs, prompts=messages),
360+
**get_llm_url(instance),
361+
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus",
362+
SpanAttributes.LLM_URL: APIS["CHAT_CREATE_V2"]["URL"],
363+
SpanAttributes.LLM_PATH: APIS["CHAT_CREATE_V2"]["ENDPOINT"],
364+
**get_extra_attributes(),
365+
}
366+
367+
attributes = LLMSpanAttributes(**span_attributes)
368+
369+
for attr_name in ["max_input_tokens", "conversation_id", "connectors", "tools", "tool_results"]:
370+
value = kwargs.get(attr_name)
371+
if value is not None:
372+
if attr_name == "max_input_tokens":
373+
attributes.llm_max_input_tokens = str(value)
374+
elif attr_name == "conversation_id":
375+
attributes.conversation_id = value
376+
else:
377+
setattr(attributes, f"llm_{attr_name}", json.dumps(value))
378+
379+
span = tracer.start_span(
380+
name=get_span_name(APIS["CHAT_CREATE_V2"]["METHOD"]),
381+
kind=SpanKind.CLIENT
382+
)
383+
384+
for field, value in attributes.model_dump(by_alias=True).items():
385+
set_span_attribute(span, field, value)
386+
387+
try:
388+
result = wrapped(*args, **kwargs)
389+
390+
if stream:
391+
return StreamWrapper(
392+
result,
393+
span,
394+
tool_calls=kwargs.get("tools") is not None,
395+
)
396+
else:
397+
if hasattr(result, "id") and result.id is not None:
398+
span.set_attribute(SpanAttributes.LLM_GENERATION_ID, result.id)
399+
span.set_attribute(SpanAttributes.LLM_RESPONSE_ID, result.id)
400+
401+
if (hasattr(result, "message") and
402+
hasattr(result.message, "content") and
403+
len(result.message.content) > 0 and
404+
hasattr(result.message.content[0], "text") and
405+
result.message.content[0].text is not None and
406+
result.message.content[0].text != ""):
407+
responses = [{
408+
"role": result.message.role,
409+
"content": result.message.content[0].text
410+
}]
411+
set_event_completion(span, responses)
412+
if hasattr(result, "tool_calls") and result.tool_calls is not None:
413+
tool_calls = [tool_call.json() for tool_call in result.tool_calls]
414+
span.set_attribute(
415+
SpanAttributes.LLM_TOOL_RESULTS,
416+
json.dumps(tool_calls)
417+
)
418+
if hasattr(result, "usage") and result.usage is not None:
419+
if (hasattr(result.usage, "billed_units") and
420+
result.usage.billed_units is not None):
421+
usage = result.usage.billed_units
422+
for metric, value in {
423+
"input": usage.input_tokens or 0,
424+
"output": usage.output_tokens or 0,
425+
"total": (usage.input_tokens or 0) + (usage.output_tokens or 0),
426+
}.items():
427+
span.set_attribute(
428+
f"gen_ai.usage.{metric}_tokens",
429+
int(value)
430+
)
431+
span.set_status(StatusCode.OK)
432+
span.end()
433+
return result
434+
435+
except Exception as error:
436+
span.record_exception(error)
437+
span.set_status(Status(StatusCode.ERROR, str(error)))
438+
span.end()
439+
raise
440+
441+
return traced_method
442+
443+
346444
def chat_stream(original_method, version, tracer):
347445
"""Wrap the `messages_stream` method."""
348446

src/langtrace_python_sdk/utils/llm.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,8 +393,19 @@ def build_streaming_response(self, chunk):
393393
if hasattr(chunk, "text") and chunk.text is not None:
394394
content = [chunk.text]
395395

396+
# CohereV2
397+
if (hasattr(chunk, "delta") and
398+
chunk.delta is not None and
399+
hasattr(chunk.delta, "message") and
400+
chunk.delta.message is not None and
401+
hasattr(chunk.delta.message, "content") and
402+
chunk.delta.message.content is not None and
403+
hasattr(chunk.delta.message.content, "text") and
404+
chunk.delta.message.content.text is not None):
405+
content = [chunk.delta.message.content.text]
406+
396407
# Anthropic
397-
if hasattr(chunk, "delta") and chunk.delta is not None:
408+
if hasattr(chunk, "delta") and chunk.delta is not None and not hasattr(chunk.delta, "message"):
398409
content = [chunk.delta.text] if hasattr(chunk.delta, "text") else []
399410

400411
if isinstance(chunk, dict):
@@ -408,7 +419,17 @@ def set_usage_attributes(self, chunk):
408419

409420
# Anthropic & OpenAI
410421
if hasattr(chunk, "type") and chunk.type == "message_start":
411-
self.prompt_tokens = chunk.message.usage.input_tokens
422+
if hasattr(chunk.message, "usage") and chunk.message.usage is not None:
423+
self.prompt_tokens = chunk.message.usage.input_tokens
424+
425+
# CohereV2
426+
if hasattr(chunk, "type") and chunk.type == "message-end":
427+
if (hasattr(chunk, "delta") and chunk.delta is not None and
428+
hasattr(chunk.delta, "usage") and chunk.delta.usage is not None and
429+
hasattr(chunk.delta.usage, "billed_units") and chunk.delta.usage.billed_units is not None):
430+
usage = chunk.delta.usage.billed_units
431+
self.completion_tokens = int(usage.output_tokens)
432+
self.prompt_tokens = int(usage.input_tokens)
412433

413434
if hasattr(chunk, "usage") and chunk.usage is not None:
414435
if hasattr(chunk.usage, "output_tokens"):

0 commit comments

Comments
 (0)