From 545859bf92ef97dbec322a2fd62ebbedb26383b2 Mon Sep 17 00:00:00 2001 From: Ali Waleed Date: Wed, 8 Jan 2025 14:45:04 +0200 Subject: [PATCH 1/2] kickstart refactoring --- src/examples/awsbedrock_examples/__init__.py | 5 +- src/examples/awsbedrock_examples/converse.py | 83 +++++++++-- .../constants/instrumentation/aws_bedrock.py | 4 + .../aws_bedrock/instrumentation.py | 28 ++-- .../instrumentation/aws_bedrock/patch.py | 132 ++++++++++++++---- src/run_example.py | 4 +- 6 files changed, 192 insertions(+), 64 deletions(-) diff --git a/src/examples/awsbedrock_examples/__init__.py b/src/examples/awsbedrock_examples/__init__.py index da9a6508..f5eb9229 100644 --- a/src/examples/awsbedrock_examples/__init__.py +++ b/src/examples/awsbedrock_examples/__init__.py @@ -1,4 +1,4 @@ -from examples.awsbedrock_examples.converse import use_converse +from examples.awsbedrock_examples.converse import use_converse, use_invoke_model_titan from langtrace_python_sdk import langtrace, with_langtrace_root_span langtrace.init() @@ -7,4 +7,5 @@ class AWSBedrockRunner: @with_langtrace_root_span("AWS_Bedrock") def run(self): - use_converse() + # use_converse() + use_invoke_model_titan() diff --git a/src/examples/awsbedrock_examples/converse.py b/src/examples/awsbedrock_examples/converse.py index 9619426b..4b9f95a2 100644 --- a/src/examples/awsbedrock_examples/converse.py +++ b/src/examples/awsbedrock_examples/converse.py @@ -1,17 +1,19 @@ -import os import boto3 +import botocore +import json from langtrace_python_sdk import langtrace +from dotenv import load_dotenv + + +load_dotenv() +langtrace.init() + +brt = boto3.client("bedrock-runtime", region_name="us-east-1") +brc = boto3.client("bedrock", region_name="us-east-1") -langtrace.init(api_key=os.environ["LANGTRACE_API_KEY"]) def use_converse(): model_id = "anthropic.claude-3-haiku-20240307-v1:0" - client = boto3.client( - "bedrock-runtime", - region_name="us-east-1", - aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], - aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], - ) conversation = [ { "role": "user", @@ -20,15 +22,70 @@ def use_converse(): ] try: - response = client.converse( + response = brt.converse( modelId=model_id, messages=conversation, - inferenceConfig={"maxTokens":4096,"temperature":0}, - additionalModelRequestFields={"top_k":250} + inferenceConfig={"maxTokens": 4096, "temperature": 0}, + additionalModelRequestFields={"top_k": 250}, ) response_text = response["output"]["message"]["content"][0]["text"] print(response_text) - except (Exception) as e: + except Exception as e: print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}") - exit(1) \ No newline at end of file + exit(1) + + +def get_foundation_models(): + models = [] + for model in brc.list_foundation_models()["modelSummaries"]: + models.append(model["modelId"]) + return models + + +# Invoke Model API +# Amazon Titan Models +def use_invoke_model_titan(): + try: + prompt_data = "what's 1+1?" + body = json.dumps( + { + "inputText": prompt_data, + "textGenerationConfig": { + "maxTokenCount": 1024, + "topP": 0.95, + "temperature": 0.2, + }, + } + ) + modelId = "amazon.titan-text-express-v1" # "amazon.titan-tg1-large" + accept = "application/json" + contentType = "application/json" + + response = brt.invoke_model( + body=body, modelId=modelId, accept=accept, contentType=contentType + ) + response_body = json.loads(response.get("body").read()) + + # print(response_body.get("results")) + + except botocore.exceptions.ClientError as error: + + if error.response["Error"]["Code"] == "AccessDeniedException": + print( + f"\x1b[41m{error.response['Error']['Message']}\ + \nTo troubeshoot this issue please refer to the following resources.\ + \nhttps://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_access-denied.html\ + \nhttps://docs.aws.amazon.com/bedrock/latest/userguide/security-iam.html\x1b[0m\n" + ) + + else: + raise error + + +# Anthropic Models +def use_invoke_model_anthropic(): + pass + + +# print(get_foundation_models()) diff --git a/src/langtrace_python_sdk/constants/instrumentation/aws_bedrock.py b/src/langtrace_python_sdk/constants/instrumentation/aws_bedrock.py index 6a7cedef..bdbe8fdb 100644 --- a/src/langtrace_python_sdk/constants/instrumentation/aws_bedrock.py +++ b/src/langtrace_python_sdk/constants/instrumentation/aws_bedrock.py @@ -1,6 +1,10 @@ from langtrace.trace_attributes import AWSBedrockMethods APIS = { + "INVOKE_MODEL": { + "METHOD": "aws_bedrock.invoke_model", + "ENDPOINT": "/invoke-model", + }, "CONVERSE": { "METHOD": AWSBedrockMethods.CONVERSE.value, "ENDPOINT": "/converse", diff --git a/src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py b/src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py index 601ed9ff..53aaaba3 100644 --- a/src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py +++ b/src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py @@ -23,21 +23,17 @@ from wrapt import wrap_function_wrapper as _W from langtrace_python_sdk.instrumentation.aws_bedrock.patch import ( - converse, converse_stream + converse, + invoke_model, + converse_stream, + patch_aws_bedrock, ) logging.basicConfig(level=logging.FATAL) -def _patch_client(client, version: str, tracer) -> None: - - # Store original methods - original_converse = client.converse - - # Replace with wrapped versions - client.converse = converse("aws_bedrock.converse", version, tracer)(original_converse) class AWSBedrockInstrumentation(BaseInstrumentor): - + def instrumentation_dependencies(self) -> Collection[str]: return ["boto3 >= 1.35.31"] @@ -46,13 +42,11 @@ def _instrument(self, **kwargs): tracer = get_tracer(__name__, "", tracer_provider) version = importlib.metadata.version("boto3") - def wrap_create_client(wrapped, instance, args, kwargs): - result = wrapped(*args, **kwargs) - if args and args[0] == 'bedrock-runtime': - _patch_client(result, version, tracer) - return result - - _W("boto3", "client", wrap_create_client) + _W( + module="boto3", + name="client", + wrapper=patch_aws_bedrock(tracer, version), + ) def _uninstrument(self, **kwargs): - pass \ No newline at end of file + pass diff --git a/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py b/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py index 3acf0e0e..397f3f82 100644 --- a/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py +++ b/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py @@ -48,30 +48,42 @@ def wrapper(original_method): @wraps(original_method) def wrapped_method(*args, **kwargs): service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"] - + print("Here's the kwargs: ", kwargs) input_content = [ { - 'role': message.get('role', 'user'), - 'content': message.get('content', [])[0].get('text', "") + "role": message.get("role", "user"), + "content": message.get("content", [])[0].get("text", ""), } - for message in kwargs.get('messages', []) + for message in kwargs.get("messages", []) ] - + span_attributes = { - **get_langtrace_attributes(version, service_provider, vendor_type="framework"), - **get_llm_request_attributes(kwargs, operation_name=operation_name, prompts=input_content), + **get_langtrace_attributes( + version, service_provider, vendor_type="framework" + ), + **get_llm_request_attributes( + kwargs, operation_name=operation_name, prompts=input_content + ), **get_llm_url(args[0] if args else None), SpanAttributes.LLM_PATH: APIS[api_name]["ENDPOINT"], **get_extra_attributes(), } if api_name == "CONVERSE": - span_attributes.update({ - SpanAttributes.LLM_REQUEST_MODEL: kwargs.get('modelId'), - SpanAttributes.LLM_REQUEST_MAX_TOKENS: kwargs.get('inferenceConfig', {}).get('maxTokens'), - SpanAttributes.LLM_REQUEST_TEMPERATURE: kwargs.get('inferenceConfig', {}).get('temperature'), - SpanAttributes.LLM_REQUEST_TOP_P: kwargs.get('inferenceConfig', {}).get('top_p'), - }) + span_attributes.update( + { + SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("modelId"), + SpanAttributes.LLM_REQUEST_MAX_TOKENS: kwargs.get( + "inferenceConfig", {} + ).get("maxTokens"), + SpanAttributes.LLM_REQUEST_TEMPERATURE: kwargs.get( + "inferenceConfig", {} + ).get("temperature"), + SpanAttributes.LLM_REQUEST_TOP_P: kwargs.get( + "inferenceConfig", {} + ).get("top_p"), + } + ) attributes = LLMSpanAttributes(**span_attributes) @@ -92,20 +104,22 @@ def wrapped_method(*args, **kwargs): raise err return wrapped_method + return wrapper + return decorator converse = traced_aws_bedrock_call("CONVERSE", "converse") +invoke_model = traced_aws_bedrock_call("INVOKE_MODEL", "invoke_model") def converse_stream(original_method, version, tracer): def traced_method(wrapped, instance, args, kwargs): service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"] - + span_attributes = { - **get_langtrace_attributes - (version, service_provider, vendor_type="llm"), + **get_langtrace_attributes(version, service_provider, vendor_type="llm"), **get_llm_request_attributes(kwargs), **get_llm_url(instance), SpanAttributes.LLM_PATH: APIS["CONVERSE_STREAM"]["ENDPOINT"], @@ -129,29 +143,87 @@ def traced_method(wrapped, instance, args, kwargs): span.record_exception(err) span.set_status(Status(StatusCode.ERROR, str(err))) raise err - + return traced_method @silently_fail def _set_response_attributes(span, kwargs, result): - set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, kwargs.get('modelId')) - set_span_attribute(span, SpanAttributes.LLM_TOP_K, kwargs.get('additionalModelRequestFields', {}).get('top_k')) - content = result.get('output', {}).get('message', {}).get('content', []) + set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, kwargs.get("modelId")) + set_span_attribute( + span, + SpanAttributes.LLM_TOP_K, + kwargs.get("additionalModelRequestFields", {}).get("top_k"), + ) + content = result.get("output", {}).get("message", {}).get("content", []) if len(content) > 0: - role = result.get('output', {}).get('message', {}).get('role', "assistant") - responses = [ - {"role": role, "content": c.get('text', "")} - for c in content - ] + role = result.get("output", {}).get("message", {}).get("role", "assistant") + responses = [{"role": role, "content": c.get("text", "")} for c in content] set_event_completion(span, responses) - if 'usage' in result: + if "usage" in result: set_span_attributes( span, { - SpanAttributes.LLM_USAGE_COMPLETION_TOKENS: result['usage'].get('outputTokens'), - SpanAttributes.LLM_USAGE_PROMPT_TOKENS: result['usage'].get('inputTokens'), - SpanAttributes.LLM_USAGE_TOTAL_TOKENS: result['usage'].get('totalTokens'), - } + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS: result["usage"].get( + "outputTokens" + ), + SpanAttributes.LLM_USAGE_PROMPT_TOKENS: result["usage"].get( + "inputTokens" + ), + SpanAttributes.LLM_USAGE_TOTAL_TOKENS: result["usage"].get( + "totalTokens" + ), + }, + ) + + +def patch_aws_bedrock(tracer, version): + def traced_method(wrapped, instance, args, kwargs): + if args and args[0] != "bedrock-runtime": + return + + client = wrapped(*args, **kwargs) + print("Here's the client: ", client) + client.invoke_model = patch_invoke_model(client.invoke_model, tracer, version) + client.invoke_model_with_response_stream = patch_invoke_model( + client.invoke_model_with_response_stream, tracer, version + ) + client.converse = patch_invoke_model(client.converse, tracer, version) + client.converse_stream = patch_invoke_model( + client.converse_stream, tracer, version ) + return client + + return traced_method + + +def patch_invoke_model(original_method, tracer, version): + def traced_method(*args, **kwargs): + service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"] + span_attributes = { + **get_langtrace_attributes( + version, service_provider, vendor_type="framework" + ), + **get_extra_attributes(), + } + with tracer.start_as_current_span( + name=get_span_name("aws_bedrock.invoke_model"), + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) as span: + set_span_attributes(span, span_attributes) + set_invoke_model_attributes(span, kwargs) + response = original_method(*args, **kwargs) + return response + + return traced_method + + +def set_invoke_model_attributes(span, kwargs): + modelId = kwargs.get("modelId") + (vendor, model_name) = modelId.split(".") + + print("Here's the vendor: ", vendor) + print("Here's the model_name: ", model_name) + print("Here's the kwargs: ", kwargs) diff --git a/src/run_example.py b/src/run_example.py index ab925665..fcdde43a 100644 --- a/src/run_example.py +++ b/src/run_example.py @@ -4,7 +4,7 @@ "anthropic": False, "azureopenai": False, "chroma": False, - "cohere": True, + "cohere": False, "fastapi": False, "langchain": False, "llamaindex": False, @@ -20,7 +20,7 @@ "vertexai": False, "gemini": False, "mistral": False, - "awsbedrock": False, + "awsbedrock": True, "cerebras": False, } From a0eb08b1c22c38145a989ce23c2d4c100708dc9f Mon Sep 17 00:00:00 2001 From: Ali Waleed Date: Mon, 13 Jan 2025 14:45:06 +0200 Subject: [PATCH 2/2] Add support for bedrock --- src/examples/awsbedrock_examples/__init__.py | 18 +- src/examples/awsbedrock_examples/converse.py | 115 ++++- .../aws_bedrock/bedrock_streaming_wrapper.py | 43 ++ .../aws_bedrock/instrumentation.py | 7 +- .../instrumentation/aws_bedrock/patch.py | 409 ++++++++++++------ .../aws_bedrock/stream_body_wrapper.py | 41 ++ src/langtrace_python_sdk/version.py | 2 +- 7 files changed, 486 insertions(+), 149 deletions(-) create mode 100644 src/langtrace_python_sdk/instrumentation/aws_bedrock/bedrock_streaming_wrapper.py create mode 100644 src/langtrace_python_sdk/instrumentation/aws_bedrock/stream_body_wrapper.py diff --git a/src/examples/awsbedrock_examples/__init__.py b/src/examples/awsbedrock_examples/__init__.py index f5eb9229..8fa9eb6e 100644 --- a/src/examples/awsbedrock_examples/__init__.py +++ b/src/examples/awsbedrock_examples/__init__.py @@ -1,11 +1,19 @@ -from examples.awsbedrock_examples.converse import use_converse, use_invoke_model_titan +from examples.awsbedrock_examples.converse import ( + use_converse_stream, + use_converse, + use_invoke_model_anthropic, + use_invoke_model_cohere, + use_invoke_model_amazon, +) from langtrace_python_sdk import langtrace, with_langtrace_root_span -langtrace.init() - class AWSBedrockRunner: @with_langtrace_root_span("AWS_Bedrock") def run(self): - # use_converse() - use_invoke_model_titan() + + use_converse_stream() + use_converse() + use_invoke_model_anthropic() + use_invoke_model_cohere() + use_invoke_model_amazon() diff --git a/src/examples/awsbedrock_examples/converse.py b/src/examples/awsbedrock_examples/converse.py index 4b9f95a2..1401636e 100644 --- a/src/examples/awsbedrock_examples/converse.py +++ b/src/examples/awsbedrock_examples/converse.py @@ -1,23 +1,46 @@ import boto3 -import botocore import json from langtrace_python_sdk import langtrace from dotenv import load_dotenv - +import botocore load_dotenv() -langtrace.init() +langtrace.init(write_spans_to_console=False) brt = boto3.client("bedrock-runtime", region_name="us-east-1") brc = boto3.client("bedrock", region_name="us-east-1") +def use_converse_stream(): + model_id = "anthropic.claude-3-haiku-20240307-v1:0" + conversation = [ + { + "role": "user", + "content": [{"text": "what is the capital of France?"}], + } + ] + + try: + response = brt.converse_stream( + modelId=model_id, + messages=conversation, + inferenceConfig={"maxTokens": 4096, "temperature": 0}, + additionalModelRequestFields={"top_k": 250}, + ) + # response_text = response["output"]["message"]["content"][0]["text"] + print(response) + + except Exception as e: + print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}") + exit(1) + + def use_converse(): model_id = "anthropic.claude-3-haiku-20240307-v1:0" conversation = [ { "role": "user", - "content": [{"text": "Write a story about a magic backpack."}], + "content": [{"text": "what is the capital of France?"}], } ] @@ -37,17 +60,15 @@ def use_converse(): def get_foundation_models(): - models = [] for model in brc.list_foundation_models()["modelSummaries"]: - models.append(model["modelId"]) - return models + print(model["modelId"]) # Invoke Model API # Amazon Titan Models -def use_invoke_model_titan(): +def use_invoke_model_titan(stream=False): try: - prompt_data = "what's 1+1?" + prompt_data = "what's the capital of France?" body = json.dumps( { "inputText": prompt_data, @@ -62,12 +83,16 @@ def use_invoke_model_titan(): accept = "application/json" contentType = "application/json" - response = brt.invoke_model( - body=body, modelId=modelId, accept=accept, contentType=contentType - ) - response_body = json.loads(response.get("body").read()) + if stream: - # print(response_body.get("results")) + response = brt.invoke_model_with_response_stream( + body=body, modelId=modelId, accept=accept, contentType=contentType + ) + else: + response = brt.invoke_model( + body=body, modelId=modelId, accept=accept, contentType=contentType + ) + response_body = json.loads(response.get("body").read()) except botocore.exceptions.ClientError as error: @@ -84,8 +109,66 @@ def use_invoke_model_titan(): # Anthropic Models -def use_invoke_model_anthropic(): - pass +def use_invoke_model_anthropic(stream=False): + body = json.dumps( + { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": 1024, + "temperature": 0.1, + "top_p": 0.9, + "messages": [{"role": "user", "content": "Hello, Claude"}], + } + ) + modelId = "anthropic.claude-v2" + accept = "application/json" + contentType = "application/json" + + if stream: + response = brt.invoke_model_with_response_stream(body=body, modelId=modelId) + stream_response = response.get("body") + if stream_response: + for event in stream_response: + chunk = event.get("chunk") + if chunk: + print(json.loads(chunk.get("bytes").decode())) + + else: + response = brt.invoke_model( + body=body, modelId=modelId, accept=accept, contentType=contentType + ) + response_body = json.loads(response.get("body").read()) + # text + print(response_body.get("completion")) + + +def use_invoke_model_llama(): + model_id = "meta.llama3-8b-instruct-v1:0" + prompt = "What is the capital of France?" + max_gen_len = 128 + temperature = 0.1 + top_p = 0.9 + + # Create request body. + body = json.dumps( + { + "prompt": prompt, + "max_gen_len": max_gen_len, + "temperature": temperature, + "top_p": top_p, + } + ) + response = brt.invoke_model(body=body, modelId=model_id) + + response_body = json.loads(response.get("body").read()) + + return response_body # print(get_foundation_models()) +def use_invoke_model_cohere(): + model_id = "cohere.command-r-plus-v1" + prompt = "What is the capital of France?" + body = json.dumps({"prompt": prompt, "max_tokens": 1024, "temperature": 0.1}) + response = brt.invoke_model(body=body, modelId=model_id) + response_body = json.loads(response.get("body").read()) + print(response_body) diff --git a/src/langtrace_python_sdk/instrumentation/aws_bedrock/bedrock_streaming_wrapper.py b/src/langtrace_python_sdk/instrumentation/aws_bedrock/bedrock_streaming_wrapper.py new file mode 100644 index 00000000..2792b023 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/aws_bedrock/bedrock_streaming_wrapper.py @@ -0,0 +1,43 @@ +import json +from wrapt import ObjectProxy + + +class StreamingWrapper(ObjectProxy): + def __init__( + self, + response, + stream_done_callback=None, + ): + super().__init__(response) + + self._stream_done_callback = stream_done_callback + self._accumulating_body = {} + + def __iter__(self): + for event in self.__wrapped__: + self._process_event(event) + yield event + + def _process_event(self, event): + chunk = event.get("chunk") + if not chunk: + return + + decoded_chunk = json.loads(chunk.get("bytes").decode()) + type = decoded_chunk.get("type") + + if type == "message_start": + self._accumulating_body = decoded_chunk.get("message") + elif type == "content_block_start": + self._accumulating_body["content"].append( + decoded_chunk.get("content_block") + ) + elif type == "content_block_delta": + self._accumulating_body["content"][-1]["text"] += decoded_chunk.get( + "delta" + ).get("text") + elif type == "message_stop" and self._stream_done_callback: + self._accumulating_body["invocation_metrics"] = decoded_chunk.get( + "amazon-bedrock-invocationMetrics" + ) + self._stream_done_callback(self._accumulating_body) diff --git a/src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py b/src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py index 53aaaba3..7ae204e0 100644 --- a/src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py +++ b/src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py @@ -22,12 +22,7 @@ from opentelemetry.trace import get_tracer from wrapt import wrap_function_wrapper as _W -from langtrace_python_sdk.instrumentation.aws_bedrock.patch import ( - converse, - invoke_model, - converse_stream, - patch_aws_bedrock, -) +from langtrace_python_sdk.instrumentation.aws_bedrock.patch import patch_aws_bedrock logging.basicConfig(level=logging.FATAL) diff --git a/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py b/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py index 397f3f82..576df737 100644 --- a/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py +++ b/src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py @@ -15,8 +15,12 @@ """ import json -from functools import wraps +from langtrace_python_sdk.instrumentation.aws_bedrock.bedrock_streaming_wrapper import ( + StreamingWrapper, +) +from .stream_body_wrapper import BufferedStreamBody +from functools import wraps from langtrace.trace_attributes import ( LLMSpanAttributes, SpanAttributes, @@ -39,81 +43,10 @@ get_span_name, set_event_completion, set_span_attributes, + set_usage_attributes, ) -def traced_aws_bedrock_call(api_name: str, operation_name: str): - def decorator(method_name: str, version: str, tracer): - def wrapper(original_method): - @wraps(original_method) - def wrapped_method(*args, **kwargs): - service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"] - print("Here's the kwargs: ", kwargs) - input_content = [ - { - "role": message.get("role", "user"), - "content": message.get("content", [])[0].get("text", ""), - } - for message in kwargs.get("messages", []) - ] - - span_attributes = { - **get_langtrace_attributes( - version, service_provider, vendor_type="framework" - ), - **get_llm_request_attributes( - kwargs, operation_name=operation_name, prompts=input_content - ), - **get_llm_url(args[0] if args else None), - SpanAttributes.LLM_PATH: APIS[api_name]["ENDPOINT"], - **get_extra_attributes(), - } - - if api_name == "CONVERSE": - span_attributes.update( - { - SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("modelId"), - SpanAttributes.LLM_REQUEST_MAX_TOKENS: kwargs.get( - "inferenceConfig", {} - ).get("maxTokens"), - SpanAttributes.LLM_REQUEST_TEMPERATURE: kwargs.get( - "inferenceConfig", {} - ).get("temperature"), - SpanAttributes.LLM_REQUEST_TOP_P: kwargs.get( - "inferenceConfig", {} - ).get("top_p"), - } - ) - - attributes = LLMSpanAttributes(**span_attributes) - - with tracer.start_as_current_span( - name=get_span_name(APIS[api_name]["METHOD"]), - kind=SpanKind.CLIENT, - context=set_span_in_context(trace.get_current_span()), - ) as span: - set_span_attributes(span, attributes) - try: - result = original_method(*args, **kwargs) - _set_response_attributes(span, kwargs, result) - span.set_status(StatusCode.OK) - return result - except Exception as err: - span.record_exception(err) - span.set_status(Status(StatusCode.ERROR, str(err))) - raise err - - return wrapped_method - - return wrapper - - return decorator - - -converse = traced_aws_bedrock_call("CONVERSE", "converse") -invoke_model = traced_aws_bedrock_call("INVOKE_MODEL", "invoke_model") - - def converse_stream(original_method, version, tracer): def traced_method(wrapped, instance, args, kwargs): service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"] @@ -147,64 +80,96 @@ def traced_method(wrapped, instance, args, kwargs): return traced_method -@silently_fail -def _set_response_attributes(span, kwargs, result): - set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, kwargs.get("modelId")) - set_span_attribute( - span, - SpanAttributes.LLM_TOP_K, - kwargs.get("additionalModelRequestFields", {}).get("top_k"), - ) - content = result.get("output", {}).get("message", {}).get("content", []) - if len(content) > 0: - role = result.get("output", {}).get("message", {}).get("role", "assistant") - responses = [{"role": role, "content": c.get("text", "")} for c in content] - set_event_completion(span, responses) - - if "usage" in result: - set_span_attributes( - span, - { - SpanAttributes.LLM_USAGE_COMPLETION_TOKENS: result["usage"].get( - "outputTokens" - ), - SpanAttributes.LLM_USAGE_PROMPT_TOKENS: result["usage"].get( - "inputTokens" - ), - SpanAttributes.LLM_USAGE_TOTAL_TOKENS: result["usage"].get( - "totalTokens" - ), - }, - ) - - def patch_aws_bedrock(tracer, version): def traced_method(wrapped, instance, args, kwargs): if args and args[0] != "bedrock-runtime": - return + return wrapped(*args, **kwargs) client = wrapped(*args, **kwargs) - print("Here's the client: ", client) client.invoke_model = patch_invoke_model(client.invoke_model, tracer, version) - client.invoke_model_with_response_stream = patch_invoke_model( - client.invoke_model_with_response_stream, tracer, version - ) - client.converse = patch_invoke_model(client.converse, tracer, version) - client.converse_stream = patch_invoke_model( + + client.converse = patch_converse(client.converse, tracer, version) + client.converse_stream = patch_converse_stream( client.converse_stream, tracer, version ) + return client return traced_method +def patch_converse_stream(original_method, tracer, version): + def traced_method(*args, **kwargs): + modelId = kwargs.get("modelId") + (vendor, _) = modelId.split(".") + input_content = [ + { + "role": message.get("role", "user"), + "content": message.get("content", [])[0].get("text", ""), + } + for message in kwargs.get("messages", []) + ] + + span_attributes = { + **get_langtrace_attributes(version, vendor, vendor_type="framework"), + **get_llm_request_attributes(kwargs, model=modelId, prompts=input_content), + **get_llm_url(args[0] if args else None), + **get_extra_attributes(), + } + with tracer.start_as_current_span( + name=get_span_name("aws_bedrock.converse"), + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) as span: + set_span_attributes(span, span_attributes) + response = original_method(*args, **kwargs) + + if span.is_recording(): + set_span_streaming_response(span, response) + return response + + return traced_method + + +def patch_converse(original_method, tracer, version): + def traced_method(*args, **kwargs): + modelId = kwargs.get("modelId") + (vendor, _) = modelId.split(".") + input_content = [ + { + "role": message.get("role", "user"), + "content": message.get("content", [])[0].get("text", ""), + } + for message in kwargs.get("messages", []) + ] + + span_attributes = { + **get_langtrace_attributes(version, vendor, vendor_type="framework"), + **get_llm_request_attributes(kwargs, model=modelId, prompts=input_content), + **get_llm_url(args[0] if args else None), + **get_extra_attributes(), + } + with tracer.start_as_current_span( + name=get_span_name("aws_bedrock.converse"), + kind=SpanKind.CLIENT, + context=set_span_in_context(trace.get_current_span()), + ) as span: + set_span_attributes(span, span_attributes) + response = original_method(*args, **kwargs) + + if span.is_recording(): + _set_response_attributes(span, kwargs, response) + return response + + return traced_method + + def patch_invoke_model(original_method, tracer, version): def traced_method(*args, **kwargs): - service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"] + modelId = kwargs.get("modelId") + (vendor, _) = modelId.split(".") span_attributes = { - **get_langtrace_attributes( - version, service_provider, vendor_type="framework" - ), + **get_langtrace_attributes(version, vendor, vendor_type="framework"), **get_extra_attributes(), } with tracer.start_as_current_span( @@ -213,17 +178,219 @@ def traced_method(*args, **kwargs): context=set_span_in_context(trace.get_current_span()), ) as span: set_span_attributes(span, span_attributes) - set_invoke_model_attributes(span, kwargs) response = original_method(*args, **kwargs) + if span.is_recording(): + handle_call(span, kwargs, response) return response return traced_method -def set_invoke_model_attributes(span, kwargs): +def handle_call(span, kwargs, response): modelId = kwargs.get("modelId") (vendor, model_name) = modelId.split(".") - - print("Here's the vendor: ", vendor) - print("Here's the model_name: ", model_name) - print("Here's the kwargs: ", kwargs) + response["body"] = BufferedStreamBody( + response["body"]._raw_stream, response["body"]._content_length + ) + request_body = json.loads(kwargs.get("body")) + response_body = json.loads(response.get("body").read()) + + set_span_attribute(span, SpanAttributes.LLM_SYSTEM, vendor) + set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, modelId) + set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, modelId) + + if vendor == "amazon": + set_amazon_attributes(span, request_body, response_body) + + if vendor == "anthropic": + if "prompt" in request_body: + set_anthropic_completions_attributes(span, request_body, response_body) + elif "messages" in request_body: + set_anthropic_messages_attributes(span, request_body, response_body) + + if vendor == "meta": + set_llama_meta_attributes(span, request_body, response_body) + + +def set_llama_meta_attributes(span, request_body, response_body): + set_span_attribute( + span, SpanAttributes.LLM_REQUEST_TOP_P, request_body.get("top_p") + ) + set_span_attribute( + span, SpanAttributes.LLM_REQUEST_TEMPERATURE, request_body.get("temperature") + ) + set_span_attribute( + span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, request_body.get("max_gen_len") + ) + + set_usage_attributes( + span, + { + "input_tokens": response_body.get("prompt_token_count"), + "output_tokens": response_body.get("generation_token_count"), + }, + ) + + prompts = [ + { + "role": "user", + "content": request_body.get("prompt"), + } + ] + + completions = [ + { + "role": "assistant", + "content": response_body.get("generation"), + } + ] + set_span_attribute(span, SpanAttributes.LLM_PROMPTS, json.dumps(prompts)) + print(completions) + set_event_completion(span, completions) + + +def set_amazon_attributes(span, request_body, response_body): + config = request_body.get("textGenerationConfig", {}) + prompts = [ + { + "role": "user", + "content": request_body.get("inputText"), + } + ] + completions = [ + { + "role": "assistant", + "content": result.get("outputText"), + } + for result in response_body.get("results") + ] + set_span_attribute( + span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, config.get("maxTokenCount") + ) + set_span_attribute( + span, SpanAttributes.LLM_REQUEST_TEMPERATURE, config.get("temperature") + ) + set_span_attribute(span, SpanAttributes.LLM_REQUEST_TOP_P, config.get("topP")) + set_span_attribute(span, SpanAttributes.LLM_PROMPTS, json.dumps(prompts)) + set_usage_attributes( + span, + { + "input_tokens": response_body.get("inputTextTokenCount"), + "output_tokens": sum( + int(result.get("tokenCount")) for result in response_body.get("results") + ), + }, + ) + set_event_completion(span, completions) + + +def set_anthropic_completions_attributes(span, request_body, response_body): + set_span_attribute( + span, + SpanAttributes.LLM_REQUEST_MAX_TOKENS, + request_body.get("max_tokens_to_sample"), + ) + set_span_attribute( + span, + SpanAttributes.LLM_REQUEST_TEMPERATURE, + str(request_body.get("temperature")), + ) + set_span_attribute( + span, + SpanAttributes.LLM_REQUEST_TOP_P, + str(request_body.get("top_p")), + ) + prompts = [ + { + "role": "user", + "content": request_body.get("prompt"), + } + ] + completions = [ + { + "role": "assistant", + "content": response_body.get("completion"), + } + ] + set_span_attribute(span, SpanAttributes.LLM_PROMPTS, json.dumps(prompts)) + set_event_completion(span, completions) + + +def set_anthropic_messages_attributes(span, request_body, response_body): + set_span_attribute( + span, + SpanAttributes.LLM_REQUEST_MAX_TOKENS, + request_body.get("max_tokens_to_sample"), + ) + set_span_attribute( + span, + SpanAttributes.LLM_REQUEST_TEMPERATURE, + str(request_body.get("temperature")), + ) + set_span_attribute( + span, + SpanAttributes.LLM_REQUEST_TOP_P, + str(request_body.get("top_p")), + ) + set_span_attribute( + span, SpanAttributes.LLM_PROMPTS, json.dumps(request_body.get("messages")) + ) + set_event_completion(span, response_body.get("content")) + set_usage_attributes(span, response_body.get("usage")) + + +@silently_fail +def _set_response_attributes(span, kwargs, result): + set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, kwargs.get("modelId")) + set_span_attribute( + span, + SpanAttributes.LLM_TOP_K, + kwargs.get("additionalModelRequestFields", {}).get("top_k"), + ) + content = result.get("output", {}).get("message", {}).get("content", []) + if len(content) > 0: + role = result.get("output", {}).get("message", {}).get("role", "assistant") + responses = [{"role": role, "content": c.get("text", "")} for c in content] + set_event_completion(span, responses) + + if "usage" in result: + set_span_attributes( + span, + { + SpanAttributes.LLM_USAGE_COMPLETION_TOKENS: result["usage"].get( + "outputTokens" + ), + SpanAttributes.LLM_USAGE_PROMPT_TOKENS: result["usage"].get( + "inputTokens" + ), + SpanAttributes.LLM_USAGE_TOTAL_TOKENS: result["usage"].get( + "totalTokens" + ), + }, + ) + + +def set_span_streaming_response(span, response): + streaming_response = "" + role = None + for event in response["stream"]: + if "messageStart" in event: + role = event["messageStart"]["role"] + elif "contentBlockDelta" in event: + delta = event["contentBlockDelta"]["delta"] + if "text" in delta: + streaming_response += delta["text"] + elif "metadata" in event and "usage" in event["metadata"]: + usage = event["metadata"]["usage"] + set_usage_attributes( + span, + { + "input_tokens": usage.get("inputTokens"), + "output_tokens": usage.get("outputTokens"), + }, + ) + + if streaming_response: + set_event_completion( + span, [{"role": role or "assistant", "content": streaming_response}] + ) diff --git a/src/langtrace_python_sdk/instrumentation/aws_bedrock/stream_body_wrapper.py b/src/langtrace_python_sdk/instrumentation/aws_bedrock/stream_body_wrapper.py new file mode 100644 index 00000000..01fc4403 --- /dev/null +++ b/src/langtrace_python_sdk/instrumentation/aws_bedrock/stream_body_wrapper.py @@ -0,0 +1,41 @@ +from botocore.response import StreamingBody +from botocore.exceptions import ( + ReadTimeoutError, + ResponseStreamingError, +) +from urllib3.exceptions import ProtocolError as URLLib3ProtocolError +from urllib3.exceptions import ReadTimeoutError as URLLib3ReadTimeoutError + + +class BufferedStreamBody(StreamingBody): + def __init__(self, raw_stream, content_length): + super().__init__(raw_stream, content_length) + self._buffer = None + self._buffer_cursor = 0 + + def read(self, amt=None): + """Read at most amt bytes from the stream. + + If the amt argument is omitted, read all data. + """ + if self._buffer is None: + try: + self._buffer = self._raw_stream.read() + except URLLib3ReadTimeoutError as e: + # TODO: the url will be None as urllib3 isn't setting it yet + raise ReadTimeoutError(endpoint_url=e.url, error=e) + except URLLib3ProtocolError as e: + raise ResponseStreamingError(error=e) + + self._amount_read += len(self._buffer) + if amt is None or (not self._buffer and amt > 0): + # If the server sends empty contents or + # we ask to read all of the contents, then we know + # we need to verify the content length. + self._verify_content_length() + + if amt is None: + return self._buffer[self._buffer_cursor :] + else: + self._buffer_cursor += amt + return self._buffer[self._buffer_cursor - amt : self._buffer_cursor] diff --git a/src/langtrace_python_sdk/version.py b/src/langtrace_python_sdk/version.py index bf476885..4e499ef0 100644 --- a/src/langtrace_python_sdk/version.py +++ b/src/langtrace_python_sdk/version.py @@ -1 +1 @@ -__version__ = "3.3.21" +__version__ = "3.3.22"