Skip to content

Commit 954e4ae

Browse files
authored
Merge pull request #537 from m1kl0sh/aws-embeddings-support
Added support for embeddings via AWS Bedrock
2 parents cc95a5a + 1871128 commit 954e4ae

File tree

7 files changed

+359
-8
lines changed

7 files changed

+359
-8
lines changed

src/langtrace_python_sdk/constants/instrumentation/aws_bedrock.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
"METHOD": "aws_bedrock.invoke_model",
66
"ENDPOINT": "/invoke-model",
77
},
8+
"INVOKE_MODEL_WITH_RESPONSE_STREAM": {
9+
"METHOD": "aws_bedrock.invoke_model_with_response_stream",
10+
"ENDPOINT": "/invoke-model-with-response-stream",
11+
},
812
"CONVERSE": {
913
"METHOD": AWSBedrockMethods.CONVERSE.value,
1014
"ENDPOINT": "/converse",

src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""
1616

1717
import json
18+
import io
1819

1920
from wrapt import ObjectProxy
2021
from itertools import tee
@@ -44,6 +45,7 @@
4445
set_span_attributes,
4546
set_usage_attributes,
4647
)
48+
from langtrace_python_sdk.utils import set_event_prompt
4749

4850

4951
def converse_stream(original_method, version, tracer):
@@ -170,12 +172,29 @@ def traced_method(*args, **kwargs):
170172
return traced_method
171173

172174

175+
def parse_vendor_and_model_name_from_model_id(model_id):
176+
if model_id.startswith("arn:aws:bedrock:"):
177+
# This needs to be in one of the following forms:
178+
# arn:aws:bedrock:region:account-id:foundation-model/vendor.model-name
179+
# arn:aws:bedrock:region:account-id:custom-model/vendor.model-name/model-id
180+
parts = model_id.split("/")
181+
identifiers = parts[1].split(".")
182+
return identifiers[0], identifiers[1]
183+
parts = model_id.split(".")
184+
if len(parts) == 1:
185+
return parts[0], parts[0]
186+
else:
187+
return parts[-2], parts[-1]
188+
189+
173190
def patch_invoke_model(original_method, tracer, version):
174191
def traced_method(*args, **kwargs):
175192
modelId = kwargs.get("modelId")
176-
(vendor, _) = modelId.split(".")
193+
vendor, _ = parse_vendor_and_model_name_from_model_id(modelId)
177194
span_attributes = {
178195
**get_langtrace_attributes(version, vendor, vendor_type="framework"),
196+
SpanAttributes.LLM_PATH: APIS["INVOKE_MODEL"]["ENDPOINT"],
197+
SpanAttributes.LLM_IS_STREAMING: False,
179198
**get_extra_attributes(),
180199
}
181200
with tracer.start_as_current_span(
@@ -196,9 +215,11 @@ def patch_invoke_model_with_response_stream(original_method, tracer, version):
196215
@wraps(original_method)
197216
def traced_method(*args, **kwargs):
198217
modelId = kwargs.get("modelId")
199-
(vendor, _) = modelId.split(".")
218+
vendor, _ = parse_vendor_and_model_name_from_model_id(modelId)
200219
span_attributes = {
201220
**get_langtrace_attributes(version, vendor, vendor_type="framework"),
221+
SpanAttributes.LLM_PATH: APIS["INVOKE_MODEL_WITH_RESPONSE_STREAM"]["ENDPOINT"],
222+
SpanAttributes.LLM_IS_STREAMING: True,
202223
**get_extra_attributes(),
203224
}
204225
span = tracer.start_span(
@@ -220,7 +241,7 @@ def handle_streaming_call(span, kwargs, response):
220241
def stream_finished(response_body):
221242
request_body = json.loads(kwargs.get("body"))
222243

223-
(vendor, model) = kwargs.get("modelId").split(".")
244+
vendor, model = parse_vendor_and_model_name_from_model_id(kwargs.get("modelId"))
224245

225246
set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, model)
226247
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, model)
@@ -244,18 +265,22 @@ def stream_finished(response_body):
244265

245266
def handle_call(span, kwargs, response):
246267
modelId = kwargs.get("modelId")
247-
(vendor, model_name) = modelId.split(".")
268+
vendor, model_name = parse_vendor_and_model_name_from_model_id(modelId)
269+
read_response_body = response.get("body").read()
270+
request_body = json.loads(kwargs.get("body"))
271+
response_body = json.loads(read_response_body)
248272
response["body"] = BufferedStreamBody(
249-
response["body"]._raw_stream, response["body"]._content_length
273+
io.BytesIO(read_response_body), len(read_response_body)
250274
)
251-
request_body = json.loads(kwargs.get("body"))
252-
response_body = json.loads(response.get("body").read())
253275

254276
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, modelId)
255277
set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, modelId)
256278

257279
if vendor == "amazon":
258-
set_amazon_attributes(span, request_body, response_body)
280+
if model_name.startswith("titan-embed-text"):
281+
set_amazon_embedding_attributes(span, request_body, response_body)
282+
else:
283+
set_amazon_attributes(span, request_body, response_body)
259284

260285
if vendor == "anthropic":
261286
if "prompt" in request_body:
@@ -359,6 +384,27 @@ def set_amazon_attributes(span, request_body, response_body):
359384
set_event_completion(span, completions)
360385

361386

387+
def set_amazon_embedding_attributes(span, request_body, response_body):
388+
input_text = request_body.get("inputText")
389+
set_event_prompt(span, input_text)
390+
391+
embeddings = response_body.get("embedding", [])
392+
input_tokens = response_body.get("inputTextTokenCount")
393+
set_usage_attributes(
394+
span,
395+
{
396+
"input_tokens": input_tokens,
397+
"output": len(embeddings),
398+
},
399+
)
400+
set_span_attribute(
401+
span, SpanAttributes.LLM_REQUEST_MODEL, request_body.get("modelId")
402+
)
403+
set_span_attribute(
404+
span, SpanAttributes.LLM_RESPONSE_MODEL, request_body.get("modelId")
405+
)
406+
407+
362408
def set_anthropic_completions_attributes(span, request_body, response_body):
363409
set_span_attribute(
364410
span,
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
interactions:
2+
- request:
3+
body: '{"messages": [{"role": "user", "content": "Say this is a test three times"}],
4+
"anthropic_version": "bedrock-2023-05-31", "max_tokens": 100}'
5+
headers:
6+
Accept:
7+
- !!binary |
8+
YXBwbGljYXRpb24vanNvbg==
9+
Content-Length:
10+
- '139'
11+
Content-Type:
12+
- !!binary |
13+
YXBwbGljYXRpb24vanNvbg==
14+
User-Agent:
15+
- !!binary |
16+
Qm90bzMvMS4zOC4xOCBtZC9Cb3RvY29yZSMxLjM4LjE4IHVhLzIuMSBvcy9tYWNvcyMyNC40LjAg
17+
bWQvYXJjaCNhcm02NCBsYW5nL3B5dGhvbiMzLjEzLjEgbWQvcHlpbXBsI0NQeXRob24gbS9aLGIg
18+
Y2ZnL3JldHJ5LW1vZGUjc3RhbmRhcmQgQm90b2NvcmUvMS4zOC4xOA==
19+
method: POST
20+
uri: https://bedrock-runtime.us-east-1.amazonaws.com/model/us.anthropic.claude-3-7-sonnet-20250219-v1%3A0/invoke
21+
response:
22+
body:
23+
string: '{"id":"msg_bdrk_01NJB1bDTLkFh6pgfoAD5hkb","type":"message","role":"assistant","model":"claude-3-7-sonnet-20250219","content":[{"type":"text","text":"This
24+
is a test.\nThis is a test.\nThis is a test."}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":14,"cache_creation_input_tokens":0,"cache_read_input_tokens":0,"output_tokens":20}}'
25+
headers:
26+
Connection:
27+
- keep-alive
28+
Content-Length:
29+
- '355'
30+
Content-Type:
31+
- application/json
32+
Date:
33+
- Mon, 19 May 2025 16:42:05 GMT
34+
X-Amzn-Bedrock-Input-Token-Count:
35+
- '14'
36+
X-Amzn-Bedrock-Invocation-Latency:
37+
- '926'
38+
X-Amzn-Bedrock-Output-Token-Count:
39+
- '20'
40+
x-amzn-RequestId:
41+
- c0a92363-ec28-4a8b-9c09-571131d946b0
42+
status:
43+
code: 200
44+
message: OK
45+
version: 1

src/tests/aws_bedrock/cassettes/test_generate_embedding.yaml

Lines changed: 41 additions & 0 deletions
Large diffs are not rendered by default.

src/tests/aws_bedrock/conftest.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Unit tests configuration module."""
2+
3+
import pytest
4+
import os
5+
6+
from boto3.session import Session
7+
from botocore.config import Config
8+
9+
from langtrace_python_sdk.instrumentation.aws_bedrock.instrumentation import (
10+
AWSBedrockInstrumentation,
11+
)
12+
13+
14+
@pytest.fixture(autouse=True)
15+
def environment():
16+
if not os.getenv("AWS_ACCESS_KEY_ID"):
17+
os.environ["AWS_ACCESS_KEY_ID"] = "test_api_key"
18+
19+
20+
@pytest.fixture
21+
def aws_bedrock_client():
22+
bedrock_config = Config(
23+
region_name="us-east-1",
24+
connect_timeout=300,
25+
read_timeout=300,
26+
retries={"total_max_attempts": 2, "mode": "standard"},
27+
)
28+
return Session().client("bedrock-runtime", config=bedrock_config)
29+
30+
31+
@pytest.fixture(scope="module")
32+
def vcr_config():
33+
return {
34+
"filter_headers": [
35+
"authorization",
36+
"X-Amz-Date",
37+
"X-Amz-Security-Token",
38+
"amz-sdk-invocation-id",
39+
"amz-sdk-request",
40+
]
41+
}
42+
43+
44+
@pytest.fixture(scope="session", autouse=True)
45+
def instrument():
46+
AWSBedrockInstrumentation().instrument()
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import pytest
2+
import json
3+
from tests.utils import (
4+
assert_completion_in_events,
5+
assert_prompt_in_events,
6+
assert_token_count,
7+
)
8+
from importlib_metadata import version as v
9+
10+
from langtrace.trace_attributes import SpanAttributes
11+
from langtrace_python_sdk.constants.instrumentation.aws_bedrock import APIS
12+
13+
ANTHROPIC_VERSION = "bedrock-2023-05-31"
14+
15+
16+
@pytest.mark.vcr()
17+
def test_chat_completion(exporter, aws_bedrock_client):
18+
model_id = "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
19+
messages_value = [{"role": "user", "content": "Say this is a test three times"}]
20+
21+
kwargs = {
22+
"modelId": model_id,
23+
"accept": "application/json",
24+
"contentType": "application/json",
25+
"body": json.dumps(
26+
{
27+
"messages": messages_value,
28+
"anthropic_version": ANTHROPIC_VERSION,
29+
"max_tokens": 100,
30+
}
31+
),
32+
}
33+
34+
aws_bedrock_client.invoke_model(**kwargs)
35+
spans = exporter.get_finished_spans()
36+
completion_span = spans[-1]
37+
assert completion_span.name == "aws_bedrock.invoke_model"
38+
39+
attributes = completion_span.attributes
40+
41+
assert attributes.get(SpanAttributes.LANGTRACE_SDK_NAME) == "langtrace-python-sdk"
42+
assert attributes.get(SpanAttributes.LANGTRACE_SERVICE_NAME) == "anthropic"
43+
assert attributes.get(SpanAttributes.LANGTRACE_SERVICE_TYPE) == "framework"
44+
assert attributes.get(SpanAttributes.LANGTRACE_SERVICE_VERSION) == v("boto3")
45+
assert attributes.get(SpanAttributes.LANGTRACE_VERSION) == v("langtrace-python-sdk")
46+
assert attributes.get(SpanAttributes.LLM_PATH) == APIS["INVOKE_MODEL"]["ENDPOINT"]
47+
assert attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) == model_id
48+
assert attributes.get(SpanAttributes.LLM_IS_STREAMING) is False
49+
assert_prompt_in_events(completion_span.events)
50+
assert_completion_in_events(completion_span.events)
51+
assert_token_count(attributes)
52+
53+
54+
@pytest.mark.skip(reason="Skipping streaming test due to no streaming support in vcrpy")
55+
def test_chat_completion_streaming(exporter, aws_bedrock_client):
56+
model_id = "us.anthropic.claude-3-7-sonnet-20250219-v1:0"
57+
messages_value = [{"role": "user", "content": "Say this is a test three times"}]
58+
59+
kwargs = {
60+
"modelId": model_id,
61+
"accept": "application/json",
62+
"contentType": "application/json",
63+
"body": json.dumps(
64+
{
65+
"messages": messages_value,
66+
"anthropic_version": ANTHROPIC_VERSION,
67+
"max_tokens": 100,
68+
}
69+
),
70+
}
71+
72+
response = aws_bedrock_client.invoke_model_with_response_stream(**kwargs)
73+
chunk_count = 0
74+
75+
for chunk in response["body"]:
76+
if chunk:
77+
chunk_count += 1
78+
79+
spans = exporter.get_finished_spans()
80+
streaming_span = spans[-1]
81+
assert streaming_span.name == "aws_bedrock.invoke_model_with_response_stream"
82+
83+
attributes = streaming_span.attributes
84+
85+
assert attributes.get(SpanAttributes.LANGTRACE_SDK_NAME) == "langtrace-python-sdk"
86+
assert attributes.get(SpanAttributes.LANGTRACE_SERVICE_NAME) == "anthropic"
87+
assert attributes.get(SpanAttributes.LANGTRACE_SERVICE_TYPE) == "framework"
88+
assert attributes.get(SpanAttributes.LANGTRACE_SERVICE_VERSION) == v("boto3")
89+
assert attributes.get(SpanAttributes.LANGTRACE_VERSION) == v("langtrace-python-sdk")
90+
assert (
91+
attributes.get(SpanAttributes.LLM_PATH)
92+
== APIS["INVOKE_MODEL_WITH_RESPONSE_STREAM"]["ENDPOINT"]
93+
)
94+
assert (
95+
attributes.get(SpanAttributes.LLM_RESPONSE_MODEL)
96+
== "claude-3-7-sonnet-20250219-v1:0"
97+
)
98+
assert attributes.get(SpanAttributes.LLM_IS_STREAMING) is True
99+
assert_prompt_in_events(streaming_span.events)
100+
assert_completion_in_events(streaming_span.events)
101+
assert_token_count(attributes)
102+
103+
104+
@pytest.mark.vcr()
105+
def test_generate_embedding(exporter, aws_bedrock_client):
106+
model_id = "amazon.titan-embed-text-v1"
107+
108+
kwargs = {
109+
"modelId": model_id,
110+
"accept": "application/json",
111+
"contentType": "application/json",
112+
"body": json.dumps(
113+
{
114+
"inputText": "Say this is a test three times",
115+
}
116+
),
117+
}
118+
119+
aws_bedrock_client.invoke_model(**kwargs)
120+
spans = exporter.get_finished_spans()
121+
completion_span = spans[-1]
122+
assert completion_span.name == "aws_bedrock.invoke_model"
123+
124+
attributes = completion_span.attributes
125+
126+
assert attributes.get(SpanAttributes.LANGTRACE_SDK_NAME) == "langtrace-python-sdk"
127+
assert attributes.get(SpanAttributes.LANGTRACE_SERVICE_NAME) == "amazon"
128+
assert attributes.get(SpanAttributes.LANGTRACE_SERVICE_TYPE) == "framework"
129+
assert attributes.get(SpanAttributes.LANGTRACE_SERVICE_VERSION) == v("boto3")
130+
assert attributes.get(SpanAttributes.LANGTRACE_VERSION) == v("langtrace-python-sdk")
131+
assert attributes.get(SpanAttributes.LLM_PATH) == APIS["INVOKE_MODEL"]["ENDPOINT"]
132+
assert attributes.get(SpanAttributes.LLM_RESPONSE_MODEL) == model_id
133+
assert attributes.get(SpanAttributes.LLM_IS_STREAMING) is False
134+
assert_prompt_in_events(completion_span.events)
135+
assert_token_count(attributes)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pytest
2+
from langtrace_python_sdk.instrumentation.aws_bedrock.patch import (
3+
parse_vendor_and_model_name_from_model_id,
4+
)
5+
6+
7+
def test_model_id_parsing():
8+
model_id = "anthropic.claude-3-opus-20240229"
9+
vendor, model_name = parse_vendor_and_model_name_from_model_id(model_id)
10+
assert vendor == "anthropic"
11+
assert model_name == "claude-3-opus-20240229"
12+
13+
14+
def test_model_id_parsing_cross_region_inference():
15+
model_id = "us.anthropic.claude-3-opus-20240229"
16+
vendor, model_name = parse_vendor_and_model_name_from_model_id(model_id)
17+
assert vendor == "anthropic"
18+
assert model_name == "claude-3-opus-20240229"
19+
20+
21+
def test_model_id_parsing_arn_custom_model_inference():
22+
model_id = (
23+
"arn:aws:bedrock:us-east-1:123456789012:custom-model/amazon.my-model/abc123"
24+
)
25+
vendor, model_name = parse_vendor_and_model_name_from_model_id(model_id)
26+
assert vendor == "amazon"
27+
assert model_name == "my-model"
28+
29+
30+
def test_model_id_parsing_arn_foundation_model_inference():
31+
model_id = "arn:aws:bedrock:us-east-1:123456789012:foundation-model/anthropic.claude-3-opus-20240229"
32+
vendor, model_name = parse_vendor_and_model_name_from_model_id(model_id)
33+
assert vendor == "anthropic"
34+
assert model_name == "claude-3-opus-20240229"

0 commit comments

Comments
 (0)