2424 get_span_name ,
2525 set_event_completion ,
2626 set_usage_attributes ,
27+ StreamWrapper
2728)
2829from langtrace .trace_attributes import Event , LLMSpanAttributes
2930from langtrace_python_sdk .utils import set_span_attribute
3839from langtrace .trace_attributes import SpanAttributes
3940
4041
41- def rerank (original_method , version , tracer ):
42+ def rerank (original_method , version , tracer , v2 = False ):
4243 """Wrap the `rerank` method."""
4344
4445 def traced_method (wrapped , instance , args , kwargs ):
@@ -49,8 +50,8 @@ def traced_method(wrapped, instance, args, kwargs):
4950 ** get_llm_request_attributes (kwargs , operation_name = "rerank" ),
5051 ** get_llm_url (instance ),
5152 SpanAttributes .LLM_REQUEST_MODEL : kwargs .get ("model" ) or "command-r-plus" ,
52- SpanAttributes .LLM_URL : APIS ["RERANK" ]["URL" ],
53- SpanAttributes .LLM_PATH : APIS ["RERANK" ]["ENDPOINT" ],
53+ SpanAttributes .LLM_URL : APIS ["RERANK" if not v2 else "RERANK_V2" ]["URL" ],
54+ SpanAttributes .LLM_PATH : APIS ["RERANK" if not v2 else "RERANK_V2" ]["ENDPOINT" ],
5455 SpanAttributes .LLM_REQUEST_DOCUMENTS : json .dumps (
5556 kwargs .get ("documents" ), cls = datetime_encoder
5657 ),
@@ -61,7 +62,7 @@ def traced_method(wrapped, instance, args, kwargs):
6162 attributes = LLMSpanAttributes (** span_attributes )
6263
6364 span = tracer .start_span (
64- name = get_span_name (APIS ["RERANK" ]["METHOD" ]), kind = SpanKind .CLIENT
65+ name = get_span_name (APIS ["RERANK" if not v2 else "RERANK_V2" ]["METHOD" ]), kind = SpanKind .CLIENT
6566 )
6667 for field , value in attributes .model_dump (by_alias = True ).items ():
6768 set_span_attribute (span , field , value )
@@ -119,7 +120,7 @@ def traced_method(wrapped, instance, args, kwargs):
119120 return traced_method
120121
121122
122- def embed (original_method , version , tracer ):
123+ def embed (original_method , version , tracer , v2 = False ):
123124 """Wrap the `embed` method."""
124125
125126 def traced_method (wrapped , instance , args , kwargs ):
@@ -129,8 +130,8 @@ def traced_method(wrapped, instance, args, kwargs):
129130 ** get_langtrace_attributes (version , service_provider ),
130131 ** get_llm_request_attributes (kwargs , operation_name = "embed" ),
131132 ** get_llm_url (instance ),
132- SpanAttributes .LLM_URL : APIS ["EMBED" ]["URL" ],
133- SpanAttributes .LLM_PATH : APIS ["EMBED" ]["ENDPOINT" ],
133+ SpanAttributes .LLM_URL : APIS ["EMBED" if not v2 else "EMBED_V2" ]["URL" ],
134+ SpanAttributes .LLM_PATH : APIS ["EMBED" if not v2 else "EMBED_V2" ]["ENDPOINT" ],
134135 SpanAttributes .LLM_REQUEST_EMBEDDING_INPUTS : json .dumps (
135136 kwargs .get ("texts" )
136137 ),
@@ -143,7 +144,7 @@ def traced_method(wrapped, instance, args, kwargs):
143144 attributes = LLMSpanAttributes (** span_attributes )
144145
145146 span = tracer .start_span (
146- name = get_span_name (APIS ["EMBED" ]["METHOD" ]),
147+ name = get_span_name (APIS ["EMBED" if not v2 else "EMBED_V2" ]["METHOD" ]),
147148 kind = SpanKind .CLIENT ,
148149 )
149150 for field , value in attributes .model_dump (by_alias = True ).items ():
@@ -343,6 +344,103 @@ def traced_method(wrapped, instance, args, kwargs):
343344 return traced_method
344345
345346
347+ def chat_create_v2 (original_method , version , tracer , stream = False ):
348+ """Wrap the `chat_create` method for Cohere API v2."""
349+
350+ def traced_method (wrapped , instance , args , kwargs ):
351+ service_provider = SERVICE_PROVIDERS ["COHERE" ]
352+
353+ messages = kwargs .get ("messages" , [])
354+ if kwargs .get ("preamble" ):
355+ messages = [{"role" : "system" , "content" : kwargs ["preamble" ]}] + messages
356+
357+ span_attributes = {
358+ ** get_langtrace_attributes (version , service_provider ),
359+ ** get_llm_request_attributes (kwargs , prompts = messages ),
360+ ** get_llm_url (instance ),
361+ SpanAttributes .LLM_REQUEST_MODEL : kwargs .get ("model" ) or "command-r-plus" ,
362+ SpanAttributes .LLM_URL : APIS ["CHAT_CREATE_V2" ]["URL" ],
363+ SpanAttributes .LLM_PATH : APIS ["CHAT_CREATE_V2" ]["ENDPOINT" ],
364+ ** get_extra_attributes (),
365+ }
366+
367+ attributes = LLMSpanAttributes (** span_attributes )
368+
369+ for attr_name in ["max_input_tokens" , "conversation_id" , "connectors" , "tools" , "tool_results" ]:
370+ value = kwargs .get (attr_name )
371+ if value is not None :
372+ if attr_name == "max_input_tokens" :
373+ attributes .llm_max_input_tokens = str (value )
374+ elif attr_name == "conversation_id" :
375+ attributes .conversation_id = value
376+ else :
377+ setattr (attributes , f"llm_{ attr_name } " , json .dumps (value ))
378+
379+ span = tracer .start_span (
380+ name = get_span_name (APIS ["CHAT_CREATE_V2" ]["METHOD" ]),
381+ kind = SpanKind .CLIENT
382+ )
383+
384+ for field , value in attributes .model_dump (by_alias = True ).items ():
385+ set_span_attribute (span , field , value )
386+
387+ try :
388+ result = wrapped (* args , ** kwargs )
389+
390+ if stream :
391+ return StreamWrapper (
392+ result ,
393+ span ,
394+ tool_calls = kwargs .get ("tools" ) is not None ,
395+ )
396+ else :
397+ if hasattr (result , "id" ) and result .id is not None :
398+ span .set_attribute (SpanAttributes .LLM_GENERATION_ID , result .id )
399+ span .set_attribute (SpanAttributes .LLM_RESPONSE_ID , result .id )
400+
401+ if (hasattr (result , "message" ) and
402+ hasattr (result .message , "content" ) and
403+ len (result .message .content ) > 0 and
404+ hasattr (result .message .content [0 ], "text" ) and
405+ result .message .content [0 ].text is not None and
406+ result .message .content [0 ].text != "" ):
407+ responses = [{
408+ "role" : result .message .role ,
409+ "content" : result .message .content [0 ].text
410+ }]
411+ set_event_completion (span , responses )
412+ if hasattr (result , "tool_calls" ) and result .tool_calls is not None :
413+ tool_calls = [tool_call .json () for tool_call in result .tool_calls ]
414+ span .set_attribute (
415+ SpanAttributes .LLM_TOOL_RESULTS ,
416+ json .dumps (tool_calls )
417+ )
418+ if hasattr (result , "usage" ) and result .usage is not None :
419+ if (hasattr (result .usage , "billed_units" ) and
420+ result .usage .billed_units is not None ):
421+ usage = result .usage .billed_units
422+ for metric , value in {
423+ "input" : usage .input_tokens or 0 ,
424+ "output" : usage .output_tokens or 0 ,
425+ "total" : (usage .input_tokens or 0 ) + (usage .output_tokens or 0 ),
426+ }.items ():
427+ span .set_attribute (
428+ f"gen_ai.usage.{ metric } _tokens" ,
429+ int (value )
430+ )
431+ span .set_status (StatusCode .OK )
432+ span .end ()
433+ return result
434+
435+ except Exception as error :
436+ span .record_exception (error )
437+ span .set_status (Status (StatusCode .ERROR , str (error )))
438+ span .end ()
439+ raise
440+
441+ return traced_method
442+
443+
346444def chat_stream (original_method , version , tracer ):
347445 """Wrap the `messages_stream` method."""
348446
0 commit comments