Skip to content

Commit a0eb08b

Browse files
committed
Add support for bedrock
1 parent 98a90f7 commit a0eb08b

File tree

7 files changed

+486
-149
lines changed

7 files changed

+486
-149
lines changed
Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
from examples.awsbedrock_examples.converse import use_converse, use_invoke_model_titan
1+
from examples.awsbedrock_examples.converse import (
2+
use_converse_stream,
3+
use_converse,
4+
use_invoke_model_anthropic,
5+
use_invoke_model_cohere,
6+
use_invoke_model_amazon,
7+
)
28
from langtrace_python_sdk import langtrace, with_langtrace_root_span
39

4-
langtrace.init()
5-
610

711
class AWSBedrockRunner:
812
@with_langtrace_root_span("AWS_Bedrock")
913
def run(self):
10-
# use_converse()
11-
use_invoke_model_titan()
14+
15+
use_converse_stream()
16+
use_converse()
17+
use_invoke_model_anthropic()
18+
use_invoke_model_cohere()
19+
use_invoke_model_amazon()

src/examples/awsbedrock_examples/converse.py

Lines changed: 99 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,46 @@
11
import boto3
2-
import botocore
32
import json
43
from langtrace_python_sdk import langtrace
54
from dotenv import load_dotenv
6-
5+
import botocore
76

87
load_dotenv()
9-
langtrace.init()
8+
langtrace.init(write_spans_to_console=False)
109

1110
brt = boto3.client("bedrock-runtime", region_name="us-east-1")
1211
brc = boto3.client("bedrock", region_name="us-east-1")
1312

1413

14+
def use_converse_stream():
15+
model_id = "anthropic.claude-3-haiku-20240307-v1:0"
16+
conversation = [
17+
{
18+
"role": "user",
19+
"content": [{"text": "what is the capital of France?"}],
20+
}
21+
]
22+
23+
try:
24+
response = brt.converse_stream(
25+
modelId=model_id,
26+
messages=conversation,
27+
inferenceConfig={"maxTokens": 4096, "temperature": 0},
28+
additionalModelRequestFields={"top_k": 250},
29+
)
30+
# response_text = response["output"]["message"]["content"][0]["text"]
31+
print(response)
32+
33+
except Exception as e:
34+
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
35+
exit(1)
36+
37+
1538
def use_converse():
1639
model_id = "anthropic.claude-3-haiku-20240307-v1:0"
1740
conversation = [
1841
{
1942
"role": "user",
20-
"content": [{"text": "Write a story about a magic backpack."}],
43+
"content": [{"text": "what is the capital of France?"}],
2144
}
2245
]
2346

@@ -37,17 +60,15 @@ def use_converse():
3760

3861

3962
def get_foundation_models():
40-
models = []
4163
for model in brc.list_foundation_models()["modelSummaries"]:
42-
models.append(model["modelId"])
43-
return models
64+
print(model["modelId"])
4465

4566

4667
# Invoke Model API
4768
# Amazon Titan Models
48-
def use_invoke_model_titan():
69+
def use_invoke_model_titan(stream=False):
4970
try:
50-
prompt_data = "what's 1+1?"
71+
prompt_data = "what's the capital of France?"
5172
body = json.dumps(
5273
{
5374
"inputText": prompt_data,
@@ -62,12 +83,16 @@ def use_invoke_model_titan():
6283
accept = "application/json"
6384
contentType = "application/json"
6485

65-
response = brt.invoke_model(
66-
body=body, modelId=modelId, accept=accept, contentType=contentType
67-
)
68-
response_body = json.loads(response.get("body").read())
86+
if stream:
6987

70-
# print(response_body.get("results"))
88+
response = brt.invoke_model_with_response_stream(
89+
body=body, modelId=modelId, accept=accept, contentType=contentType
90+
)
91+
else:
92+
response = brt.invoke_model(
93+
body=body, modelId=modelId, accept=accept, contentType=contentType
94+
)
95+
response_body = json.loads(response.get("body").read())
7196

7297
except botocore.exceptions.ClientError as error:
7398

@@ -84,8 +109,66 @@ def use_invoke_model_titan():
84109

85110

86111
# Anthropic Models
87-
def use_invoke_model_anthropic():
88-
pass
112+
def use_invoke_model_anthropic(stream=False):
113+
body = json.dumps(
114+
{
115+
"anthropic_version": "bedrock-2023-05-31",
116+
"max_tokens": 1024,
117+
"temperature": 0.1,
118+
"top_p": 0.9,
119+
"messages": [{"role": "user", "content": "Hello, Claude"}],
120+
}
121+
)
122+
modelId = "anthropic.claude-v2"
123+
accept = "application/json"
124+
contentType = "application/json"
125+
126+
if stream:
127+
response = brt.invoke_model_with_response_stream(body=body, modelId=modelId)
128+
stream_response = response.get("body")
129+
if stream_response:
130+
for event in stream_response:
131+
chunk = event.get("chunk")
132+
if chunk:
133+
print(json.loads(chunk.get("bytes").decode()))
134+
135+
else:
136+
response = brt.invoke_model(
137+
body=body, modelId=modelId, accept=accept, contentType=contentType
138+
)
139+
response_body = json.loads(response.get("body").read())
140+
# text
141+
print(response_body.get("completion"))
142+
143+
144+
def use_invoke_model_llama():
145+
model_id = "meta.llama3-8b-instruct-v1:0"
146+
prompt = "What is the capital of France?"
147+
max_gen_len = 128
148+
temperature = 0.1
149+
top_p = 0.9
150+
151+
# Create request body.
152+
body = json.dumps(
153+
{
154+
"prompt": prompt,
155+
"max_gen_len": max_gen_len,
156+
"temperature": temperature,
157+
"top_p": top_p,
158+
}
159+
)
160+
response = brt.invoke_model(body=body, modelId=model_id)
161+
162+
response_body = json.loads(response.get("body").read())
163+
164+
return response_body
89165

90166

91167
# print(get_foundation_models())
168+
def use_invoke_model_cohere():
169+
model_id = "cohere.command-r-plus-v1"
170+
prompt = "What is the capital of France?"
171+
body = json.dumps({"prompt": prompt, "max_tokens": 1024, "temperature": 0.1})
172+
response = brt.invoke_model(body=body, modelId=model_id)
173+
response_body = json.loads(response.get("body").read())
174+
print(response_body)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import json
2+
from wrapt import ObjectProxy
3+
4+
5+
class StreamingWrapper(ObjectProxy):
6+
def __init__(
7+
self,
8+
response,
9+
stream_done_callback=None,
10+
):
11+
super().__init__(response)
12+
13+
self._stream_done_callback = stream_done_callback
14+
self._accumulating_body = {}
15+
16+
def __iter__(self):
17+
for event in self.__wrapped__:
18+
self._process_event(event)
19+
yield event
20+
21+
def _process_event(self, event):
22+
chunk = event.get("chunk")
23+
if not chunk:
24+
return
25+
26+
decoded_chunk = json.loads(chunk.get("bytes").decode())
27+
type = decoded_chunk.get("type")
28+
29+
if type == "message_start":
30+
self._accumulating_body = decoded_chunk.get("message")
31+
elif type == "content_block_start":
32+
self._accumulating_body["content"].append(
33+
decoded_chunk.get("content_block")
34+
)
35+
elif type == "content_block_delta":
36+
self._accumulating_body["content"][-1]["text"] += decoded_chunk.get(
37+
"delta"
38+
).get("text")
39+
elif type == "message_stop" and self._stream_done_callback:
40+
self._accumulating_body["invocation_metrics"] = decoded_chunk.get(
41+
"amazon-bedrock-invocationMetrics"
42+
)
43+
self._stream_done_callback(self._accumulating_body)

src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,7 @@
2222
from opentelemetry.trace import get_tracer
2323
from wrapt import wrap_function_wrapper as _W
2424

25-
from langtrace_python_sdk.instrumentation.aws_bedrock.patch import (
26-
converse,
27-
invoke_model,
28-
converse_stream,
29-
patch_aws_bedrock,
30-
)
25+
from langtrace_python_sdk.instrumentation.aws_bedrock.patch import patch_aws_bedrock
3126

3227
logging.basicConfig(level=logging.FATAL)
3328

0 commit comments

Comments
 (0)