diff --git a/src/examples/awsbedrock_examples/__init__.py b/src/examples/awsbedrock_examples/__init__.py index 8fa9eb6e..a69d7f71 100644 --- a/src/examples/awsbedrock_examples/__init__.py +++ b/src/examples/awsbedrock_examples/__init__.py @@ -1,9 +1,7 @@ from examples.awsbedrock_examples.converse import ( - use_converse_stream, - use_converse, use_invoke_model_anthropic, - use_invoke_model_cohere, - use_invoke_model_amazon, + use_invoke_model_titan, + use_invoke_model_llama, ) from langtrace_python_sdk import langtrace, with_langtrace_root_span @@ -12,8 +10,9 @@ class AWSBedrockRunner: @with_langtrace_root_span("AWS_Bedrock") def run(self): - use_converse_stream() - use_converse() - use_invoke_model_anthropic() - use_invoke_model_cohere() - use_invoke_model_amazon() + # use_converse_stream() + # use_converse() + # use_invoke_model_anthropic(stream=True) + # use_invoke_model_cohere() + # use_invoke_model_llama(stream=False) + use_invoke_model_titan(stream=False) diff --git a/src/examples/awsbedrock_examples/converse.py b/src/examples/awsbedrock_examples/converse.py index 1401636e..80f6eaf1 100644 --- a/src/examples/awsbedrock_examples/converse.py +++ b/src/examples/awsbedrock_examples/converse.py @@ -88,6 +88,12 @@ def use_invoke_model_titan(stream=False): response = brt.invoke_model_with_response_stream( body=body, modelId=modelId, accept=accept, contentType=contentType ) + # Extract and print the response text in real-time. + for event in response["body"]: + chunk = json.loads(event["chunk"]["bytes"]) + if "outputText" in chunk: + print(chunk["outputText"], end="") + else: response = brt.invoke_model( body=body, modelId=modelId, accept=accept, contentType=contentType @@ -130,7 +136,8 @@ def use_invoke_model_anthropic(stream=False): for event in stream_response: chunk = event.get("chunk") if chunk: - print(json.loads(chunk.get("bytes").decode())) + # print(json.loads(chunk.get("bytes").decode())) + pass else: response = brt.invoke_model( @@ -141,7 +148,7 @@ def use_invoke_model_anthropic(stream=False): print(response_body.get("completion")) -def use_invoke_model_llama(): +def use_invoke_model_llama(stream=False): model_id = "meta.llama3-8b-instruct-v1:0" prompt = "What is the capital of France?" max_gen_len = 128 @@ -157,11 +164,18 @@ def use_invoke_model_llama(): "top_p": top_p, } ) - response = brt.invoke_model(body=body, modelId=model_id) - - response_body = json.loads(response.get("body").read()) - return response_body + if stream: + response = brt.invoke_model_with_response_stream(body=body, modelId=model_id) + for event in response["body"]: + chunk = json.loads(event["chunk"]["bytes"]) + if "generation" in chunk: + # print(chunk["generation"], end="") + pass + else: + response = brt.invoke_model(body=body, modelId=model_id) + response_body = json.loads(response.get("body").read()) + return response_body # print(get_foundation_models()) diff --git a/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py b/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py index 576df737..2f354d86 100644 --- a/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py +++ b/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py @@ -16,9 +16,7 @@ import json -from langtrace_python_sdk.instrumentation.aws_bedrock.bedrock_streaming_wrapper import ( - StreamingWrapper, -) +from wrapt import ObjectProxy from .stream_body_wrapper import BufferedStreamBody from functools import wraps from langtrace.trace_attributes import ( @@ -87,6 +85,11 @@ def traced_method(wrapped, instance, args, kwargs): client = wrapped(*args, **kwargs) client.invoke_model = patch_invoke_model(client.invoke_model, tracer, version) + client.invoke_model_with_response_stream = ( + patch_invoke_model_with_response_stream( + client.invoke_model_with_response_stream, tracer, version + ) + ) client.converse = patch_converse(client.converse, tracer, version) client.converse_stream = patch_converse_stream( @@ -186,6 +189,56 @@ def traced_method(*args, **kwargs): return traced_method +def patch_invoke_model_with_response_stream(original_method, tracer, version): + @wraps(original_method) + def traced_method(*args, **kwargs): + modelId = kwargs.get("modelId") + (vendor, _) = modelId.split(".") + span_attributes = { + **get_langtrace_attributes(version, vendor, vendor_type="framework"), + **get_extra_attributes(), + } + span = tracer.start_span( + name=get_span_name("aws_bedrock.invoke_model_with_response_stream"), + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) + set_span_attributes(span, span_attributes) + response = original_method(*args, **kwargs) + if span.is_recording(): + handle_streaming_call(span, kwargs, response) + return response + + return traced_method + + +def handle_streaming_call(span, kwargs, response): + + def stream_finished(response_body): + request_body = json.loads(kwargs.get("body")) + + (vendor, model) = kwargs.get("modelId").split(".") + + set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, model) + set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, model) + + if vendor == "amazon": + set_amazon_attributes(span, request_body, response_body) + + if vendor == "anthropic": + if "prompt" in request_body: + set_anthropic_completions_attributes(span, request_body, response_body) + elif "messages" in request_body: + set_anthropic_messages_attributes(span, request_body, response_body) + + if vendor == "meta": + set_llama_meta_attributes(span, request_body, response_body) + + span.end() + + response["body"] = StreamingBedrockWrapper(response["body"], stream_finished) + + def handle_call(span, kwargs, response): modelId = kwargs.get("modelId") (vendor, model_name) = modelId.split(".") @@ -195,7 +248,6 @@ def handle_call(span, kwargs, response): request_body = json.loads(kwargs.get("body")) response_body = json.loads(response.get("body").read()) - set_span_attribute(span, SpanAttributes.LLM_SYSTEM, vendor) set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, modelId) set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, modelId) @@ -222,12 +274,18 @@ def set_llama_meta_attributes(span, request_body, response_body): set_span_attribute( span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, request_body.get("max_gen_len") ) + if "invocation_metrics" in response_body: + input_tokens = response_body.get("invocation_metrics").get("inputTokenCount") + output_tokens = response_body.get("invocation_metrics").get("outputTokenCount") + else: + input_tokens = response_body.get("prompt_token_count") + output_tokens = response_body.get("generation_token_count") set_usage_attributes( span, { - "input_tokens": response_body.get("prompt_token_count"), - "output_tokens": response_body.get("generation_token_count"), + "input_tokens": input_tokens, + "output_tokens": output_tokens, }, ) @@ -245,7 +303,6 @@ def set_llama_meta_attributes(span, request_body, response_body): } ] set_span_attribute(span, SpanAttributes.LLM_PROMPTS, json.dumps(prompts)) - print(completions) set_event_completion(span, completions) @@ -257,13 +314,22 @@ def set_amazon_attributes(span, request_body, response_body): "content": request_body.get("inputText"), } ] - completions = [ - { - "role": "assistant", - "content": result.get("outputText"), - } - for result in response_body.get("results") - ] + if "results" in response_body: + completions = [ + { + "role": "assistant", + "content": result.get("outputText"), + } + for result in response_body.get("results") + ] + + else: + completions = [ + { + "role": "assistant", + "content": response_body.get("outputText"), + } + ] set_span_attribute( span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, config.get("maxTokenCount") ) @@ -272,13 +338,19 @@ def set_amazon_attributes(span, request_body, response_body): ) set_span_attribute(span, SpanAttributes.LLM_REQUEST_TOP_P, config.get("topP")) set_span_attribute(span, SpanAttributes.LLM_PROMPTS, json.dumps(prompts)) + input_tokens = response_body.get("inputTextTokenCount") + if "results" in response_body: + output_tokens = sum( + int(result.get("tokenCount")) for result in response_body.get("results") + ) + else: + output_tokens = response_body.get("outputTextTokenCount") + set_usage_attributes( span, { - "input_tokens": response_body.get("inputTextTokenCount"), - "output_tokens": sum( - int(result.get("tokenCount")) for result in response_body.get("results") - ), + "input_tokens": input_tokens, + "output_tokens": output_tokens, }, ) set_event_completion(span, completions) @@ -320,7 +392,7 @@ def set_anthropic_messages_attributes(span, request_body, response_body): set_span_attribute( span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, - request_body.get("max_tokens_to_sample"), + request_body.get("max_tokens_to_sample") or request_body.get("max_tokens"), ) set_span_attribute( span, @@ -394,3 +466,62 @@ def set_span_streaming_response(span, response): set_event_completion( span, [{"role": role or "assistant", "content": streaming_response}] ) + + +class StreamingBedrockWrapper(ObjectProxy): + def __init__( + self, + response, + stream_done_callback=None, + ): + super().__init__(response) + + self._stream_done_callback = stream_done_callback + self._accumulating_body = {"generation": ""} + + def __iter__(self): + for event in self.__wrapped__: + self._process_event(event) + yield event + + def _process_event(self, event): + chunk = event.get("chunk") + if not chunk: + return + + decoded_chunk = json.loads(chunk.get("bytes").decode()) + type = decoded_chunk.get("type") + + if type is None and "outputText" in decoded_chunk: + self._stream_done_callback(decoded_chunk) + return + if "generation" in decoded_chunk: + self._accumulating_body["generation"] += decoded_chunk.get("generation") + + if type == "message_start": + self._accumulating_body = decoded_chunk.get("message") + elif type == "content_block_start": + self._accumulating_body["content"].append( + decoded_chunk.get("content_block") + ) + elif type == "content_block_delta": + self._accumulating_body["content"][-1]["text"] += decoded_chunk.get( + "delta" + ).get("text") + + elif self.has_finished(type, decoded_chunk): + self._accumulating_body["invocation_metrics"] = decoded_chunk.get( + "amazon-bedrock-invocationMetrics" + ) + self._stream_done_callback(self._accumulating_body) + + def has_finished(self, type, chunk): + if type and type == "message_stop": + return True + + if "completionReason" in chunk and chunk.get("completionReason") == "FINISH": + return True + + if "stop_reason" in chunk and chunk.get("stop_reason") is not None: + return True + return False diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index 086040fa..943ca0bc 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "3.3.23" +__version__ = "3.3.24" diff --git a/src/run_example.py b/src/run_example.py index 80a8cf42..c04c3545 100644 --- a/src/run_example.py +++ b/src/run_example.py @@ -20,9 +20,9 @@ "vertexai": False, "gemini": False, "mistral": False, - "awsbedrock": False, + "awsbedrock": True, "cerebras": False, - "google_genai": True, + "google_genai": False, } if ENABLED_EXAMPLES["anthropic"]: