Skip to content

Commit 1ed6af1

Browse files
authored
Merge pull request #451 from Scale3-Labs/ali/bedrock-refactor
Bedrock Enhancement
2 parents 12807d6 + a0eb08b commit 1ed6af1

File tree

9 files changed

+579
-114
lines changed

9 files changed

+579
-114
lines changed
Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,19 @@
1-
from examples.awsbedrock_examples.converse import use_converse
1+
from examples.awsbedrock_examples.converse import (
2+
use_converse_stream,
3+
use_converse,
4+
use_invoke_model_anthropic,
5+
use_invoke_model_cohere,
6+
use_invoke_model_amazon,
7+
)
28
from langtrace_python_sdk import langtrace, with_langtrace_root_span
39

4-
langtrace.init()
5-
610

711
class AWSBedrockRunner:
812
@with_langtrace_root_span("AWS_Bedrock")
913
def run(self):
14+
15+
use_converse_stream()
1016
use_converse()
17+
use_invoke_model_anthropic()
18+
use_invoke_model_cohere()
19+
use_invoke_model_amazon()
Lines changed: 154 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,174 @@
1-
import os
21
import boto3
2+
import json
33
from langtrace_python_sdk import langtrace
4+
from dotenv import load_dotenv
5+
import botocore
6+
7+
load_dotenv()
8+
langtrace.init(write_spans_to_console=False)
9+
10+
brt = boto3.client("bedrock-runtime", region_name="us-east-1")
11+
brc = boto3.client("bedrock", region_name="us-east-1")
12+
13+
14+
def use_converse_stream():
15+
model_id = "anthropic.claude-3-haiku-20240307-v1:0"
16+
conversation = [
17+
{
18+
"role": "user",
19+
"content": [{"text": "what is the capital of France?"}],
20+
}
21+
]
22+
23+
try:
24+
response = brt.converse_stream(
25+
modelId=model_id,
26+
messages=conversation,
27+
inferenceConfig={"maxTokens": 4096, "temperature": 0},
28+
additionalModelRequestFields={"top_k": 250},
29+
)
30+
# response_text = response["output"]["message"]["content"][0]["text"]
31+
print(response)
32+
33+
except Exception as e:
34+
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
35+
exit(1)
436

5-
langtrace.init(api_key=os.environ["LANGTRACE_API_KEY"])
637

738
def use_converse():
839
model_id = "anthropic.claude-3-haiku-20240307-v1:0"
9-
client = boto3.client(
10-
"bedrock-runtime",
11-
region_name="us-east-1",
12-
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
13-
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
14-
)
1540
conversation = [
1641
{
1742
"role": "user",
18-
"content": [{"text": "Write a story about a magic backpack."}],
43+
"content": [{"text": "what is the capital of France?"}],
1944
}
2045
]
2146

2247
try:
23-
response = client.converse(
48+
response = brt.converse(
2449
modelId=model_id,
2550
messages=conversation,
26-
inferenceConfig={"maxTokens":4096,"temperature":0},
27-
additionalModelRequestFields={"top_k":250}
51+
inferenceConfig={"maxTokens": 4096, "temperature": 0},
52+
additionalModelRequestFields={"top_k": 250},
2853
)
2954
response_text = response["output"]["message"]["content"][0]["text"]
3055
print(response_text)
3156

32-
except (Exception) as e:
57+
except Exception as e:
3358
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
34-
exit(1)
59+
exit(1)
60+
61+
62+
def get_foundation_models():
63+
for model in brc.list_foundation_models()["modelSummaries"]:
64+
print(model["modelId"])
65+
66+
67+
# Invoke Model API
68+
# Amazon Titan Models
69+
def use_invoke_model_titan(stream=False):
70+
try:
71+
prompt_data = "what's the capital of France?"
72+
body = json.dumps(
73+
{
74+
"inputText": prompt_data,
75+
"textGenerationConfig": {
76+
"maxTokenCount": 1024,
77+
"topP": 0.95,
78+
"temperature": 0.2,
79+
},
80+
}
81+
)
82+
modelId = "amazon.titan-text-express-v1" # "amazon.titan-tg1-large"
83+
accept = "application/json"
84+
contentType = "application/json"
85+
86+
if stream:
87+
88+
response = brt.invoke_model_with_response_stream(
89+
body=body, modelId=modelId, accept=accept, contentType=contentType
90+
)
91+
else:
92+
response = brt.invoke_model(
93+
body=body, modelId=modelId, accept=accept, contentType=contentType
94+
)
95+
response_body = json.loads(response.get("body").read())
96+
97+
except botocore.exceptions.ClientError as error:
98+
99+
if error.response["Error"]["Code"] == "AccessDeniedException":
100+
print(
101+
f"\x1b[41m{error.response['Error']['Message']}\
102+
\nTo troubeshoot this issue please refer to the following resources.\
103+
\nhttps://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_access-denied.html\
104+
\nhttps://docs.aws.amazon.com/bedrock/latest/userguide/security-iam.html\x1b[0m\n"
105+
)
106+
107+
else:
108+
raise error
109+
110+
111+
# Anthropic Models
112+
def use_invoke_model_anthropic(stream=False):
113+
body = json.dumps(
114+
{
115+
"anthropic_version": "bedrock-2023-05-31",
116+
"max_tokens": 1024,
117+
"temperature": 0.1,
118+
"top_p": 0.9,
119+
"messages": [{"role": "user", "content": "Hello, Claude"}],
120+
}
121+
)
122+
modelId = "anthropic.claude-v2"
123+
accept = "application/json"
124+
contentType = "application/json"
125+
126+
if stream:
127+
response = brt.invoke_model_with_response_stream(body=body, modelId=modelId)
128+
stream_response = response.get("body")
129+
if stream_response:
130+
for event in stream_response:
131+
chunk = event.get("chunk")
132+
if chunk:
133+
print(json.loads(chunk.get("bytes").decode()))
134+
135+
else:
136+
response = brt.invoke_model(
137+
body=body, modelId=modelId, accept=accept, contentType=contentType
138+
)
139+
response_body = json.loads(response.get("body").read())
140+
# text
141+
print(response_body.get("completion"))
142+
143+
144+
def use_invoke_model_llama():
145+
model_id = "meta.llama3-8b-instruct-v1:0"
146+
prompt = "What is the capital of France?"
147+
max_gen_len = 128
148+
temperature = 0.1
149+
top_p = 0.9
150+
151+
# Create request body.
152+
body = json.dumps(
153+
{
154+
"prompt": prompt,
155+
"max_gen_len": max_gen_len,
156+
"temperature": temperature,
157+
"top_p": top_p,
158+
}
159+
)
160+
response = brt.invoke_model(body=body, modelId=model_id)
161+
162+
response_body = json.loads(response.get("body").read())
163+
164+
return response_body
165+
166+
167+
# print(get_foundation_models())
168+
def use_invoke_model_cohere():
169+
model_id = "cohere.command-r-plus-v1"
170+
prompt = "What is the capital of France?"
171+
body = json.dumps({"prompt": prompt, "max_tokens": 1024, "temperature": 0.1})
172+
response = brt.invoke_model(body=body, modelId=model_id)
173+
response_body = json.loads(response.get("body").read())
174+
print(response_body)

src/langtrace_python_sdk/constants/instrumentation/aws_bedrock.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from langtrace.trace_attributes import AWSBedrockMethods
22

33
APIS = {
4+
"INVOKE_MODEL": {
5+
"METHOD": "aws_bedrock.invoke_model",
6+
"ENDPOINT": "/invoke-model",
7+
},
48
"CONVERSE": {
59
"METHOD": AWSBedrockMethods.CONVERSE.value,
610
"ENDPOINT": "/converse",
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import json
2+
from wrapt import ObjectProxy
3+
4+
5+
class StreamingWrapper(ObjectProxy):
6+
def __init__(
7+
self,
8+
response,
9+
stream_done_callback=None,
10+
):
11+
super().__init__(response)
12+
13+
self._stream_done_callback = stream_done_callback
14+
self._accumulating_body = {}
15+
16+
def __iter__(self):
17+
for event in self.__wrapped__:
18+
self._process_event(event)
19+
yield event
20+
21+
def _process_event(self, event):
22+
chunk = event.get("chunk")
23+
if not chunk:
24+
return
25+
26+
decoded_chunk = json.loads(chunk.get("bytes").decode())
27+
type = decoded_chunk.get("type")
28+
29+
if type == "message_start":
30+
self._accumulating_body = decoded_chunk.get("message")
31+
elif type == "content_block_start":
32+
self._accumulating_body["content"].append(
33+
decoded_chunk.get("content_block")
34+
)
35+
elif type == "content_block_delta":
36+
self._accumulating_body["content"][-1]["text"] += decoded_chunk.get(
37+
"delta"
38+
).get("text")
39+
elif type == "message_stop" and self._stream_done_callback:
40+
self._accumulating_body["invocation_metrics"] = decoded_chunk.get(
41+
"amazon-bedrock-invocationMetrics"
42+
)
43+
self._stream_done_callback(self._accumulating_body)

src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,13 @@
2222
from opentelemetry.trace import get_tracer
2323
from wrapt import wrap_function_wrapper as _W
2424

25-
from langtrace_python_sdk.instrumentation.aws_bedrock.patch import (
26-
converse, converse_stream
27-
)
25+
from langtrace_python_sdk.instrumentation.aws_bedrock.patch import patch_aws_bedrock
2826

2927
logging.basicConfig(level=logging.FATAL)
3028

31-
def _patch_client(client, version: str, tracer) -> None:
32-
33-
# Store original methods
34-
original_converse = client.converse
35-
36-
# Replace with wrapped versions
37-
client.converse = converse("aws_bedrock.converse", version, tracer)(original_converse)
3829

3930
class AWSBedrockInstrumentation(BaseInstrumentor):
40-
31+
4132
def instrumentation_dependencies(self) -> Collection[str]:
4233
return ["boto3 >= 1.35.31"]
4334

@@ -46,13 +37,11 @@ def _instrument(self, **kwargs):
4637
tracer = get_tracer(__name__, "", tracer_provider)
4738
version = importlib.metadata.version("boto3")
4839

49-
def wrap_create_client(wrapped, instance, args, kwargs):
50-
result = wrapped(*args, **kwargs)
51-
if args and args[0] == 'bedrock-runtime':
52-
_patch_client(result, version, tracer)
53-
return result
54-
55-
_W("boto3", "client", wrap_create_client)
40+
_W(
41+
module="boto3",
42+
name="client",
43+
wrapper=patch_aws_bedrock(tracer, version),
44+
)
5645

5746
def _uninstrument(self, **kwargs):
58-
pass
47+
pass

0 commit comments

Comments
 (0)