3131 ConverseStreamWrapper ,
3232 InvokeModelWithResponseStreamWrapper ,
3333 _Choice ,
34+ estimate_token_count ,
3435 genai_capture_message_content ,
3536 message_to_event ,
36- estimate_token_count ,
3737)
3838from opentelemetry .instrumentation .botocore .extensions .types import (
3939 _AttributeMapT ,
106106
107107_MODEL_ID_KEY : str = "modelId"
108108
109+
109110class _BedrockRuntimeExtension (_AwsSdkExtension ):
110111 """
111112 This class is an extension for <a
@@ -255,7 +256,9 @@ def _extract_titan_attributes(self, attributes, request_body):
255256 attributes , GEN_AI_REQUEST_MAX_TOKENS , config .get ("maxTokenCount" )
256257 )
257258 self ._set_if_not_none (
258- attributes , GEN_AI_REQUEST_STOP_SEQUENCES , config .get ("stopSequences" )
259+ attributes ,
260+ GEN_AI_REQUEST_STOP_SEQUENCES ,
261+ config .get ("stopSequences" ),
259262 )
260263
261264 def _extract_nova_attributes (self , attributes , request_body ):
@@ -270,21 +273,29 @@ def _extract_nova_attributes(self, attributes, request_body):
270273 attributes , GEN_AI_REQUEST_MAX_TOKENS , config .get ("max_new_tokens" )
271274 )
272275 self ._set_if_not_none (
273- attributes , GEN_AI_REQUEST_STOP_SEQUENCES , config .get ("stopSequences" )
276+ attributes ,
277+ GEN_AI_REQUEST_STOP_SEQUENCES ,
278+ config .get ("stopSequences" ),
274279 )
275280
276281 def _extract_claude_attributes (self , attributes , request_body ):
277282 self ._set_if_not_none (
278- attributes , GEN_AI_REQUEST_MAX_TOKENS , request_body .get ("max_tokens" )
283+ attributes ,
284+ GEN_AI_REQUEST_MAX_TOKENS ,
285+ request_body .get ("max_tokens" ),
279286 )
280287 self ._set_if_not_none (
281- attributes , GEN_AI_REQUEST_TEMPERATURE , request_body .get ("temperature" )
288+ attributes ,
289+ GEN_AI_REQUEST_TEMPERATURE ,
290+ request_body .get ("temperature" ),
282291 )
283292 self ._set_if_not_none (
284293 attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("top_p" )
285294 )
286295 self ._set_if_not_none (
287- attributes , GEN_AI_REQUEST_STOP_SEQUENCES , request_body .get ("stop_sequences" )
296+ attributes ,
297+ GEN_AI_REQUEST_STOP_SEQUENCES ,
298+ request_body .get ("stop_sequences" ),
288299 )
289300
290301 def _extract_command_r_attributes (self , attributes , request_body ):
@@ -293,16 +304,22 @@ def _extract_command_r_attributes(self, attributes, request_body):
293304 attributes , GEN_AI_USAGE_INPUT_TOKENS , estimate_token_count (prompt )
294305 )
295306 self ._set_if_not_none (
296- attributes , GEN_AI_REQUEST_MAX_TOKENS , request_body .get ("max_tokens" )
307+ attributes ,
308+ GEN_AI_REQUEST_MAX_TOKENS ,
309+ request_body .get ("max_tokens" ),
297310 )
298311 self ._set_if_not_none (
299- attributes , GEN_AI_REQUEST_TEMPERATURE , request_body .get ("temperature" )
312+ attributes ,
313+ GEN_AI_REQUEST_TEMPERATURE ,
314+ request_body .get ("temperature" ),
300315 )
301316 self ._set_if_not_none (
302317 attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("p" )
303318 )
304319 self ._set_if_not_none (
305- attributes , GEN_AI_REQUEST_STOP_SEQUENCES , request_body .get ("stop_sequences" )
320+ attributes ,
321+ GEN_AI_REQUEST_STOP_SEQUENCES ,
322+ request_body .get ("stop_sequences" ),
306323 )
307324
308325 def _extract_command_attributes (self , attributes , request_body ):
@@ -311,24 +328,34 @@ def _extract_command_attributes(self, attributes, request_body):
311328 attributes , GEN_AI_USAGE_INPUT_TOKENS , estimate_token_count (prompt )
312329 )
313330 self ._set_if_not_none (
314- attributes , GEN_AI_REQUEST_MAX_TOKENS , request_body .get ("max_tokens" )
331+ attributes ,
332+ GEN_AI_REQUEST_MAX_TOKENS ,
333+ request_body .get ("max_tokens" ),
315334 )
316335 self ._set_if_not_none (
317- attributes , GEN_AI_REQUEST_TEMPERATURE , request_body .get ("temperature" )
336+ attributes ,
337+ GEN_AI_REQUEST_TEMPERATURE ,
338+ request_body .get ("temperature" ),
318339 )
319340 self ._set_if_not_none (
320341 attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("p" )
321342 )
322343 self ._set_if_not_none (
323- attributes , GEN_AI_REQUEST_STOP_SEQUENCES , request_body .get ("stop_sequences" )
344+ attributes ,
345+ GEN_AI_REQUEST_STOP_SEQUENCES ,
346+ request_body .get ("stop_sequences" ),
324347 )
325348
326349 def _extract_llama_attributes (self , attributes , request_body ):
327350 self ._set_if_not_none (
328- attributes , GEN_AI_REQUEST_MAX_TOKENS , request_body .get ("max_gen_len" )
351+ attributes ,
352+ GEN_AI_REQUEST_MAX_TOKENS ,
353+ request_body .get ("max_gen_len" ),
329354 )
330355 self ._set_if_not_none (
331- attributes , GEN_AI_REQUEST_TEMPERATURE , request_body .get ("temperature" )
356+ attributes ,
357+ GEN_AI_REQUEST_TEMPERATURE ,
358+ request_body .get ("temperature" ),
332359 )
333360 self ._set_if_not_none (
334361 attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("top_p" )
@@ -339,13 +366,19 @@ def _extract_mistral_attributes(self, attributes, request_body):
339366 prompt = request_body .get ("prompt" )
340367 if prompt :
341368 self ._set_if_not_none (
342- attributes , GEN_AI_USAGE_INPUT_TOKENS , estimate_token_count (prompt )
369+ attributes ,
370+ GEN_AI_USAGE_INPUT_TOKENS ,
371+ estimate_token_count (prompt ),
343372 )
344373 self ._set_if_not_none (
345- attributes , GEN_AI_REQUEST_MAX_TOKENS , request_body .get ("max_tokens" )
374+ attributes ,
375+ GEN_AI_REQUEST_MAX_TOKENS ,
376+ request_body .get ("max_tokens" ),
346377 )
347378 self ._set_if_not_none (
348- attributes , GEN_AI_REQUEST_TEMPERATURE , request_body .get ("temperature" )
379+ attributes ,
380+ GEN_AI_REQUEST_TEMPERATURE ,
381+ request_body .get ("temperature" ),
349382 )
350383 self ._set_if_not_none (
351384 attributes , GEN_AI_REQUEST_TOP_P , request_body .get ("top_p" )
@@ -361,7 +394,6 @@ def _set_if_not_none(attributes, key, value):
361394
362395 def _get_request_messages (self ):
363396 """Extracts and normalize system and user / assistant messages"""
364- input_text = None
365397 if system := self ._call_context .params .get ("system" , []):
366398 system_messages = [{"role" : "system" , "content" : system }]
367399 else :
@@ -390,20 +422,23 @@ def _get_request_messages(self):
390422 messages = self ._get_messages_from_input_text (
391423 decoded_body , "message"
392424 )
393- elif "cohere.command" in model_id or "meta.llama" in model_id or "mistral.mistral" in model_id :
425+ elif (
426+ "cohere.command" in model_id
427+ or "meta.llama" in model_id
428+ or "mistral.mistral" in model_id
429+ ):
394430 messages = self ._get_messages_from_input_text (
395431 decoded_body , "prompt"
396432 )
397433
398434 return system_messages + messages
399435
436+ # pylint: disable=no-self-use
400437 def _get_messages_from_input_text (
401438 self , decoded_body : dict [str , Any ], input_name : str
402439 ):
403440 if input_text := decoded_body .get (input_name ):
404- return [
405- {"role" : "user" , "content" : [{"text" : input_text }]}
406- ]
441+ return [{"role" : "user" , "content" : [{"text" : input_text }]}]
407442 return []
408443
409444 def before_service_call (
@@ -843,11 +878,13 @@ def _handle_cohere_command_r_response(
843878 ):
844879 if "text" in response_body :
845880 span .set_attribute (
846- GEN_AI_USAGE_OUTPUT_TOKENS , estimate_token_count (response_body ["text" ])
881+ GEN_AI_USAGE_OUTPUT_TOKENS ,
882+ estimate_token_count (response_body ["text" ]),
847883 )
848884 if "finish_reason" in response_body :
849885 span .set_attribute (
850- GEN_AI_RESPONSE_FINISH_REASONS , [response_body ["finish_reason" ]]
886+ GEN_AI_RESPONSE_FINISH_REASONS ,
887+ [response_body ["finish_reason" ]],
851888 )
852889
853890 event_logger = instrumentor_context .event_logger
@@ -867,11 +904,13 @@ def _handle_cohere_command_response(
867904 generations = response_body ["generations" ][0 ]
868905 if "text" in generations :
869906 span .set_attribute (
870- GEN_AI_USAGE_OUTPUT_TOKENS , estimate_token_count (generations ["text" ])
907+ GEN_AI_USAGE_OUTPUT_TOKENS ,
908+ estimate_token_count (generations ["text" ]),
871909 )
872910 if "finish_reason" in generations :
873911 span .set_attribute (
874- GEN_AI_RESPONSE_FINISH_REASONS , [generations ["finish_reason" ]]
912+ GEN_AI_RESPONSE_FINISH_REASONS ,
913+ [generations ["finish_reason" ]],
875914 )
876915
877916 event_logger = instrumentor_context .event_logger
@@ -893,17 +932,16 @@ def _handle_meta_llama_response(
893932 )
894933 if "generation_token_count" in response_body :
895934 span .set_attribute (
896- GEN_AI_USAGE_OUTPUT_TOKENS , response_body ["generation_token_count" ],
935+ GEN_AI_USAGE_OUTPUT_TOKENS ,
936+ response_body ["generation_token_count" ],
897937 )
898938 if "stop_reason" in response_body :
899939 span .set_attribute (
900940 GEN_AI_RESPONSE_FINISH_REASONS , [response_body ["stop_reason" ]]
901941 )
902942
903943 event_logger = instrumentor_context .event_logger
904- choice = _Choice .from_invoke_meta_llama (
905- response_body , capture_content
906- )
944+ choice = _Choice .from_invoke_meta_llama (response_body , capture_content )
907945 event_logger .emit (choice .to_choice_event ())
908946
909947 def _handle_mistral_ai_response (
@@ -916,9 +954,14 @@ def _handle_mistral_ai_response(
916954 if "outputs" in response_body :
917955 outputs = response_body ["outputs" ][0 ]
918956 if "text" in outputs :
919- span .set_attribute (GEN_AI_USAGE_OUTPUT_TOKENS , estimate_token_count (outputs ["text" ]))
957+ span .set_attribute (
958+ GEN_AI_USAGE_OUTPUT_TOKENS ,
959+ estimate_token_count (outputs ["text" ]),
960+ )
920961 if "stop_reason" in outputs :
921- span .set_attribute (GEN_AI_RESPONSE_FINISH_REASONS , [outputs ["stop_reason" ]])
962+ span .set_attribute (
963+ GEN_AI_RESPONSE_FINISH_REASONS , [outputs ["stop_reason" ]]
964+ )
922965
923966 event_logger = instrumentor_context .event_logger
924967 choice = _Choice .from_invoke_mistral_mistral (
0 commit comments