Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/examples/awsbedrock_examples/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from examples.awsbedrock_examples.converse import (
use_converse_stream,
use_converse,
use_invoke_model_anthropic,
use_invoke_model_cohere,
use_invoke_model_amazon,
use_invoke_model_titan,
use_invoke_model_llama,
)
from langtrace_python_sdk import langtrace, with_langtrace_root_span

Expand All @@ -12,8 +10,9 @@ class AWSBedrockRunner:
@with_langtrace_root_span("AWS_Bedrock")
def run(self):

use_converse_stream()
use_converse()
use_invoke_model_anthropic()
use_invoke_model_cohere()
use_invoke_model_amazon()
# use_converse_stream()
# use_converse()
# use_invoke_model_anthropic(stream=True)
# use_invoke_model_cohere()
# use_invoke_model_llama(stream=False)
use_invoke_model_titan(stream=False)
26 changes: 20 additions & 6 deletions src/examples/awsbedrock_examples/converse.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ def use_invoke_model_titan(stream=False):
response = brt.invoke_model_with_response_stream(
body=body, modelId=modelId, accept=accept, contentType=contentType
)
# Extract and print the response text in real-time.
for event in response["body"]:
chunk = json.loads(event["chunk"]["bytes"])
if "outputText" in chunk:
print(chunk["outputText"], end="")

else:
response = brt.invoke_model(
body=body, modelId=modelId, accept=accept, contentType=contentType
Expand Down Expand Up @@ -130,7 +136,8 @@ def use_invoke_model_anthropic(stream=False):
for event in stream_response:
chunk = event.get("chunk")
if chunk:
print(json.loads(chunk.get("bytes").decode()))
# print(json.loads(chunk.get("bytes").decode()))
pass

else:
response = brt.invoke_model(
Expand All @@ -141,7 +148,7 @@ def use_invoke_model_anthropic(stream=False):
print(response_body.get("completion"))


def use_invoke_model_llama():
def use_invoke_model_llama(stream=False):
model_id = "meta.llama3-8b-instruct-v1:0"
prompt = "What is the capital of France?"
max_gen_len = 128
Expand All @@ -157,11 +164,18 @@ def use_invoke_model_llama():
"top_p": top_p,
}
)
response = brt.invoke_model(body=body, modelId=model_id)

response_body = json.loads(response.get("body").read())

return response_body
if stream:
response = brt.invoke_model_with_response_stream(body=body, modelId=model_id)
for event in response["body"]:
chunk = json.loads(event["chunk"]["bytes"])
if "generation" in chunk:
# print(chunk["generation"], end="")
pass
else:
response = brt.invoke_model(body=body, modelId=model_id)
response_body = json.loads(response.get("body").read())
return response_body


# print(get_foundation_models())
Expand Down
169 changes: 150 additions & 19 deletions src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@

import json

from langtrace_python_sdk.instrumentation.aws_bedrock.bedrock_streaming_wrapper import (
StreamingWrapper,
)
from wrapt import ObjectProxy
from .stream_body_wrapper import BufferedStreamBody
from functools import wraps
from langtrace.trace_attributes import (
Expand Down Expand Up @@ -87,6 +85,11 @@ def traced_method(wrapped, instance, args, kwargs):

client = wrapped(*args, **kwargs)
client.invoke_model = patch_invoke_model(client.invoke_model, tracer, version)
client.invoke_model_with_response_stream = (
patch_invoke_model_with_response_stream(
client.invoke_model_with_response_stream, tracer, version
)
)

client.converse = patch_converse(client.converse, tracer, version)
client.converse_stream = patch_converse_stream(
Expand Down Expand Up @@ -186,6 +189,56 @@ def traced_method(*args, **kwargs):
return traced_method


def patch_invoke_model_with_response_stream(original_method, tracer, version):
@wraps(original_method)
def traced_method(*args, **kwargs):
modelId = kwargs.get("modelId")
(vendor, _) = modelId.split(".")
span_attributes = {
**get_langtrace_attributes(version, vendor, vendor_type="framework"),
**get_extra_attributes(),
}
span = tracer.start_span(
name=get_span_name("aws_bedrock.invoke_model_with_response_stream"),
kind=SpanKind.CLIENT,
context=set_span_in_context(trace.get_current_span()),
)
set_span_attributes(span, span_attributes)
response = original_method(*args, **kwargs)
if span.is_recording():
handle_streaming_call(span, kwargs, response)
return response

return traced_method


def handle_streaming_call(span, kwargs, response):

def stream_finished(response_body):
request_body = json.loads(kwargs.get("body"))

(vendor, model) = kwargs.get("modelId").split(".")

set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, model)
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, model)

if vendor == "amazon":
set_amazon_attributes(span, request_body, response_body)

if vendor == "anthropic":
if "prompt" in request_body:
set_anthropic_completions_attributes(span, request_body, response_body)
elif "messages" in request_body:
set_anthropic_messages_attributes(span, request_body, response_body)

if vendor == "meta":
set_llama_meta_attributes(span, request_body, response_body)

span.end()

response["body"] = StreamingBedrockWrapper(response["body"], stream_finished)


def handle_call(span, kwargs, response):
modelId = kwargs.get("modelId")
(vendor, model_name) = modelId.split(".")
Expand All @@ -195,7 +248,6 @@ def handle_call(span, kwargs, response):
request_body = json.loads(kwargs.get("body"))
response_body = json.loads(response.get("body").read())

set_span_attribute(span, SpanAttributes.LLM_SYSTEM, vendor)
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, modelId)
set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, modelId)

Expand All @@ -222,12 +274,18 @@ def set_llama_meta_attributes(span, request_body, response_body):
set_span_attribute(
span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, request_body.get("max_gen_len")
)
if "invocation_metrics" in response_body:
input_tokens = response_body.get("invocation_metrics").get("inputTokenCount")
output_tokens = response_body.get("invocation_metrics").get("outputTokenCount")
else:
input_tokens = response_body.get("prompt_token_count")
output_tokens = response_body.get("generation_token_count")

set_usage_attributes(
span,
{
"input_tokens": response_body.get("prompt_token_count"),
"output_tokens": response_body.get("generation_token_count"),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
},
)

Expand All @@ -245,7 +303,6 @@ def set_llama_meta_attributes(span, request_body, response_body):
}
]
set_span_attribute(span, SpanAttributes.LLM_PROMPTS, json.dumps(prompts))
print(completions)
set_event_completion(span, completions)


Expand All @@ -257,13 +314,22 @@ def set_amazon_attributes(span, request_body, response_body):
"content": request_body.get("inputText"),
}
]
completions = [
{
"role": "assistant",
"content": result.get("outputText"),
}
for result in response_body.get("results")
]
if "results" in response_body:
completions = [
{
"role": "assistant",
"content": result.get("outputText"),
}
for result in response_body.get("results")
]

else:
completions = [
{
"role": "assistant",
"content": response_body.get("outputText"),
}
]
set_span_attribute(
span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, config.get("maxTokenCount")
)
Expand All @@ -272,13 +338,19 @@ def set_amazon_attributes(span, request_body, response_body):
)
set_span_attribute(span, SpanAttributes.LLM_REQUEST_TOP_P, config.get("topP"))
set_span_attribute(span, SpanAttributes.LLM_PROMPTS, json.dumps(prompts))
input_tokens = response_body.get("inputTextTokenCount")
if "results" in response_body:
output_tokens = sum(
int(result.get("tokenCount")) for result in response_body.get("results")
)
else:
output_tokens = response_body.get("outputTextTokenCount")

set_usage_attributes(
span,
{
"input_tokens": response_body.get("inputTextTokenCount"),
"output_tokens": sum(
int(result.get("tokenCount")) for result in response_body.get("results")
),
"input_tokens": input_tokens,
"output_tokens": output_tokens,
},
)
set_event_completion(span, completions)
Expand Down Expand Up @@ -320,7 +392,7 @@ def set_anthropic_messages_attributes(span, request_body, response_body):
set_span_attribute(
span,
SpanAttributes.LLM_REQUEST_MAX_TOKENS,
request_body.get("max_tokens_to_sample"),
request_body.get("max_tokens_to_sample") or request_body.get("max_tokens"),
)
set_span_attribute(
span,
Expand Down Expand Up @@ -394,3 +466,62 @@ def set_span_streaming_response(span, response):
set_event_completion(
span, [{"role": role or "assistant", "content": streaming_response}]
)


class StreamingBedrockWrapper(ObjectProxy):
def __init__(
self,
response,
stream_done_callback=None,
):
super().__init__(response)

self._stream_done_callback = stream_done_callback
self._accumulating_body = {"generation": ""}

def __iter__(self):
for event in self.__wrapped__:
self._process_event(event)
yield event

def _process_event(self, event):
chunk = event.get("chunk")
if not chunk:
return

decoded_chunk = json.loads(chunk.get("bytes").decode())
type = decoded_chunk.get("type")

if type is None and "outputText" in decoded_chunk:
self._stream_done_callback(decoded_chunk)
return
if "generation" in decoded_chunk:
self._accumulating_body["generation"] += decoded_chunk.get("generation")

if type == "message_start":
self._accumulating_body = decoded_chunk.get("message")
elif type == "content_block_start":
self._accumulating_body["content"].append(
decoded_chunk.get("content_block")
)
elif type == "content_block_delta":
self._accumulating_body["content"][-1]["text"] += decoded_chunk.get(
"delta"
).get("text")

elif self.has_finished(type, decoded_chunk):
self._accumulating_body["invocation_metrics"] = decoded_chunk.get(
"amazon-bedrock-invocationMetrics"
)
self._stream_done_callback(self._accumulating_body)

def has_finished(self, type, chunk):
if type and type == "message_stop":
return True

if "completionReason" in chunk and chunk.get("completionReason") == "FINISH":
return True

if "stop_reason" in chunk and chunk.get("stop_reason") is not None:
return True
return False
2 changes: 1 addition & 1 deletion src/langtrace_python_sdk/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "3.3.23"
__version__ = "3.3.24"
4 changes: 2 additions & 2 deletions src/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
"vertexai": False,
"gemini": False,
"mistral": False,
"awsbedrock": False,
"awsbedrock": True,
"cerebras": False,
"google_genai": True,
"google_genai": False,
}

if ENABLED_EXAMPLES["anthropic"]:
Expand Down
Loading