Skip to content

Commit 545859b

Browse files
committed
kickstart refactoring
1 parent 4c04d1f commit 545859b

File tree

6 files changed

+192
-64
lines changed

6 files changed

+192
-64
lines changed
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from examples.awsbedrock_examples.converse import use_converse
1+
from examples.awsbedrock_examples.converse import use_converse, use_invoke_model_titan
22
from langtrace_python_sdk import langtrace, with_langtrace_root_span
33

44
langtrace.init()
@@ -7,4 +7,5 @@
77
class AWSBedrockRunner:
88
@with_langtrace_root_span("AWS_Bedrock")
99
def run(self):
10-
use_converse()
10+
# use_converse()
11+
use_invoke_model_titan()
Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
import os
21
import boto3
2+
import botocore
3+
import json
34
from langtrace_python_sdk import langtrace
5+
from dotenv import load_dotenv
6+
7+
8+
load_dotenv()
9+
langtrace.init()
10+
11+
brt = boto3.client("bedrock-runtime", region_name="us-east-1")
12+
brc = boto3.client("bedrock", region_name="us-east-1")
413

5-
langtrace.init(api_key=os.environ["LANGTRACE_API_KEY"])
614

715
def use_converse():
816
model_id = "anthropic.claude-3-haiku-20240307-v1:0"
9-
client = boto3.client(
10-
"bedrock-runtime",
11-
region_name="us-east-1",
12-
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
13-
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
14-
)
1517
conversation = [
1618
{
1719
"role": "user",
@@ -20,15 +22,70 @@ def use_converse():
2022
]
2123

2224
try:
23-
response = client.converse(
25+
response = brt.converse(
2426
modelId=model_id,
2527
messages=conversation,
26-
inferenceConfig={"maxTokens":4096,"temperature":0},
27-
additionalModelRequestFields={"top_k":250}
28+
inferenceConfig={"maxTokens": 4096, "temperature": 0},
29+
additionalModelRequestFields={"top_k": 250},
2830
)
2931
response_text = response["output"]["message"]["content"][0]["text"]
3032
print(response_text)
3133

32-
except (Exception) as e:
34+
except Exception as e:
3335
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
34-
exit(1)
36+
exit(1)
37+
38+
39+
def get_foundation_models():
40+
models = []
41+
for model in brc.list_foundation_models()["modelSummaries"]:
42+
models.append(model["modelId"])
43+
return models
44+
45+
46+
# Invoke Model API
47+
# Amazon Titan Models
48+
def use_invoke_model_titan():
49+
try:
50+
prompt_data = "what's 1+1?"
51+
body = json.dumps(
52+
{
53+
"inputText": prompt_data,
54+
"textGenerationConfig": {
55+
"maxTokenCount": 1024,
56+
"topP": 0.95,
57+
"temperature": 0.2,
58+
},
59+
}
60+
)
61+
modelId = "amazon.titan-text-express-v1" # "amazon.titan-tg1-large"
62+
accept = "application/json"
63+
contentType = "application/json"
64+
65+
response = brt.invoke_model(
66+
body=body, modelId=modelId, accept=accept, contentType=contentType
67+
)
68+
response_body = json.loads(response.get("body").read())
69+
70+
# print(response_body.get("results"))
71+
72+
except botocore.exceptions.ClientError as error:
73+
74+
if error.response["Error"]["Code"] == "AccessDeniedException":
75+
print(
76+
f"\x1b[41m{error.response['Error']['Message']}\
77+
\nTo troubeshoot this issue please refer to the following resources.\
78+
\nhttps://docs.aws.amazon.com/IAM/latest/UserGuide/troubleshoot_access-denied.html\
79+
\nhttps://docs.aws.amazon.com/bedrock/latest/userguide/security-iam.html\x1b[0m\n"
80+
)
81+
82+
else:
83+
raise error
84+
85+
86+
# Anthropic Models
87+
def use_invoke_model_anthropic():
88+
pass
89+
90+
91+
# print(get_foundation_models())

src/langtrace_python_sdk/constants/instrumentation/aws_bedrock.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
from langtrace.trace_attributes import AWSBedrockMethods
22

33
APIS = {
4+
"INVOKE_MODEL": {
5+
"METHOD": "aws_bedrock.invoke_model",
6+
"ENDPOINT": "/invoke-model",
7+
},
48
"CONVERSE": {
59
"METHOD": AWSBedrockMethods.CONVERSE.value,
610
"ENDPOINT": "/converse",

src/langtrace_python_sdk/instrumentation/aws_bedrock/instrumentation.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,17 @@
2323
from wrapt import wrap_function_wrapper as _W
2424

2525
from langtrace_python_sdk.instrumentation.aws_bedrock.patch import (
26-
converse, converse_stream
26+
converse,
27+
invoke_model,
28+
converse_stream,
29+
patch_aws_bedrock,
2730
)
2831

2932
logging.basicConfig(level=logging.FATAL)
3033

31-
def _patch_client(client, version: str, tracer) -> None:
32-
33-
# Store original methods
34-
original_converse = client.converse
35-
36-
# Replace with wrapped versions
37-
client.converse = converse("aws_bedrock.converse", version, tracer)(original_converse)
3834

3935
class AWSBedrockInstrumentation(BaseInstrumentor):
40-
36+
4137
def instrumentation_dependencies(self) -> Collection[str]:
4238
return ["boto3 >= 1.35.31"]
4339

@@ -46,13 +42,11 @@ def _instrument(self, **kwargs):
4642
tracer = get_tracer(__name__, "", tracer_provider)
4743
version = importlib.metadata.version("boto3")
4844

49-
def wrap_create_client(wrapped, instance, args, kwargs):
50-
result = wrapped(*args, **kwargs)
51-
if args and args[0] == 'bedrock-runtime':
52-
_patch_client(result, version, tracer)
53-
return result
54-
55-
_W("boto3", "client", wrap_create_client)
45+
_W(
46+
module="boto3",
47+
name="client",
48+
wrapper=patch_aws_bedrock(tracer, version),
49+
)
5650

5751
def _uninstrument(self, **kwargs):
58-
pass
52+
pass

src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py

Lines changed: 102 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -48,30 +48,42 @@ def wrapper(original_method):
4848
@wraps(original_method)
4949
def wrapped_method(*args, **kwargs):
5050
service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"]
51-
51+
print("Here's the kwargs: ", kwargs)
5252
input_content = [
5353
{
54-
'role': message.get('role', 'user'),
55-
'content': message.get('content', [])[0].get('text', "")
54+
"role": message.get("role", "user"),
55+
"content": message.get("content", [])[0].get("text", ""),
5656
}
57-
for message in kwargs.get('messages', [])
57+
for message in kwargs.get("messages", [])
5858
]
59-
59+
6060
span_attributes = {
61-
**get_langtrace_attributes(version, service_provider, vendor_type="framework"),
62-
**get_llm_request_attributes(kwargs, operation_name=operation_name, prompts=input_content),
61+
**get_langtrace_attributes(
62+
version, service_provider, vendor_type="framework"
63+
),
64+
**get_llm_request_attributes(
65+
kwargs, operation_name=operation_name, prompts=input_content
66+
),
6367
**get_llm_url(args[0] if args else None),
6468
SpanAttributes.LLM_PATH: APIS[api_name]["ENDPOINT"],
6569
**get_extra_attributes(),
6670
}
6771

6872
if api_name == "CONVERSE":
69-
span_attributes.update({
70-
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get('modelId'),
71-
SpanAttributes.LLM_REQUEST_MAX_TOKENS: kwargs.get('inferenceConfig', {}).get('maxTokens'),
72-
SpanAttributes.LLM_REQUEST_TEMPERATURE: kwargs.get('inferenceConfig', {}).get('temperature'),
73-
SpanAttributes.LLM_REQUEST_TOP_P: kwargs.get('inferenceConfig', {}).get('top_p'),
74-
})
73+
span_attributes.update(
74+
{
75+
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get("modelId"),
76+
SpanAttributes.LLM_REQUEST_MAX_TOKENS: kwargs.get(
77+
"inferenceConfig", {}
78+
).get("maxTokens"),
79+
SpanAttributes.LLM_REQUEST_TEMPERATURE: kwargs.get(
80+
"inferenceConfig", {}
81+
).get("temperature"),
82+
SpanAttributes.LLM_REQUEST_TOP_P: kwargs.get(
83+
"inferenceConfig", {}
84+
).get("top_p"),
85+
}
86+
)
7587

7688
attributes = LLMSpanAttributes(**span_attributes)
7789

@@ -92,20 +104,22 @@ def wrapped_method(*args, **kwargs):
92104
raise err
93105

94106
return wrapped_method
107+
95108
return wrapper
109+
96110
return decorator
97111

98112

99113
converse = traced_aws_bedrock_call("CONVERSE", "converse")
114+
invoke_model = traced_aws_bedrock_call("INVOKE_MODEL", "invoke_model")
100115

101116

102117
def converse_stream(original_method, version, tracer):
103118
def traced_method(wrapped, instance, args, kwargs):
104119
service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"]
105-
120+
106121
span_attributes = {
107-
**get_langtrace_attributes
108-
(version, service_provider, vendor_type="llm"),
122+
**get_langtrace_attributes(version, service_provider, vendor_type="llm"),
109123
**get_llm_request_attributes(kwargs),
110124
**get_llm_url(instance),
111125
SpanAttributes.LLM_PATH: APIS["CONVERSE_STREAM"]["ENDPOINT"],
@@ -129,29 +143,87 @@ def traced_method(wrapped, instance, args, kwargs):
129143
span.record_exception(err)
130144
span.set_status(Status(StatusCode.ERROR, str(err)))
131145
raise err
132-
146+
133147
return traced_method
134148

135149

136150
@silently_fail
137151
def _set_response_attributes(span, kwargs, result):
138-
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, kwargs.get('modelId'))
139-
set_span_attribute(span, SpanAttributes.LLM_TOP_K, kwargs.get('additionalModelRequestFields', {}).get('top_k'))
140-
content = result.get('output', {}).get('message', {}).get('content', [])
152+
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, kwargs.get("modelId"))
153+
set_span_attribute(
154+
span,
155+
SpanAttributes.LLM_TOP_K,
156+
kwargs.get("additionalModelRequestFields", {}).get("top_k"),
157+
)
158+
content = result.get("output", {}).get("message", {}).get("content", [])
141159
if len(content) > 0:
142-
role = result.get('output', {}).get('message', {}).get('role', "assistant")
143-
responses = [
144-
{"role": role, "content": c.get('text', "")}
145-
for c in content
146-
]
160+
role = result.get("output", {}).get("message", {}).get("role", "assistant")
161+
responses = [{"role": role, "content": c.get("text", "")} for c in content]
147162
set_event_completion(span, responses)
148163

149-
if 'usage' in result:
164+
if "usage" in result:
150165
set_span_attributes(
151166
span,
152167
{
153-
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS: result['usage'].get('outputTokens'),
154-
SpanAttributes.LLM_USAGE_PROMPT_TOKENS: result['usage'].get('inputTokens'),
155-
SpanAttributes.LLM_USAGE_TOTAL_TOKENS: result['usage'].get('totalTokens'),
156-
}
168+
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS: result["usage"].get(
169+
"outputTokens"
170+
),
171+
SpanAttributes.LLM_USAGE_PROMPT_TOKENS: result["usage"].get(
172+
"inputTokens"
173+
),
174+
SpanAttributes.LLM_USAGE_TOTAL_TOKENS: result["usage"].get(
175+
"totalTokens"
176+
),
177+
},
178+
)
179+
180+
181+
def patch_aws_bedrock(tracer, version):
182+
def traced_method(wrapped, instance, args, kwargs):
183+
if args and args[0] != "bedrock-runtime":
184+
return
185+
186+
client = wrapped(*args, **kwargs)
187+
print("Here's the client: ", client)
188+
client.invoke_model = patch_invoke_model(client.invoke_model, tracer, version)
189+
client.invoke_model_with_response_stream = patch_invoke_model(
190+
client.invoke_model_with_response_stream, tracer, version
191+
)
192+
client.converse = patch_invoke_model(client.converse, tracer, version)
193+
client.converse_stream = patch_invoke_model(
194+
client.converse_stream, tracer, version
157195
)
196+
return client
197+
198+
return traced_method
199+
200+
201+
def patch_invoke_model(original_method, tracer, version):
202+
def traced_method(*args, **kwargs):
203+
service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"]
204+
span_attributes = {
205+
**get_langtrace_attributes(
206+
version, service_provider, vendor_type="framework"
207+
),
208+
**get_extra_attributes(),
209+
}
210+
with tracer.start_as_current_span(
211+
name=get_span_name("aws_bedrock.invoke_model"),
212+
kind=SpanKind.CLIENT,
213+
context=set_span_in_context(trace.get_current_span()),
214+
) as span:
215+
set_span_attributes(span, span_attributes)
216+
set_invoke_model_attributes(span, kwargs)
217+
response = original_method(*args, **kwargs)
218+
return response
219+
220+
return traced_method
221+
222+
223+
def set_invoke_model_attributes(span, kwargs):
224+
modelId = kwargs.get("modelId")
225+
(vendor, model_name) = modelId.split(".")
226+
227+
print("Here's the vendor: ", vendor)
228+
print("Here's the model_name: ", model_name)
229+
print("Here's the kwargs: ", kwargs)

src/run_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"anthropic": False,
55
"azureopenai": False,
66
"chroma": False,
7-
"cohere": True,
7+
"cohere": False,
88
"fastapi": False,
99
"langchain": False,
1010
"llamaindex": False,
@@ -20,7 +20,7 @@
2020
"vertexai": False,
2121
"gemini": False,
2222
"mistral": False,
23-
"awsbedrock": False,
23+
"awsbedrock": True,
2424
"cerebras": False,
2525
}
2626

0 commit comments

Comments
 (0)