1616
1717import json
1818
19- from langtrace_python_sdk .instrumentation .aws_bedrock .bedrock_streaming_wrapper import (
20- StreamingWrapper ,
21- )
19+ from wrapt import ObjectProxy
2220from .stream_body_wrapper import BufferedStreamBody
2321from functools import wraps
2422from langtrace .trace_attributes import (
@@ -87,6 +85,11 @@ def traced_method(wrapped, instance, args, kwargs):
8785
8886 client = wrapped (* args , ** kwargs )
8987 client .invoke_model = patch_invoke_model (client .invoke_model , tracer , version )
88+ client .invoke_model_with_response_stream = (
89+ patch_invoke_model_with_response_stream (
90+ client .invoke_model_with_response_stream , tracer , version
91+ )
92+ )
9093
9194 client .converse = patch_converse (client .converse , tracer , version )
9295 client .converse_stream = patch_converse_stream (
@@ -186,6 +189,56 @@ def traced_method(*args, **kwargs):
186189 return traced_method
187190
188191
192+ def patch_invoke_model_with_response_stream (original_method , tracer , version ):
193+ @wraps (original_method )
194+ def traced_method (* args , ** kwargs ):
195+ modelId = kwargs .get ("modelId" )
196+ (vendor , _ ) = modelId .split ("." )
197+ span_attributes = {
198+ ** get_langtrace_attributes (version , vendor , vendor_type = "framework" ),
199+ ** get_extra_attributes (),
200+ }
201+ span = tracer .start_span (
202+ name = get_span_name ("aws_bedrock.invoke_model_with_response_stream" ),
203+ kind = SpanKind .CLIENT ,
204+ context = set_span_in_context (trace .get_current_span ()),
205+ )
206+ set_span_attributes (span , span_attributes )
207+ response = original_method (* args , ** kwargs )
208+ if span .is_recording ():
209+ handle_streaming_call (span , kwargs , response )
210+ return response
211+
212+ return traced_method
213+
214+
215+ def handle_streaming_call (span , kwargs , response ):
216+
217+ def stream_finished (response_body ):
218+ request_body = json .loads (kwargs .get ("body" ))
219+
220+ (vendor , model ) = kwargs .get ("modelId" ).split ("." )
221+
222+ set_span_attribute (span , SpanAttributes .LLM_REQUEST_MODEL , model )
223+ set_span_attribute (span , SpanAttributes .LLM_RESPONSE_MODEL , model )
224+
225+ if vendor == "amazon" :
226+ set_amazon_attributes (span , request_body , response_body )
227+
228+ if vendor == "anthropic" :
229+ if "prompt" in request_body :
230+ set_anthropic_completions_attributes (span , request_body , response_body )
231+ elif "messages" in request_body :
232+ set_anthropic_messages_attributes (span , request_body , response_body )
233+
234+ if vendor == "meta" :
235+ set_llama_meta_attributes (span , request_body , response_body )
236+
237+ span .end ()
238+
239+ response ["body" ] = StreamingBedrockWrapper (response ["body" ], stream_finished )
240+
241+
189242def handle_call (span , kwargs , response ):
190243 modelId = kwargs .get ("modelId" )
191244 (vendor , model_name ) = modelId .split ("." )
@@ -195,7 +248,6 @@ def handle_call(span, kwargs, response):
195248 request_body = json .loads (kwargs .get ("body" ))
196249 response_body = json .loads (response .get ("body" ).read ())
197250
198- set_span_attribute (span , SpanAttributes .LLM_SYSTEM , vendor )
199251 set_span_attribute (span , SpanAttributes .LLM_RESPONSE_MODEL , modelId )
200252 set_span_attribute (span , SpanAttributes .LLM_REQUEST_MODEL , modelId )
201253
@@ -222,12 +274,18 @@ def set_llama_meta_attributes(span, request_body, response_body):
222274 set_span_attribute (
223275 span , SpanAttributes .LLM_REQUEST_MAX_TOKENS , request_body .get ("max_gen_len" )
224276 )
277+ if "invocation_metrics" in response_body :
278+ input_tokens = response_body .get ("invocation_metrics" ).get ("inputTokenCount" )
279+ output_tokens = response_body .get ("invocation_metrics" ).get ("outputTokenCount" )
280+ else :
281+ input_tokens = response_body .get ("prompt_token_count" )
282+ output_tokens = response_body .get ("generation_token_count" )
225283
226284 set_usage_attributes (
227285 span ,
228286 {
229- "input_tokens" : response_body . get ( "prompt_token_count" ) ,
230- "output_tokens" : response_body . get ( "generation_token_count" ) ,
287+ "input_tokens" : input_tokens ,
288+ "output_tokens" : output_tokens ,
231289 },
232290 )
233291
@@ -245,7 +303,6 @@ def set_llama_meta_attributes(span, request_body, response_body):
245303 }
246304 ]
247305 set_span_attribute (span , SpanAttributes .LLM_PROMPTS , json .dumps (prompts ))
248- print (completions )
249306 set_event_completion (span , completions )
250307
251308
@@ -257,13 +314,22 @@ def set_amazon_attributes(span, request_body, response_body):
257314 "content" : request_body .get ("inputText" ),
258315 }
259316 ]
260- completions = [
261- {
262- "role" : "assistant" ,
263- "content" : result .get ("outputText" ),
264- }
265- for result in response_body .get ("results" )
266- ]
317+ if "results" in response_body :
318+ completions = [
319+ {
320+ "role" : "assistant" ,
321+ "content" : result .get ("outputText" ),
322+ }
323+ for result in response_body .get ("results" )
324+ ]
325+
326+ else :
327+ completions = [
328+ {
329+ "role" : "assistant" ,
330+ "content" : response_body .get ("outputText" ),
331+ }
332+ ]
267333 set_span_attribute (
268334 span , SpanAttributes .LLM_REQUEST_MAX_TOKENS , config .get ("maxTokenCount" )
269335 )
@@ -272,13 +338,19 @@ def set_amazon_attributes(span, request_body, response_body):
272338 )
273339 set_span_attribute (span , SpanAttributes .LLM_REQUEST_TOP_P , config .get ("topP" ))
274340 set_span_attribute (span , SpanAttributes .LLM_PROMPTS , json .dumps (prompts ))
341+ input_tokens = response_body .get ("inputTextTokenCount" )
342+ if "results" in response_body :
343+ output_tokens = sum (
344+ int (result .get ("tokenCount" )) for result in response_body .get ("results" )
345+ )
346+ else :
347+ output_tokens = response_body .get ("outputTextTokenCount" )
348+
275349 set_usage_attributes (
276350 span ,
277351 {
278- "input_tokens" : response_body .get ("inputTextTokenCount" ),
279- "output_tokens" : sum (
280- int (result .get ("tokenCount" )) for result in response_body .get ("results" )
281- ),
352+ "input_tokens" : input_tokens ,
353+ "output_tokens" : output_tokens ,
282354 },
283355 )
284356 set_event_completion (span , completions )
@@ -320,7 +392,7 @@ def set_anthropic_messages_attributes(span, request_body, response_body):
320392 set_span_attribute (
321393 span ,
322394 SpanAttributes .LLM_REQUEST_MAX_TOKENS ,
323- request_body .get ("max_tokens_to_sample" ),
395+ request_body .get ("max_tokens_to_sample" ) or request_body . get ( "max_tokens" ) ,
324396 )
325397 set_span_attribute (
326398 span ,
@@ -394,3 +466,62 @@ def set_span_streaming_response(span, response):
394466 set_event_completion (
395467 span , [{"role" : role or "assistant" , "content" : streaming_response }]
396468 )
469+
470+
471+ class StreamingBedrockWrapper (ObjectProxy ):
472+ def __init__ (
473+ self ,
474+ response ,
475+ stream_done_callback = None ,
476+ ):
477+ super ().__init__ (response )
478+
479+ self ._stream_done_callback = stream_done_callback
480+ self ._accumulating_body = {"generation" : "" }
481+
482+ def __iter__ (self ):
483+ for event in self .__wrapped__ :
484+ self ._process_event (event )
485+ yield event
486+
487+ def _process_event (self , event ):
488+ chunk = event .get ("chunk" )
489+ if not chunk :
490+ return
491+
492+ decoded_chunk = json .loads (chunk .get ("bytes" ).decode ())
493+ type = decoded_chunk .get ("type" )
494+
495+ if type is None and "outputText" in decoded_chunk :
496+ self ._stream_done_callback (decoded_chunk )
497+ return
498+ if "generation" in decoded_chunk :
499+ self ._accumulating_body ["generation" ] += decoded_chunk .get ("generation" )
500+
501+ if type == "message_start" :
502+ self ._accumulating_body = decoded_chunk .get ("message" )
503+ elif type == "content_block_start" :
504+ self ._accumulating_body ["content" ].append (
505+ decoded_chunk .get ("content_block" )
506+ )
507+ elif type == "content_block_delta" :
508+ self ._accumulating_body ["content" ][- 1 ]["text" ] += decoded_chunk .get (
509+ "delta"
510+ ).get ("text" )
511+
512+ elif self .has_finished (type , decoded_chunk ):
513+ self ._accumulating_body ["invocation_metrics" ] = decoded_chunk .get (
514+ "amazon-bedrock-invocationMetrics"
515+ )
516+ self ._stream_done_callback (self ._accumulating_body )
517+
518+ def has_finished (self , type , chunk ):
519+ if type and type == "message_stop" :
520+ return True
521+
522+ if "completionReason" in chunk and chunk .get ("completionReason" ) == "FINISH" :
523+ return True
524+
525+ if "stop_reason" in chunk and chunk .get ("stop_reason" ) is not None :
526+ return True
527+ return False
0 commit comments