Skip to content

Commit f217b66

Browse files
committed
Merge branch 'development' of github.com:Scale3-Labs/langtrace-python-sdk into development
2 parents 8249964 + 883d9d7 commit f217b66

File tree

19 files changed

+681
-27
lines changed

19 files changed

+681
-27
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,23 @@
11
from examples.cohere_example.chat import chat_comp
2+
from examples.cohere_example.chatv2 import chat_v2
3+
from examples.cohere_example.chat_streamv2 import chat_stream_v2
24
from examples.cohere_example.chat_stream import chat_stream
35
from examples.cohere_example.tools import tool_calling
46
from examples.cohere_example.embed import embed
57
from examples.cohere_example.rerank import rerank
8+
from examples.cohere_example.rerankv2 import rerank_v2
69
from langtrace_python_sdk import with_langtrace_root_span
710

811

912
class CohereRunner:
1013

1114
@with_langtrace_root_span("Cohere")
1215
def run(self):
16+
chat_v2()
17+
chat_stream_v2()
1318
chat_comp()
1419
chat_stream()
1520
tool_calling()
1621
embed()
1722
rerank()
23+
rerank_v2()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import os
2+
from langtrace_python_sdk import langtrace
3+
import cohere
4+
5+
langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY"))
6+
co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
7+
8+
def chat_stream_v2():
9+
res = co.chat_stream(
10+
model="command-r-plus-08-2024",
11+
messages=[{"role": "user", "content": "Write a title for a blog post about API design. Only output the title text"}],
12+
)
13+
14+
for event in res:
15+
if event:
16+
if event.type == "content-delta":
17+
print(event.delta.message.content.text)
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import os
2+
from langtrace_python_sdk import langtrace
3+
import cohere
4+
5+
langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY"))
6+
7+
8+
def chat_v2():
9+
co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
10+
11+
res = co.chat(
12+
model="command-r-plus-08-2024",
13+
messages=[
14+
{
15+
"role": "user",
16+
"content": "Write a title for a blog post about API design. Only output the title text.",
17+
}
18+
],
19+
)
20+
21+
print(res.message.content[0].text)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
from langtrace_python_sdk import langtrace
3+
import cohere
4+
5+
langtrace.init(api_key=os.getenv("LANGTRACE_API_KEY"))
6+
co = cohere.ClientV2(api_key=os.getenv("COHERE_API_KEY"))
7+
8+
docs = [
9+
"Carson City is the capital city of the American state of Nevada.",
10+
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean. Its capital is Saipan.",
11+
"Capitalization or capitalisation in English grammar is the use of a capital letter at the start of a word. English usage varies from capitalization in other languages.",
12+
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district.",
13+
"Capital punishment has existed in the United States since beforethe United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states.",
14+
]
15+
16+
def rerank_v2():
17+
response = co.rerank(
18+
model="rerank-v3.5",
19+
query="What is the capital of the United States?",
20+
documents=docs,
21+
top_n=3,
22+
)
23+
print(response)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
LANGTRACE_REMOTE_URL = "https://app.langtrace.ai"
1+
LANGTRACE_REMOTE_URL = "https://app.langtrace.ai"
2+
LANGTRACE_SESSION_ID_HEADER = "x-langtrace-session-id"

src/langtrace_python_sdk/constants/instrumentation/cohere.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,39 @@
44
"METHOD": "cohere.client.chat",
55
"ENDPOINT": "/v1/chat",
66
},
7+
"CHAT_CREATE_V2": {
8+
"URL": "https://api.cohere.ai",
9+
"METHOD": "cohere.client_v2.chat",
10+
"ENDPOINT": "/v2/chat",
11+
},
712
"EMBED": {
813
"URL": "https://api.cohere.ai",
914
"METHOD": "cohere.client.embed",
1015
"ENDPOINT": "/v1/embed",
1116
},
17+
"EMBED_V2": {
18+
"URL": "https://api.cohere.ai",
19+
"METHOD": "cohere.client_v2.embed",
20+
"ENDPOINT": "/v2/embed",
21+
},
1222
"CHAT_STREAM": {
1323
"URL": "https://api.cohere.ai",
1424
"METHOD": "cohere.client.chat_stream",
15-
"ENDPOINT": "/v1/messages",
25+
"ENDPOINT": "/v1/chat",
26+
},
27+
"CHAT_STREAM_V2": {
28+
"URL": "https://api.cohere.ai",
29+
"METHOD": "cohere.client_v2.chat_stream",
30+
"ENDPOINT": "/v2/chat",
1631
},
1732
"RERANK": {
1833
"URL": "https://api.cohere.ai",
1934
"METHOD": "cohere.client.rerank",
2035
"ENDPOINT": "/v1/rerank",
2136
},
37+
"RERANK_V2": {
38+
"URL": "https://api.cohere.ai",
39+
"METHOD": "cohere.client_v2.rerank",
40+
"ENDPOINT": "/v2/rerank",
41+
},
2242
}

src/langtrace_python_sdk/extensions/langtrace_exporter.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from langtrace_python_sdk.constants.exporter.langtrace_exporter import (
1111
LANGTRACE_REMOTE_URL,
12+
LANGTRACE_SESSION_ID_HEADER,
1213
)
1314
from colorama import Fore
1415
from requests.exceptions import RequestException
@@ -51,12 +52,14 @@ class LangTraceExporter(SpanExporter):
5152
api_key: str
5253
api_host: str
5354
disable_logging: bool
55+
session_id: str
5456

5557
def __init__(
5658
self,
5759
api_host,
5860
api_key: str = None,
5961
disable_logging: bool = False,
62+
session_id: str = None,
6063
) -> None:
6164
self.api_key = api_key or os.environ.get("LANGTRACE_API_KEY")
6265
self.api_host = (
@@ -65,6 +68,7 @@ def __init__(
6568
else api_host
6669
)
6770
self.disable_logging = disable_logging
71+
self.session_id = session_id or os.environ.get("LANGTRACE_SESSION_ID")
6872

6973
def export(self, spans: typing.Sequence[ReadableSpan]) -> SpanExportResult:
7074
"""
@@ -82,6 +86,10 @@ def export(self, spans: typing.Sequence[ReadableSpan]) -> SpanExportResult:
8286
"User-Agent": "LangtraceExporter",
8387
}
8488

89+
# Add session ID if available
90+
if self.session_id:
91+
headers[LANGTRACE_SESSION_ID_HEADER] = self.session_id
92+
8593
# Check if the OTEL_EXPORTER_OTLP_HEADERS environment variable is set
8694
otel_headers = os.getenv("OTEL_EXPORTER_OTLP_HEADERS", None)
8795
if otel_headers:

src/langtrace_python_sdk/instrumentation/cohere/instrumentation.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
from langtrace_python_sdk.instrumentation.cohere.patch import (
2525
chat_create,
26+
chat_create_v2,
2627
chat_stream,
2728
embed,
2829
rerank,
@@ -48,6 +49,18 @@ def _instrument(self, **kwargs):
4849
chat_create("cohere.client.chat", version, tracer),
4950
)
5051

52+
wrap_function_wrapper(
53+
"cohere.client_v2",
54+
"ClientV2.chat",
55+
chat_create_v2("cohere.client_v2.chat", version, tracer),
56+
)
57+
58+
wrap_function_wrapper(
59+
"cohere.client_v2",
60+
"ClientV2.chat_stream",
61+
chat_create_v2("cohere.client_v2.chat", version, tracer, stream=True),
62+
)
63+
5164
wrap_function_wrapper(
5265
"cohere.client",
5366
"Client.chat_stream",
@@ -60,12 +73,24 @@ def _instrument(self, **kwargs):
6073
embed("cohere.client.embed", version, tracer),
6174
)
6275

76+
wrap_function_wrapper(
77+
"cohere.client_v2",
78+
"ClientV2.embed",
79+
embed("cohere.client.embed", version, tracer, v2=True),
80+
)
81+
6382
wrap_function_wrapper(
6483
"cohere.client",
6584
"Client.rerank",
6685
rerank("cohere.client.rerank", version, tracer),
6786
)
6887

88+
wrap_function_wrapper(
89+
"cohere.client_v2",
90+
"ClientV2.rerank",
91+
rerank("cohere.client.rerank", version, tracer, v2=True),
92+
)
93+
6994
def _instrument_module(self, module_name):
7095
pass
7196

src/langtrace_python_sdk/instrumentation/cohere/patch.py

Lines changed: 106 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
get_span_name,
2525
set_event_completion,
2626
set_usage_attributes,
27+
StreamWrapper
2728
)
2829
from langtrace.trace_attributes import Event, LLMSpanAttributes
2930
from langtrace_python_sdk.utils import set_span_attribute
@@ -38,7 +39,7 @@
3839
from langtrace.trace_attributes import SpanAttributes
3940

4041

41-
def rerank(original_method, version, tracer):
42+
def rerank(original_method, version, tracer, v2=False):
4243
"""Wrap the `rerank` method."""
4344

4445
def traced_method(wrapped, instance, args, kwargs):
@@ -49,8 +50,8 @@ def traced_method(wrapped, instance, args, kwargs):
4950
**get_llm_request_attributes(kwargs, operation_name="rerank"),
5051
**get_llm_url(instance),
5152
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus",
52-
SpanAttributes.LLM_URL: APIS["RERANK"]["URL"],
53-
SpanAttributes.LLM_PATH: APIS["RERANK"]["ENDPOINT"],
53+
SpanAttributes.LLM_URL: APIS["RERANK" if not v2 else "RERANK_V2"]["URL"],
54+
SpanAttributes.LLM_PATH: APIS["RERANK" if not v2 else "RERANK_V2"]["ENDPOINT"],
5455
SpanAttributes.LLM_REQUEST_DOCUMENTS: json.dumps(
5556
kwargs.get("documents"), cls=datetime_encoder
5657
),
@@ -61,7 +62,7 @@ def traced_method(wrapped, instance, args, kwargs):
6162
attributes = LLMSpanAttributes(**span_attributes)
6263

6364
span = tracer.start_span(
64-
name=get_span_name(APIS["RERANK"]["METHOD"]), kind=SpanKind.CLIENT
65+
name=get_span_name(APIS["RERANK" if not v2 else "RERANK_V2"]["METHOD"]), kind=SpanKind.CLIENT
6566
)
6667
for field, value in attributes.model_dump(by_alias=True).items():
6768
set_span_attribute(span, field, value)
@@ -119,7 +120,7 @@ def traced_method(wrapped, instance, args, kwargs):
119120
return traced_method
120121

121122

122-
def embed(original_method, version, tracer):
123+
def embed(original_method, version, tracer, v2=False):
123124
"""Wrap the `embed` method."""
124125

125126
def traced_method(wrapped, instance, args, kwargs):
@@ -129,8 +130,8 @@ def traced_method(wrapped, instance, args, kwargs):
129130
**get_langtrace_attributes(version, service_provider),
130131
**get_llm_request_attributes(kwargs, operation_name="embed"),
131132
**get_llm_url(instance),
132-
SpanAttributes.LLM_URL: APIS["EMBED"]["URL"],
133-
SpanAttributes.LLM_PATH: APIS["EMBED"]["ENDPOINT"],
133+
SpanAttributes.LLM_URL: APIS["EMBED" if not v2 else "EMBED_V2"]["URL"],
134+
SpanAttributes.LLM_PATH: APIS["EMBED" if not v2 else "EMBED_V2"]["ENDPOINT"],
134135
SpanAttributes.LLM_REQUEST_EMBEDDING_INPUTS: json.dumps(
135136
kwargs.get("texts")
136137
),
@@ -143,7 +144,7 @@ def traced_method(wrapped, instance, args, kwargs):
143144
attributes = LLMSpanAttributes(**span_attributes)
144145

145146
span = tracer.start_span(
146-
name=get_span_name(APIS["EMBED"]["METHOD"]),
147+
name=get_span_name(APIS["EMBED" if not v2 else "EMBED_V2"]["METHOD"]),
147148
kind=SpanKind.CLIENT,
148149
)
149150
for field, value in attributes.model_dump(by_alias=True).items():
@@ -343,6 +344,103 @@ def traced_method(wrapped, instance, args, kwargs):
343344
return traced_method
344345

345346

347+
def chat_create_v2(original_method, version, tracer, stream=False):
348+
"""Wrap the `chat_create` method for Cohere API v2."""
349+
350+
def traced_method(wrapped, instance, args, kwargs):
351+
service_provider = SERVICE_PROVIDERS["COHERE"]
352+
353+
messages = kwargs.get("messages", [])
354+
if kwargs.get("preamble"):
355+
messages = [{"role": "system", "content": kwargs["preamble"]}] + messages
356+
357+
span_attributes = {
358+
**get_langtrace_attributes(version, service_provider),
359+
**get_llm_request_attributes(kwargs, prompts=messages),
360+
**get_llm_url(instance),
361+
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("model") or "command-r-plus",
362+
SpanAttributes.LLM_URL: APIS["CHAT_CREATE_V2"]["URL"],
363+
SpanAttributes.LLM_PATH: APIS["CHAT_CREATE_V2"]["ENDPOINT"],
364+
**get_extra_attributes(),
365+
}
366+
367+
attributes = LLMSpanAttributes(**span_attributes)
368+
369+
for attr_name in ["max_input_tokens", "conversation_id", "connectors", "tools", "tool_results"]:
370+
value = kwargs.get(attr_name)
371+
if value is not None:
372+
if attr_name == "max_input_tokens":
373+
attributes.llm_max_input_tokens = str(value)
374+
elif attr_name == "conversation_id":
375+
attributes.conversation_id = value
376+
else:
377+
setattr(attributes, f"llm_{attr_name}", json.dumps(value))
378+
379+
span = tracer.start_span(
380+
name=get_span_name(APIS["CHAT_CREATE_V2"]["METHOD"]),
381+
kind=SpanKind.CLIENT
382+
)
383+
384+
for field, value in attributes.model_dump(by_alias=True).items():
385+
set_span_attribute(span, field, value)
386+
387+
try:
388+
result = wrapped(*args, **kwargs)
389+
390+
if stream:
391+
return StreamWrapper(
392+
result,
393+
span,
394+
tool_calls=kwargs.get("tools") is not None,
395+
)
396+
else:
397+
if hasattr(result, "id") and result.id is not None:
398+
span.set_attribute(SpanAttributes.LLM_GENERATION_ID, result.id)
399+
span.set_attribute(SpanAttributes.LLM_RESPONSE_ID, result.id)
400+
401+
if (hasattr(result, "message") and
402+
hasattr(result.message, "content") and
403+
len(result.message.content) > 0 and
404+
hasattr(result.message.content[0], "text") and
405+
result.message.content[0].text is not None and
406+
result.message.content[0].text != ""):
407+
responses = [{
408+
"role": result.message.role,
409+
"content": result.message.content[0].text
410+
}]
411+
set_event_completion(span, responses)
412+
if hasattr(result, "tool_calls") and result.tool_calls is not None:
413+
tool_calls = [tool_call.json() for tool_call in result.tool_calls]
414+
span.set_attribute(
415+
SpanAttributes.LLM_TOOL_RESULTS,
416+
json.dumps(tool_calls)
417+
)
418+
if hasattr(result, "usage") and result.usage is not None:
419+
if (hasattr(result.usage, "billed_units") and
420+
result.usage.billed_units is not None):
421+
usage = result.usage.billed_units
422+
for metric, value in {
423+
"input": usage.input_tokens or 0,
424+
"output": usage.output_tokens or 0,
425+
"total": (usage.input_tokens or 0) + (usage.output_tokens or 0),
426+
}.items():
427+
span.set_attribute(
428+
f"gen_ai.usage.{metric}_tokens",
429+
int(value)
430+
)
431+
span.set_status(StatusCode.OK)
432+
span.end()
433+
return result
434+
435+
except Exception as error:
436+
span.record_exception(error)
437+
span.set_status(Status(StatusCode.ERROR, str(error)))
438+
span.end()
439+
raise
440+
441+
return traced_method
442+
443+
346444
def chat_stream(original_method, version, tracer):
347445
"""Wrap the `messages_stream` method."""
348446

0 commit comments

Comments
 (0)