Skip to content

Commit e0bc5e8

Browse files
authored
Merge pull request #432 from Scale3-Labs/obinna/S3EN-2920-update-cohere-instrumentation
Obinna/s3 en 2920 update cohere instrumentation
2 parents 1bc18e7 + 82a9bc5 commit e0bc5e8

File tree

10 files changed

+245
-14
lines changed

10 files changed

+245
-14
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from examples.cohere_example.chat import chat_comp
2+
from examples.cohere_example.chatv2 import chat_v2
3+
from examples.cohere_example.chat_streamv2 import chat_stream_v2
24
from examples.cohere_example.chat_stream import chat_stream
35
from examples.cohere_example.tools import tool_calling
46
from examples.cohere_example.embed import embed
57
from examples.cohere_example.rerank import rerank
8+
from examples.cohere_example.rerankv2 import rerank_v2
69
from langtrace_python_sdk import with_langtrace_root_span
710

811

912
class CohereRunner:
1013

1114
@with_langtrace_root_span("Cohere")
1215
def run(self):
16+
chat_v2()
17+
chat_stream_v2()
1318
chat_comp()
1419
chat_stream()
1520
tool_calling()
1621
embed()
1722
rerank()
23+
rerank_v2()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import os
2+
from langtrace_python_sdk import langtrace
3+
import cohere
4+
5+
langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY"))
6+
co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
7+
8+
def chat_stream_v2():
9+
res = co.chat_stream(
10+
model="command-r-plus-08-2024",
11+
messages=[{"role": "user", "content": "Write a title for a blog post about API design. Only output the title text"}],
12+
)
13+
14+
for event in res:
15+
if event:
16+
if event.type == "content-delta":
17+
print(event.delta.message.content.text)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
from langtrace_python_sdk import langtrace
3+
import cohere
4+
5+
langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY"))
6+
7+
8+
def chat_v2():
9+
co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
10+
11+
res = co.chat(
12+
model="command-r-plus-08-2024",
13+
messages=[
14+
{
15+
"role": "user",
16+
"content": "Write a title for a blog post about API design. Only output the title text.",
17+
}
18+
],
19+
)
20+
21+
print(res.message.content[0].text)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
from langtrace_python_sdk import langtrace
3+
import cohere
4+
5+
langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY"))
6+
co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
7+
8+
docs = [
9+
"Carson City is the capital city of the American state of Nevada.",
10+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
11+
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
12+
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
13+
"Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
14+
]
15+
16+
def rerank_v2():
17+
response = co.rerank(
18+
model="rerank-v3.5",
19+
query="What is the capital of the United States?",
20+
documents=docs,
21+
top_n=3,
22+
)
23+
print(response)

src/langtrace_python_sdk/constants/instrumentation/cohere.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,39 @@
44
"METHOD": "cohere.client.chat",
55
"ENDPOINT": "/v1/chat",
66
},
7+
"CHAT_CREATE_V2": {
8+
"URL": "https://api.cohere.ai",
9+
"METHOD": "cohere.client_v2.chat",
10+
"ENDPOINT": "/v2/chat",
11+
},
712
"EMBED": {
813
"URL": "https://api.cohere.ai",
914
"METHOD": "cohere.client.embed",
1015
"ENDPOINT": "/v1/embed",
1116
},
17+
"EMBED_V2": {
18+
"URL": "https://api.cohere.ai",
19+
"METHOD": "cohere.client_v2.embed",
20+
"ENDPOINT": "/v2/embed",
21+
},
1222
"CHAT_STREAM": {
1323
"URL": "https://api.cohere.ai",
1424
"METHOD": "cohere.client.chat_stream",
15-
"ENDPOINT": "/v1/messages",
25+
"ENDPOINT": "/v1/chat",
26+
},
27+
"CHAT_STREAM_V2": {
28+
"URL": "https://api.cohere.ai",
29+
"METHOD": "cohere.client_v2.chat_stream",
30+
"ENDPOINT": "/v2/chat",
1631
},
1732
"RERANK": {
1833
"URL": "https://api.cohere.ai",
1934
"METHOD": "cohere.client.rerank",
2035
"ENDPOINT": "/v1/rerank",
2136
},
37+
"RERANK_V2": {
38+
"URL": "https://api.cohere.ai",
39+
"METHOD": "cohere.client_v2.rerank",
40+
"ENDPOINT": "/v2/rerank",
41+
},
2242
}

src/langtrace_python_sdk/instrumentation/cohere/instrumentation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from langtrace_python_sdk.instrumentation.cohere.patch import (
2525
chat_create,
26+
chat_create_v2,
2627
chat_stream,
2728
embed,
2829
rerank,
@@ -48,6 +49,18 @@ def _instrument(self, **kwargs):
4849
chat_create("cohere.client.chat", version, tracer),
4950
)
5051

52+
wrap_function_wrapper(
53+
"cohere.client_v2",
54+
"ClientV2.chat",
55+
chat_create_v2("cohere.client_v2.chat", version, tracer),
56+
)
57+
58+
wrap_function_wrapper(
59+
"cohere.client_v2",
60+
"ClientV2.chat_stream",
61+
chat_create_v2("cohere.client_v2.chat", version, tracer, stream=True),
62+
)
63+
5164
wrap_function_wrapper(
5265
"cohere.client",
5366
"Client.chat_stream",
@@ -60,12 +73,24 @@ def _instrument(self, **kwargs):
6073
embed("cohere.client.embed", version, tracer),
6174
)
6275

76+
wrap_function_wrapper(
77+
"cohere.client_v2",
78+
"ClientV2.embed",
79+
embed("cohere.client.embed", version, tracer, v2=True),
80+
)
81+
6382
wrap_function_wrapper(
6483
"cohere.client",
6584
"Client.rerank",
6685
rerank("cohere.client.rerank", version, tracer),
6786
)
6887

88+
wrap_function_wrapper(
89+
"cohere.client_v2",
90+
"ClientV2.rerank",
91+
rerank("cohere.client.rerank", version, tracer, v2=True),
92+
)
93+
6994
def _instrument_module(self, module_name):
7095
pass
7196

src/langtrace_python_sdk/instrumentation/cohere/patch.py

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
get_span_name,
2525
set_event_completion,
2626
set_usage_attributes,
27+
StreamWrapper
2728
)
2829
from langtrace.trace_attributes import Event, LLMSpanAttributes
2930
from langtrace_python_sdk.utils import set_span_attribute
@@ -38,7 +39,7 @@
3839
from langtrace.trace_attributes import SpanAttributes
3940

4041

41-
def rerank(original_method, version, tracer):
42+
def rerank(original_method, version, tracer, v2=False):
4243
"""Wrap the `rerank` method."""
4344

4445
def traced_method(wrapped, instance, args, kwargs):
@@ -49,8 +50,8 @@ def traced_method(wrapped, instance, args, kwargs):
4950
**get_llm_request_attributes(kwargs, operation_name="rerank"),
5051
**get_llm_url(instance),
5152
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus",
52-
SpanAttributes.LLM_URL: APIS["RERANK"]["URL"],
53-
SpanAttributes.LLM_PATH: APIS["RERANK"]["ENDPOINT"],
53+
SpanAttributes.LLM_URL: APIS["RERANK" if not v2 else "RERANK_V2"]["URL"],
54+
SpanAttributes.LLM_PATH: APIS["RERANK" if not v2 else "RERANK_V2"]["ENDPOINT"],
5455
SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps(
5556
kwargs.get("documents"), cls=datetime_encoder
5657
),
@@ -61,7 +62,7 @@ def traced_method(wrapped, instance, args, kwargs):
6162
attributes = LLMSpanAttributes(**span_attributes)
6263

6364
span = tracer.start_span(
64-
name=get_span_name(APIS["RERANK"]["METHOD"]), kind=SpanKind.CLIENT
65+
name=get_span_name(APIS["RERANK" if not v2 else "RERANK_V2"]["METHOD"]), kind=SpanKind.CLIENT
6566
)
6667
for field, value in attributes.model_dump(by_alias=True).items():
6768
set_span_attribute(span, field, value)
@@ -119,7 +120,7 @@ def traced_method(wrapped, instance, args, kwargs):
119120
return traced_method
120121

121122

122-
def embed(original_method, version, tracer):
123+
def embed(original_method, version, tracer, v2=False):
123124
"""Wrap the `embed` method."""
124125

125126
def traced_method(wrapped, instance, args, kwargs):
@@ -129,8 +130,8 @@ def traced_method(wrapped, instance, args, kwargs):
129130
**get_langtrace_attributes(version, service_provider),
130131
**get_llm_request_attributes(kwargs, operation_name="embed"),
131132
**get_llm_url(instance),
132-
SpanAttributes.LLM_URL: APIS["EMBED"]["URL"],
133-
SpanAttributes.LLM_PATH: APIS["EMBED"]["ENDPOINT"],
133+
SpanAttributes.LLM_URL: APIS["EMBED" if not v2 else "EMBED_V2"]["URL"],
134+
SpanAttributes.LLM_PATH: APIS["EMBED" if not v2 else "EMBED_V2"]["ENDPOINT"],
134135
SpanAttributes.LLM_REQUEST_EMBEDDING_INPUTS: json.dumps(
135136
kwargs.get("texts")
136137
),
@@ -143,7 +144,7 @@ def traced_method(wrapped, instance, args, kwargs):
143144
attributes = LLMSpanAttributes(**span_attributes)
144145

145146
span = tracer.start_span(
146-
name=get_span_name(APIS["EMBED"]["METHOD"]),
147+
name=get_span_name(APIS["EMBED" if not v2 else "EMBED_V2"]["METHOD"]),
147148
kind=SpanKind.CLIENT,
148149
)
149150
for field, value in attributes.model_dump(by_alias=True).items():
@@ -343,6 +344,103 @@ def traced_method(wrapped, instance, args, kwargs):
343344
return traced_method
344345

345346

347+
def chat_create_v2(original_method, version, tracer, stream=False):
348+
"""Wrap the `chat_create` method for Cohere API v2."""
349+
350+
def traced_method(wrapped, instance, args, kwargs):
351+
service_provider = SERVICE_PROVIDERS["COHERE"]
352+
353+
messages = kwargs.get("messages", [])
354+
if kwargs.get("preamble"):
355+
messages = [{"role": "system", "content": kwargs["preamble"]}] + messages
356+
357+
span_attributes = {
358+
**get_langtrace_attributes(version, service_provider),
359+
**get_llm_request_attributes(kwargs, prompts=messages),
360+
**get_llm_url(instance),
361+
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus",
362+
SpanAttributes.LLM_URL: APIS["CHAT_CREATE_V2"]["URL"],
363+
SpanAttributes.LLM_PATH: APIS["CHAT_CREATE_V2"]["ENDPOINT"],
364+
**get_extra_attributes(),
365+
}
366+
367+
attributes = LLMSpanAttributes(**span_attributes)
368+
369+
for attr_name in ["max_input_tokens", "conversation_id", "connectors", "tools", "tool_results"]:
370+
value = kwargs.get(attr_name)
371+
if value is not None:
372+
if attr_name == "max_input_tokens":
373+
attributes.llm_max_input_tokens = str(value)
374+
elif attr_name == "conversation_id":
375+
attributes.conversation_id = value
376+
else:
377+
setattr(attributes, f"llm_{attr_name}", json.dumps(value))
378+
379+
span = tracer.start_span(
380+
name=get_span_name(APIS["CHAT_CREATE_V2"]["METHOD"]),
381+
kind=SpanKind.CLIENT
382+
)
383+
384+
for field, value in attributes.model_dump(by_alias=True).items():
385+
set_span_attribute(span, field, value)
386+
387+
try:
388+
result = wrapped(*args, **kwargs)
389+
390+
if stream:
391+
return StreamWrapper(
392+
result,
393+
span,
394+
tool_calls=kwargs.get("tools") is not None,
395+
)
396+
else:
397+
if hasattr(result, "id") and result.id is not None:
398+
span.set_attribute(SpanAttributes.LLM_GENERATION_ID, result.id)
399+
span.set_attribute(SpanAttributes.LLM_RESPONSE_ID, result.id)
400+
401+
if (hasattr(result, "message") and
402+
hasattr(result.message, "content") and
403+
len(result.message.content) > 0 and
404+
hasattr(result.message.content[0], "text") and
405+
result.message.content[0].text is not None and
406+
result.message.content[0].text != ""):
407+
responses = [{
408+
"role": result.message.role,
409+
"content": result.message.content[0].text
410+
}]
411+
set_event_completion(span, responses)
412+
if hasattr(result, "tool_calls") and result.tool_calls is not None:
413+
tool_calls = [tool_call.json() for tool_call in result.tool_calls]
414+
span.set_attribute(
415+
SpanAttributes.LLM_TOOL_RESULTS,
416+
json.dumps(tool_calls)
417+
)
418+
if hasattr(result, "usage") and result.usage is not None:
419+
if (hasattr(result.usage, "billed_units") and
420+
result.usage.billed_units is not None):
421+
usage = result.usage.billed_units
422+
for metric, value in {
423+
"input": usage.input_tokens or 0,
424+
"output": usage.output_tokens or 0,
425+
"total": (usage.input_tokens or 0) + (usage.output_tokens or 0),
426+
}.items():
427+
span.set_attribute(
428+
f"gen_ai.usage.{metric}_tokens",
429+
int(value)
430+
)
431+
span.set_status(StatusCode.OK)
432+
span.end()
433+
return result
434+
435+
except Exception as error:
436+
span.record_exception(error)
437+
span.set_status(Status(StatusCode.ERROR, str(error)))
438+
span.end()
439+
raise
440+
441+
return traced_method
442+
443+
346444
def chat_stream(original_method, version, tracer):
347445
"""Wrap the `messages_stream` method."""
348446

src/langtrace_python_sdk/utils/llm.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,8 +394,19 @@ def build_streaming_response(self, chunk):
394394
if hasattr(chunk, "text") and chunk.text is not None:
395395
content = [chunk.text]
396396

397+
# CohereV2
398+
if (hasattr(chunk, "delta") and
399+
chunk.delta is not None and
400+
hasattr(chunk.delta, "message") and
401+
chunk.delta.message is not None and
402+
hasattr(chunk.delta.message, "content") and
403+
chunk.delta.message.content is not None and
404+
hasattr(chunk.delta.message.content, "text") and
405+
chunk.delta.message.content.text is not None):
406+
content = [chunk.delta.message.content.text]
407+
397408
# Anthropic
398-
if hasattr(chunk, "delta") and chunk.delta is not None:
409+
if hasattr(chunk, "delta") and chunk.delta is not None and not hasattr(chunk.delta, "message"):
399410
content = [chunk.delta.text] if hasattr(chunk.delta, "text") else []
400411

401412
if isinstance(chunk, dict):
@@ -409,7 +420,17 @@ def set_usage_attributes(self, chunk):
409420

410421
# Anthropic & OpenAI
411422
if hasattr(chunk, "type") and chunk.type == "message_start":
412-
self.prompt_tokens = chunk.message.usage.input_tokens
423+
if hasattr(chunk.message, "usage") and chunk.message.usage is not None:
424+
self.prompt_tokens = chunk.message.usage.input_tokens
425+
426+
# CohereV2
427+
if hasattr(chunk, "type") and chunk.type == "message-end":
428+
if (hasattr(chunk, "delta") and chunk.delta is not None and
429+
hasattr(chunk.delta, "usage") and chunk.delta.usage is not None and
430+
hasattr(chunk.delta.usage, "billed_units") and chunk.delta.usage.billed_units is not None):
431+
usage = chunk.delta.usage.billed_units
432+
self.completion_tokens = int(usage.output_tokens)
433+
self.prompt_tokens = int(usage.input_tokens)
413434

414435
if hasattr(chunk, "usage") and chunk.usage is not None:
415436
if hasattr(chunk.usage, "output_tokens"):
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.3.14"
1+
__version__ = "3.3.15"

0 commit comments

Comments
 (0)