3131 ConverseStreamWrapper ,
3232 InvokeModelWithResponseStreamWrapper ,
3333 _Choice ,
34+ estimate_token_count ,
3435 genai_capture_message_content ,
3536 message_to_event ,
3637)
@@ -223,6 +224,23 @@ def extract_attributes(self, attributes: _AttributeMapT):
223224 self ._extract_claude_attributes (
224225 attributes , request_body
225226 )
227+ elif "cohere.command-r" in model_id :
228+ self ._extract_command_r_attributes (
229+ attributes , request_body
230+ )
231+ elif "cohere.command" in model_id :
232+ self ._extract_command_attributes (
233+ attributes , request_body
234+ )
235+ elif "meta.llama" in model_id :
236+ self ._extract_llama_attributes (
237+ attributes , request_body
238+ )
239+ elif "mistral" in model_id :
240+ self ._extract_mistral_attributes (
241+ attributes , request_body
242+ )
243+
226244 except json .JSONDecodeError :
227245 _logger .debug ("Error: Unable to parse the body as JSON" )
228246
@@ -280,14 +298,102 @@ def _extract_claude_attributes(self, attributes, request_body):
280298 request_body .get ("stop_sequences" ),
281299 )
282300
301+ def _extract_command_r_attributes (self , attributes , request_body ):
302+ prompt = request_body .get ("message" )
303+ self ._set_if_not_none (
304+ attributes , GEN_AI_USAGE_INPUT_TOKENS , estimate_token_count (prompt )
305+ )
306+ self ._set_if_not_none (
307+ attributes ,
308+ GEN_AI_REQUEST_MAX_TOKENS ,
309+ request_body .get ("max_tokens" ),
310+ )
311+ self ._set_if_not_none (
312+ attributes ,
313+ GEN_AI_REQUEST_TEMPERATURE ,
314+ request_body .get ("temperature" ),
315+ )
316+ self ._set_if_not_none (
317+ attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("p" )
318+ )
319+ self ._set_if_not_none (
320+ attributes ,
321+ GEN_AI_REQUEST_STOP_SEQUENCES ,
322+ request_body .get ("stop_sequences" ),
323+ )
324+
325+ def _extract_command_attributes (self , attributes , request_body ):
326+ prompt = request_body .get ("prompt" )
327+ self ._set_if_not_none (
328+ attributes , GEN_AI_USAGE_INPUT_TOKENS , estimate_token_count (prompt )
329+ )
330+ self ._set_if_not_none (
331+ attributes ,
332+ GEN_AI_REQUEST_MAX_TOKENS ,
333+ request_body .get ("max_tokens" ),
334+ )
335+ self ._set_if_not_none (
336+ attributes ,
337+ GEN_AI_REQUEST_TEMPERATURE ,
338+ request_body .get ("temperature" ),
339+ )
340+ self ._set_if_not_none (
341+ attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("p" )
342+ )
343+ self ._set_if_not_none (
344+ attributes ,
345+ GEN_AI_REQUEST_STOP_SEQUENCES ,
346+ request_body .get ("stop_sequences" ),
347+ )
348+
349+ def _extract_llama_attributes (self , attributes , request_body ):
350+ self ._set_if_not_none (
351+ attributes ,
352+ GEN_AI_REQUEST_MAX_TOKENS ,
353+ request_body .get ("max_gen_len" ),
354+ )
355+ self ._set_if_not_none (
356+ attributes ,
357+ GEN_AI_REQUEST_TEMPERATURE ,
358+ request_body .get ("temperature" ),
359+ )
360+ self ._set_if_not_none (
361+ attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("top_p" )
362+ )
363+ # request for meta llama models does not contain stop_sequences field
364+
365+ def _extract_mistral_attributes (self , attributes , request_body ):
366+ prompt = request_body .get ("prompt" )
367+ if prompt :
368+ self ._set_if_not_none (
369+ attributes ,
370+ GEN_AI_USAGE_INPUT_TOKENS ,
371+ estimate_token_count (prompt ),
372+ )
373+ self ._set_if_not_none (
374+ attributes ,
375+ GEN_AI_REQUEST_MAX_TOKENS ,
376+ request_body .get ("max_tokens" ),
377+ )
378+ self ._set_if_not_none (
379+ attributes ,
380+ GEN_AI_REQUEST_TEMPERATURE ,
381+ request_body .get ("temperature" ),
382+ )
383+ self ._set_if_not_none (
384+ attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("top_p" )
385+ )
386+ self ._set_if_not_none (
387+ attributes , GEN_AI_REQUEST_STOP_SEQUENCES , request_body .get ("stop" )
388+ )
389+
283390 @staticmethod
284391 def _set_if_not_none (attributes , key , value ):
285392 if value is not None :
286393 attributes [key ] = value
287394
288395 def _get_request_messages (self ):
289396 """Extracts and normalize system and user / assistant messages"""
290- input_text = None
291397 if system := self ._call_context .params .get ("system" , []):
292398 system_messages = [{"role" : "system" , "content" : system }]
293399 else :
@@ -304,15 +410,37 @@ def _get_request_messages(self):
304410 system_messages = [{"role" : "system" , "content" : content }]
305411
306412 messages = decoded_body .get ("messages" , [])
413+ # if no messages interface, convert to messages format from generic API
307414 if not messages :
308- # transform old school amazon titan invokeModel api to messages
309- if input_text := decoded_body .get ("inputText" ):
310- messages = [
311- {"role" : "user" , "content" : [{"text" : input_text }]}
312- ]
415+ model_id = self ._call_context .params .get (_MODEL_ID_KEY )
416+ if "amazon.titan" in model_id :
417+ messages = self ._get_messages_from_input_text (
418+ decoded_body , "inputText"
419+ )
420+ elif "cohere.command-r" in model_id :
421+ # chat_history can be converted to messages; for now, just use message
422+ messages = self ._get_messages_from_input_text (
423+ decoded_body , "message"
424+ )
425+ elif (
426+ "cohere.command" in model_id
427+ or "meta.llama" in model_id
428+ or "mistral.mistral" in model_id
429+ ):
430+ messages = self ._get_messages_from_input_text (
431+ decoded_body , "prompt"
432+ )
313433
314434 return system_messages + messages
315435
436+ # pylint: disable=no-self-use
437+ def _get_messages_from_input_text (
438+ self , decoded_body : dict [str , Any ], input_name : str
439+ ):
440+ if input_text := decoded_body .get (input_name ):
441+ return [{"role" : "user" , "content" : [{"text" : input_text }]}]
442+ return []
443+
316444 def before_service_call (
317445 self , span : Span , instrumentor_context : _BotocoreInstrumentorContext
318446 ):
@@ -439,6 +567,22 @@ def _invoke_model_on_success(
439567 self ._handle_anthropic_claude_response (
440568 span , response_body , instrumentor_context , capture_content
441569 )
570+ elif "cohere.command-r" in model_id :
571+ self ._handle_cohere_command_r_response (
572+ span , response_body , instrumentor_context , capture_content
573+ )
574+ elif "cohere.command" in model_id :
575+ self ._handle_cohere_command_response (
576+ span , response_body , instrumentor_context , capture_content
577+ )
578+ elif "meta.llama" in model_id :
579+ self ._handle_meta_llama_response (
580+ span , response_body , instrumentor_context , capture_content
581+ )
582+ elif "mistral" in model_id :
583+ self ._handle_mistral_ai_response (
584+ span , response_body , instrumentor_context , capture_content
585+ )
442586 except json .JSONDecodeError :
443587 _logger .debug ("Error: Unable to parse the response body as JSON" )
444588 except Exception as exc : # pylint: disable=broad-exception-caught
@@ -725,6 +869,106 @@ def _handle_anthropic_claude_response(
725869 output_tokens , output_attributes
726870 )
727871
872+ def _handle_cohere_command_r_response (
873+ self ,
874+ span : Span ,
875+ response_body : dict [str , Any ],
876+ instrumentor_context : _BotocoreInstrumentorContext ,
877+ capture_content : bool ,
878+ ):
879+ if "text" in response_body :
880+ span .set_attribute (
881+ GEN_AI_USAGE_OUTPUT_TOKENS ,
882+ estimate_token_count (response_body ["text" ]),
883+ )
884+ if "finish_reason" in response_body :
885+ span .set_attribute (
886+ GEN_AI_RESPONSE_FINISH_REASONS ,
887+ [response_body ["finish_reason" ]],
888+ )
889+
890+ event_logger = instrumentor_context .event_logger
891+ choice = _Choice .from_invoke_cohere_command_r (
892+ response_body , capture_content
893+ )
894+ event_logger .emit (choice .to_choice_event ())
895+
896+ def _handle_cohere_command_response (
897+ self ,
898+ span : Span ,
899+ response_body : dict [str , Any ],
900+ instrumentor_context : _BotocoreInstrumentorContext ,
901+ capture_content : bool ,
902+ ):
903+ if "generations" in response_body and response_body ["generations" ]:
904+ generations = response_body ["generations" ][0 ]
905+ if "text" in generations :
906+ span .set_attribute (
907+ GEN_AI_USAGE_OUTPUT_TOKENS ,
908+ estimate_token_count (generations ["text" ]),
909+ )
910+ if "finish_reason" in generations :
911+ span .set_attribute (
912+ GEN_AI_RESPONSE_FINISH_REASONS ,
913+ [generations ["finish_reason" ]],
914+ )
915+
916+ event_logger = instrumentor_context .event_logger
917+ choice = _Choice .from_invoke_cohere_command (
918+ response_body , capture_content
919+ )
920+ event_logger .emit (choice .to_choice_event ())
921+
922+ def _handle_meta_llama_response (
923+ self ,
924+ span : Span ,
925+ response_body : dict [str , Any ],
926+ instrumentor_context : _BotocoreInstrumentorContext ,
927+ capture_content : bool ,
928+ ):
929+ if "prompt_token_count" in response_body :
930+ span .set_attribute (
931+ GEN_AI_USAGE_INPUT_TOKENS , response_body ["prompt_token_count" ]
932+ )
933+ if "generation_token_count" in response_body :
934+ span .set_attribute (
935+ GEN_AI_USAGE_OUTPUT_TOKENS ,
936+ response_body ["generation_token_count" ],
937+ )
938+ if "stop_reason" in response_body :
939+ span .set_attribute (
940+ GEN_AI_RESPONSE_FINISH_REASONS , [response_body ["stop_reason" ]]
941+ )
942+
943+ event_logger = instrumentor_context .event_logger
944+ choice = _Choice .from_invoke_meta_llama (response_body , capture_content )
945+ event_logger .emit (choice .to_choice_event ())
946+
947+ def _handle_mistral_ai_response (
948+ self ,
949+ span : Span ,
950+ response_body : dict [str , Any ],
951+ instrumentor_context : _BotocoreInstrumentorContext ,
952+ capture_content : bool ,
953+ ):
954+ if "outputs" in response_body :
955+ outputs = response_body ["outputs" ][0 ]
956+ if "text" in outputs :
957+ span .set_attribute (
958+ GEN_AI_USAGE_OUTPUT_TOKENS ,
959+ estimate_token_count (outputs ["text" ]),
960+ )
961+ if "stop_reason" in outputs :
962+ span .set_attribute (
963+ GEN_AI_RESPONSE_FINISH_REASONS , [outputs ["stop_reason" ]]
964+ )
965+
966+ event_logger = instrumentor_context .event_logger
967+ choice = _Choice .from_invoke_mistral_mistral (
968+ response_body , capture_content
969+ )
970+ event_logger .emit (choice .to_choice_event ())
971+
728972 def on_error (
729973 self ,
730974 span : Span ,
0 commit comments