Skip to content

Commit 01875d3

Browse files
committed
Merge branch 'development' of github.com:Scale3-Labs/langtrace-python-sdk into ali/s3en-2856-support-for-mongodb-vector-db
2 parents 377c95d + 1deae92 commit 01875d3

File tree

14 files changed

+299
-4
lines changed

14 files changed

+299
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ Langtrace automatically captures traces from the following vendors:
263263
| Langchain | Framework | :x: | :white_check_mark: |
264264
| Langgraph | Framework | :x: | :white_check_mark: |
265265
| LlamaIndex | Framework | :white_check_mark: | :white_check_mark: |
266+
| AWS Bedrock | Framework | :white_check_mark: | :white_check_mark: |
266267
| LiteLLM | Framework | :x: | :white_check_mark: |
267268
| DSPy | Framework | :x: | :white_check_mark: |
268269
| CrewAI | Framework | :x: | :white_check_mark: |

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ classifiers = [
1818
"Operating System :: OS Independent",
1919
]
2020
dependencies = [
21-
'trace-attributes==7.0.4',
21+
'trace-attributes==7.1.0',
2222
'opentelemetry-api>=1.25.0',
2323
'opentelemetry-sdk>=1.25.0',
2424
'opentelemetry-instrumentation>=0.47b0',
@@ -57,7 +57,8 @@ dev = [
5757
"google-generativeai",
5858
"google-cloud-aiplatform",
5959
"mistralai",
60-
"embedchain"
60+
"boto3",
61+
"embedchain",
6162
]
6263

6364
test = ["pytest", "pytest-vcr", "pytest-asyncio"]
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from examples.awsbedrock_examples.converse import use_converse
2+
from langtrace_python_sdk import langtrace, with_langtrace_root_span
3+
4+
langtrace.init()
5+
6+
7+
class AWSBedrockRunner:
8+
@with_langtrace_root_span("AWS_Bedrock")
9+
def run(self):
10+
use_converse()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
import boto3
3+
from langtrace_python_sdk import langtrace
4+
5+
langtrace.init(api_key=os.environ["LANGTRACE_API_KEY"])
6+
7+
def use_converse():
8+
model_id = "anthropic.claude-3-haiku-20240307-v1:0"
9+
client = boto3.client(
10+
"bedrock-runtime",
11+
region_name="us-east-1",
12+
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
13+
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
14+
)
15+
conversation = [
16+
{
17+
"role": "user",
18+
"content": [{"text": "Write a story about a magic backpack."}],
19+
}
20+
]
21+
22+
try:
23+
response = client.converse(
24+
modelId=model_id,
25+
messages=conversation,
26+
inferenceConfig={"maxTokens":4096,"temperature":0},
27+
additionalModelRequestFields={"top_k":250}
28+
)
29+
response_text = response["output"]["message"]["content"][0]["text"]
30+
print(response_text)
31+
32+
except (Exception) as e:
33+
print(f"ERROR: Can't invoke '{model_id}'. Reason: {e}")
34+
exit(1)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from langtrace.trace_attributes import AWSBedrockMethods
2+
3+
APIS = {
4+
"CONVERSE": {
5+
"METHOD": AWSBedrockMethods.CONVERSE.value,
6+
"ENDPOINT": "/converse",
7+
},
8+
"CONVERSE_STREAM": {
9+
"METHOD": AWSBedrockMethods.CONVERSE_STREAM.value,
10+
"ENDPOINT": "/converse-stream",
11+
},
12+
}

src/langtrace_python_sdk/constants/instrumentation/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
"AUTOGEN": "Autogen",
3636
"XAI": "XAI",
3737
"MONGODB": "MongoDB",
38+
"AWS_BEDROCK": "AWS Bedrock",
3839
}
3940

4041
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes"

src/langtrace_python_sdk/instrumentation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from .vertexai import VertexAIInstrumentation
1919
from .gemini import GeminiInstrumentation
2020
from .mistral import MistralInstrumentation
21+
from .aws_bedrock import AWSBedrockInstrumentation
2122
from .embedchain import EmbedchainInstrumentation
2223
from .litellm import LiteLLMInstrumentation
2324
from .pymongo import PyMongoInstrumentation
@@ -46,4 +47,5 @@
4647
"GeminiInstrumentation",
4748
"MistralInstrumentation",
4849
"PyMongoInstrumentation",
50+
"AWSBedrockInstrumentation",
4951
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .instrumentation import AWSBedrockInstrumentation
2+
3+
__all__ = ["AWSBedrockInstrumentation"]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""
2+
Copyright (c) 2024 Scale3 Labs
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import importlib.metadata
18+
import logging
19+
from typing import Collection
20+
21+
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
22+
from opentelemetry.trace import get_tracer
23+
from wrapt import wrap_function_wrapper as _W
24+
25+
from langtrace_python_sdk.instrumentation.aws_bedrock.patch import (
26+
converse, converse_stream
27+
)
28+
29+
logging.basicConfig(level=logging.FATAL)
30+
31+
def _patch_client(client, version: str, tracer) -> None:
32+
33+
# Store original methods
34+
original_converse = client.converse
35+
36+
# Replace with wrapped versions
37+
client.converse = converse("aws_bedrock.converse", version, tracer)(original_converse)
38+
39+
class AWSBedrockInstrumentation(BaseInstrumentor):
40+
41+
def instrumentation_dependencies(self) -> Collection[str]:
42+
return ["boto3 >= 1.35.31"]
43+
44+
def _instrument(self, **kwargs):
45+
tracer_provider = kwargs.get("tracer_provider")
46+
tracer = get_tracer(__name__, "", tracer_provider)
47+
version = importlib.metadata.version("boto3")
48+
49+
def wrap_create_client(wrapped, instance, args, kwargs):
50+
result = wrapped(*args, **kwargs)
51+
if args and args[0] == 'bedrock-runtime':
52+
_patch_client(result, version, tracer)
53+
return result
54+
55+
_W("boto3", "client", wrap_create_client)
56+
57+
def _uninstrument(self, **kwargs):
58+
pass
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
"""
2+
Copyright (c) 2024 Scale3 Labs
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import json
18+
from functools import wraps
19+
20+
from langtrace.trace_attributes import (
21+
LLMSpanAttributes,
22+
SpanAttributes,
23+
)
24+
from langtrace_python_sdk.utils import set_span_attribute
25+
from langtrace_python_sdk.utils.silently_fail import silently_fail
26+
from opentelemetry import trace
27+
from opentelemetry.trace import SpanKind
28+
from opentelemetry.trace.status import Status, StatusCode
29+
from opentelemetry.trace.propagation import set_span_in_context
30+
from langtrace_python_sdk.constants.instrumentation.common import (
31+
SERVICE_PROVIDERS,
32+
)
33+
from langtrace_python_sdk.constants.instrumentation.aws_bedrock import APIS
34+
from langtrace_python_sdk.utils.llm import (
35+
get_extra_attributes,
36+
get_langtrace_attributes,
37+
get_llm_request_attributes,
38+
get_llm_url,
39+
get_span_name,
40+
set_event_completion,
41+
set_span_attributes,
42+
)
43+
44+
45+
def traced_aws_bedrock_call(api_name: str, operation_name: str):
46+
def decorator(method_name: str, version: str, tracer):
47+
def wrapper(original_method):
48+
@wraps(original_method)
49+
def wrapped_method(*args, **kwargs):
50+
service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"]
51+
52+
input_content = [
53+
{
54+
'role': message.get('role', 'user'),
55+
'content': message.get('content', [])[0].get('text', "")
56+
}
57+
for message in kwargs.get('messages', [])
58+
]
59+
60+
span_attributes = {
61+
**get_langtrace_attributes(version, service_provider, vendor_type="framework"),
62+
**get_llm_request_attributes(kwargs, operation_name=operation_name, prompts=input_content),
63+
**get_llm_url(args[0] if args else None),
64+
SpanAttributes.LLM_PATH: APIS[api_name]["ENDPOINT"],
65+
**get_extra_attributes(),
66+
}
67+
68+
if api_name == "CONVERSE":
69+
span_attributes.update({
70+
SpanAttributes.LLM_REQUEST_MODEL: kwargs.get('modelId'),
71+
SpanAttributes.LLM_REQUEST_MAX_TOKENS: kwargs.get('inferenceConfig', {}).get('maxTokens'),
72+
SpanAttributes.LLM_REQUEST_TEMPERATURE: kwargs.get('inferenceConfig', {}).get('temperature'),
73+
SpanAttributes.LLM_REQUEST_TOP_P: kwargs.get('inferenceConfig', {}).get('top_p'),
74+
})
75+
76+
attributes = LLMSpanAttributes(**span_attributes)
77+
78+
with tracer.start_as_current_span(
79+
name=get_span_name(APIS[api_name]["METHOD"]),
80+
kind=SpanKind.CLIENT,
81+
context=set_span_in_context(trace.get_current_span()),
82+
) as span:
83+
set_span_attributes(span, attributes)
84+
try:
85+
result = original_method(*args, **kwargs)
86+
_set_response_attributes(span, kwargs, result)
87+
span.set_status(StatusCode.OK)
88+
return result
89+
except Exception as err:
90+
span.record_exception(err)
91+
span.set_status(Status(StatusCode.ERROR, str(err)))
92+
raise err
93+
94+
return wrapped_method
95+
return wrapper
96+
return decorator
97+
98+
99+
converse = traced_aws_bedrock_call("CONVERSE", "converse")
100+
101+
102+
def converse_stream(original_method, version, tracer):
103+
def traced_method(wrapped, instance, args, kwargs):
104+
service_provider = SERVICE_PROVIDERS["AWS_BEDROCK"]
105+
106+
span_attributes = {
107+
**get_langtrace_attributes
108+
(version, service_provider, vendor_type="llm"),
109+
**get_llm_request_attributes(kwargs),
110+
**get_llm_url(instance),
111+
SpanAttributes.LLM_PATH: APIS["CONVERSE_STREAM"]["ENDPOINT"],
112+
**get_extra_attributes(),
113+
}
114+
115+
attributes = LLMSpanAttributes(**span_attributes)
116+
117+
with tracer.start_as_current_span(
118+
name=get_span_name(APIS["CONVERSE_STREAM"]["METHOD"]),
119+
kind=SpanKind.CLIENT,
120+
context=set_span_in_context(trace.get_current_span()),
121+
) as span:
122+
set_span_attributes(span, attributes)
123+
try:
124+
result = wrapped(*args, **kwargs)
125+
_set_response_attributes(span, kwargs, result)
126+
span.set_status(StatusCode.OK)
127+
return result
128+
except Exception as err:
129+
span.record_exception(err)
130+
span.set_status(Status(StatusCode.ERROR, str(err)))
131+
raise err
132+
133+
return traced_method
134+
135+
136+
@silently_fail
137+
def _set_response_attributes(span, kwargs, result):
138+
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, kwargs.get('modelId'))
139+
set_span_attribute(span, SpanAttributes.LLM_TOP_K, kwargs.get('additionalModelRequestFields', {}).get('top_k'))
140+
content = result.get('output', {}).get('message', {}).get('content', [])
141+
if len(content) > 0:
142+
role = result.get('output', {}).get('message', {}).get('role', "assistant")
143+
responses = [
144+
{"role": role, "content": c.get('text', "")}
145+
for c in content
146+
]
147+
set_event_completion(span, responses)
148+
149+
if 'usage' in result:
150+
set_span_attributes(
151+
span,
152+
{
153+
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS: result['usage'].get('outputTokens'),
154+
SpanAttributes.LLM_USAGE_PROMPT_TOKENS: result['usage'].get('inputTokens'),
155+
SpanAttributes.LLM_USAGE_TOTAL_TOKENS: result['usage'].get('totalTokens'),
156+
}
157+
)

0 commit comments

Comments
 (0)