1515"""
1616
1717import json
18+ import io
1819
1920from wrapt import ObjectProxy
21+ from itertools import tee
2022from .stream_body_wrapper import BufferedStreamBody
2123from functools import wraps
2224from langtrace .trace_attributes import (
4345 set_span_attributes ,
4446 set_usage_attributes ,
4547)
48+ from langtrace_python_sdk .utils import set_event_prompt
4649
4750
4851def converse_stream (original_method , version , tracer ):
@@ -104,7 +107,7 @@ def traced_method(wrapped, instance, args, kwargs):
104107def patch_converse_stream (original_method , tracer , version ):
105108 def traced_method (* args , ** kwargs ):
106109 modelId = kwargs .get ("modelId" )
107- ( vendor , _ ) = modelId . split ( "." )
110+ vendor , _ = parse_vendor_and_model_name_from_model_id ( modelId )
108111 input_content = [
109112 {
110113 "role" : message .get ("role" , "user" ),
@@ -128,7 +131,9 @@ def traced_method(*args, **kwargs):
128131 response = original_method (* args , ** kwargs )
129132
130133 if span .is_recording ():
131- set_span_streaming_response (span , response )
134+ stream1 , stream2 = tee (response ["stream" ])
135+ set_span_streaming_response (span , stream1 )
136+ response ["stream" ] = stream2
132137 return response
133138
134139 return traced_method
@@ -137,7 +142,7 @@ def traced_method(*args, **kwargs):
137142def patch_converse (original_method , tracer , version ):
138143 def traced_method (* args , ** kwargs ):
139144 modelId = kwargs .get ("modelId" )
140- ( vendor , _ ) = modelId . split ( "." )
145+ vendor , _ = parse_vendor_and_model_name_from_model_id ( modelId )
141146 input_content = [
142147 {
143148 "role" : message .get ("role" , "user" ),
@@ -167,12 +172,29 @@ def traced_method(*args, **kwargs):
167172 return traced_method
168173
169174
175+ def parse_vendor_and_model_name_from_model_id (model_id ):
176+ if model_id .startswith ("arn:aws:bedrock:" ):
177+ # This needs to be in one of the following forms:
178+ # arn:aws:bedrock:region:account-id:foundation-model/vendor.model-name
179+ # arn:aws:bedrock:region:account-id:custom-model/vendor.model-name/model-id
180+ parts = model_id .split ("/" )
181+ identifiers = parts [1 ].split ("." )
182+ return identifiers [0 ], identifiers [1 ]
183+ parts = model_id .split ("." )
184+ if len (parts ) == 1 :
185+ return parts [0 ], parts [0 ]
186+ else :
187+ return parts [- 2 ], parts [- 1 ]
188+
189+
170190def patch_invoke_model (original_method , tracer , version ):
171191 def traced_method (* args , ** kwargs ):
172192 modelId = kwargs .get ("modelId" )
173- ( vendor , _ ) = modelId . split ( "." )
193+ vendor , _ = parse_vendor_and_model_name_from_model_id ( modelId )
174194 span_attributes = {
175195 ** get_langtrace_attributes (version , vendor , vendor_type = "framework" ),
196+ SpanAttributes .LLM_PATH : APIS ["INVOKE_MODEL" ]["ENDPOINT" ],
197+ SpanAttributes .LLM_IS_STREAMING : False ,
176198 ** get_extra_attributes (),
177199 }
178200 with tracer .start_as_current_span (
@@ -193,9 +215,11 @@ def patch_invoke_model_with_response_stream(original_method, tracer, version):
193215 @wraps (original_method )
194216 def traced_method (* args , ** kwargs ):
195217 modelId = kwargs .get ("modelId" )
196- ( vendor , _ ) = modelId . split ( "." )
218+ vendor , _ = parse_vendor_and_model_name_from_model_id ( modelId )
197219 span_attributes = {
198220 ** get_langtrace_attributes (version , vendor , vendor_type = "framework" ),
221+ SpanAttributes .LLM_PATH : APIS ["INVOKE_MODEL_WITH_RESPONSE_STREAM" ]["ENDPOINT" ],
222+ SpanAttributes .LLM_IS_STREAMING : True ,
199223 ** get_extra_attributes (),
200224 }
201225 span = tracer .start_span (
@@ -217,7 +241,7 @@ def handle_streaming_call(span, kwargs, response):
217241 def stream_finished (response_body ):
218242 request_body = json .loads (kwargs .get ("body" ))
219243
220- ( vendor , model ) = kwargs .get ("modelId" ). split ( "." )
244+ vendor , model = parse_vendor_and_model_name_from_model_id ( kwargs .get ("modelId" ))
221245
222246 set_span_attribute (span , SpanAttributes .LLM_REQUEST_MODEL , model )
223247 set_span_attribute (span , SpanAttributes .LLM_RESPONSE_MODEL , model )
@@ -241,18 +265,22 @@ def stream_finished(response_body):
241265
242266def handle_call (span , kwargs , response ):
243267 modelId = kwargs .get ("modelId" )
244- (vendor , model_name ) = modelId .split ("." )
268+ vendor , model_name = parse_vendor_and_model_name_from_model_id (modelId )
269+ read_response_body = response .get ("body" ).read ()
270+ request_body = json .loads (kwargs .get ("body" ))
271+ response_body = json .loads (read_response_body )
245272 response ["body" ] = BufferedStreamBody (
246- response [ "body" ]. _raw_stream , response [ "body" ]. _content_length
273+ io . BytesIO ( read_response_body ), len ( read_response_body )
247274 )
248- request_body = json .loads (kwargs .get ("body" ))
249- response_body = json .loads (response .get ("body" ).read ())
250275
251276 set_span_attribute (span , SpanAttributes .LLM_RESPONSE_MODEL , modelId )
252277 set_span_attribute (span , SpanAttributes .LLM_REQUEST_MODEL , modelId )
253278
254279 if vendor == "amazon" :
255- set_amazon_attributes (span , request_body , response_body )
280+ if model_name .startswith ("titan-embed-text" ):
281+ set_amazon_embedding_attributes (span , request_body , response_body )
282+ else :
283+ set_amazon_attributes (span , request_body , response_body )
256284
257285 if vendor == "anthropic" :
258286 if "prompt" in request_body :
@@ -356,6 +384,27 @@ def set_amazon_attributes(span, request_body, response_body):
356384 set_event_completion (span , completions )
357385
358386
387+ def set_amazon_embedding_attributes (span , request_body , response_body ):
388+ input_text = request_body .get ("inputText" )
389+ set_event_prompt (span , input_text )
390+
391+ embeddings = response_body .get ("embedding" , [])
392+ input_tokens = response_body .get ("inputTextTokenCount" )
393+ set_usage_attributes (
394+ span ,
395+ {
396+ "input_tokens" : input_tokens ,
397+ "output" : len (embeddings ),
398+ },
399+ )
400+ set_span_attribute (
401+ span , SpanAttributes .LLM_REQUEST_MODEL , request_body .get ("modelId" )
402+ )
403+ set_span_attribute (
404+ span , SpanAttributes .LLM_RESPONSE_MODEL , request_body .get ("modelId" )
405+ )
406+
407+
359408def set_anthropic_completions_attributes (span , request_body , response_body ):
360409 set_span_attribute (
361410 span ,
@@ -442,10 +491,10 @@ def _set_response_attributes(span, kwargs, result):
442491 )
443492
444493
445- def set_span_streaming_response (span , response ):
494+ def set_span_streaming_response (span , response_stream ):
446495 streaming_response = ""
447496 role = None
448- for event in response [ "stream" ] :
497+ for event in response_stream :
449498 if "messageStart" in event :
450499 role = event ["messageStart" ]["role" ]
451500 elif "contentBlockDelta" in event :
@@ -475,13 +524,15 @@ def __init__(
475524 stream_done_callback = None ,
476525 ):
477526 super ().__init__ (response )
478-
479527 self ._stream_done_callback = stream_done_callback
480528 self ._accumulating_body = {"generation" : "" }
529+ self .last_chunk = None
481530
482531 def __iter__ (self ):
483532 for event in self .__wrapped__ :
533+ # Process the event
484534 self ._process_event (event )
535+ # Yield the original event immediately
485536 yield event
486537
487538 def _process_event (self , event ):
@@ -496,7 +547,11 @@ def _process_event(self, event):
496547 self ._stream_done_callback (decoded_chunk )
497548 return
498549 if "generation" in decoded_chunk :
499- self ._accumulating_body ["generation" ] += decoded_chunk .get ("generation" )
550+ generation = decoded_chunk .get ("generation" )
551+ if self .last_chunk == generation :
552+ return
553+ self .last_chunk = generation
554+ self ._accumulating_body ["generation" ] += generation
500555
501556 if type == "message_start" :
502557 self ._accumulating_body = decoded_chunk .get ("message" )
@@ -505,9 +560,11 @@ def _process_event(self, event):
505560 decoded_chunk .get ("content_block" )
506561 )
507562 elif type == "content_block_delta" :
508- self ._accumulating_body ["content" ][- 1 ]["text" ] += decoded_chunk .get (
509- "delta"
510- ).get ("text" )
563+ text = decoded_chunk .get ("delta" ).get ("text" )
564+ if self .last_chunk == text :
565+ return
566+ self .last_chunk = text
567+ self ._accumulating_body ["content" ][- 1 ]["text" ] += text
511568
512569 elif self .has_finished (type , decoded_chunk ):
513570 self ._accumulating_body ["invocation_metrics" ] = decoded_chunk .get (
0 commit comments