Skip to content

Commit 4fb3308

Browse files
Add support for streaming in async Bedrock (#1312)
* Add async bedrock instrumentation. * Async bedrock tests * Refactor and add safeguards. * Remove unused imports. * Add streaming support for async Bedrock. * Update attr checks. * Cleanup. --------- Co-authored-by: Tim Pansino <[email protected]>
1 parent 9429694 commit 4fb3308

File tree

8 files changed

+1859
-65
lines changed

8 files changed

+1859
-65
lines changed

newrelic/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3993,6 +3993,8 @@ def _process_module_builtin_defaults():
39933993
"aiobotocore.endpoint", "newrelic.hooks.external_aiobotocore", "instrument_aiobotocore_endpoint"
39943994
)
39953995

3996+
_process_module_definition("aiobotocore.client", "newrelic.hooks.external_aiobotocore", "instrument_aiobotocore_client")
3997+
39963998
_process_module_definition("botocore.endpoint", "newrelic.hooks.external_botocore", "instrument_botocore_endpoint")
39973999
_process_module_definition("botocore.client", "newrelic.hooks.external_botocore", "instrument_botocore_client")
39984000

newrelic/hooks/external_aiobotocore.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,37 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
15+
import traceback
16+
import sys
17+
from aiobotocore.response import StreamingBody
18+
from io import BytesIO
1419

1520
from newrelic.api.external_trace import ExternalTrace
1621
from newrelic.common.object_wrapper import wrap_function_wrapper
22+
from newrelic.hooks.external_botocore import (
23+
AsyncEventStreamWrapper,
24+
handle_bedrock_exception,
25+
run_bedrock_response_extractor,
26+
run_bedrock_request_extractor,
27+
EMBEDDING_STREAMING_UNSUPPORTED_LOG_MESSAGE,
28+
RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE,
29+
)
30+
31+
_logger = logging.getLogger(__name__)
32+
33+
34+
# Class from https://github.com/aio-libs/aiobotocore/blob/master/tests/test_response.py
35+
# aiobotocore Apache 2 license: https://github.com/aio-libs/aiobotocore/blob/master/LICENSE
36+
class AsyncBytesIO(BytesIO):
37+
def __init__(self, *args, **kwargs):
38+
super().__init__(*args, **kwargs)
39+
self.content = self
40+
41+
async def read(self, amt=-1):
42+
if amt == -1: # aiohttp to regular response
43+
amt = None
44+
return super().read(amt)
1745

1846

1947
def _bind_make_request_params(operation_model, request_dict, *args, **kwargs):
@@ -44,5 +72,105 @@ async def wrap_endpoint_make_request(wrapped, instance, args, kwargs):
4472
return result
4573

4674

75+
async def wrap_client__make_api_call(wrapped, instance, args, kwargs):
76+
# This instrumentation only applies to bedrock runtimes so exit if this method was hit through a different path
77+
if not hasattr(instance, "_nr_is_bedrock"):
78+
return await wrapped(*args, **kwargs)
79+
80+
transaction = getattr(instance, "_nr_txn", None)
81+
if not transaction:
82+
return await wrapped(*args, **kwargs)
83+
84+
settings = getattr(instance, "_nr_settings", None)
85+
86+
# Early exit if we can't access the shared settings object from invoke_model instrumentation
87+
# This settings object helps us determine if AIM was enabled as well as streaming
88+
if not settings:
89+
return await wrapped(*args, **kwargs)
90+
91+
if not settings.ai_monitoring.enabled:
92+
return await wrapped(*args, **kwargs)
93+
94+
# Grab all context data from botocore invoke_model instrumentation off the shared instance
95+
trace_id = getattr(instance, "_nr_trace_id", "")
96+
span_id = getattr(instance, "_nr_span_id", "")
97+
98+
request_extractor = getattr(instance, "_nr_request_extractor", None)
99+
response_extractor = getattr(instance, "_nr_response_extractor", None)
100+
stream_extractor = getattr(instance, "_nr_stream_extractor", None)
101+
response_streaming = getattr(instance, "_nr_response_streaming", False)
102+
103+
ft = getattr(instance, "_nr_ft", None)
104+
105+
if len(args) >= 2:
106+
model = args[1].get("modelId")
107+
request_body = args[1].get("body")
108+
is_embedding = "embed" in model
109+
else:
110+
model = ""
111+
request_body = None
112+
is_embedding = False
113+
114+
try:
115+
response = await wrapped(*args, **kwargs)
116+
except Exception as exc:
117+
handle_bedrock_exception(
118+
exc, is_embedding, model, span_id, trace_id, request_extractor, request_body, ft, transaction
119+
)
120+
121+
if not response or response_streaming and not settings.ai_monitoring.streaming.enabled:
122+
if ft:
123+
ft.__exit__(None, None, None)
124+
return response
125+
126+
if response_streaming and is_embedding:
127+
# This combination is not supported at time of writing, but may become
128+
# a supported feature in the future. Instrumentation will need to be written
129+
# if this becomes available.
130+
_logger.warning(EMBEDDING_STREAMING_UNSUPPORTED_LOG_MESSAGE)
131+
if ft:
132+
ft.__exit__(None, None, None)
133+
return response
134+
135+
response_headers = response.get("ResponseMetadata", {}).get("HTTPHeaders") or {}
136+
bedrock_attrs = {
137+
"request_id": response_headers.get("x-amzn-requestid"),
138+
"model": model,
139+
"span_id": span_id,
140+
"trace_id": trace_id,
141+
}
142+
143+
run_bedrock_request_extractor(request_extractor, request_body, bedrock_attrs)
144+
145+
try:
146+
if response_streaming:
147+
# Wrap EventStream object here to intercept __iter__ method instead of instrumenting class.
148+
# This class is used in numerous other services in botocore, and would cause conflicts.
149+
response["body"] = body = AsyncEventStreamWrapper(response["body"])
150+
body._nr_ft = ft or None
151+
body._nr_bedrock_attrs = bedrock_attrs or {}
152+
body._nr_model_extractor = stream_extractor or None
153+
return response
154+
155+
# Read and replace response streaming bodies
156+
response_body = await response["body"].read()
157+
158+
if ft:
159+
ft.__exit__(None, None, None)
160+
bedrock_attrs["duration"] = ft.duration * 1000
161+
response["body"] = StreamingBody(AsyncBytesIO(response_body), len(response_body))
162+
run_bedrock_response_extractor(response_extractor, response_body, bedrock_attrs, is_embedding, transaction)
163+
164+
except Exception:
165+
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
166+
167+
return response
168+
169+
47170
def instrument_aiobotocore_endpoint(module):
48171
wrap_function_wrapper(module, "AioEndpoint.make_request", wrap_endpoint_make_request)
172+
173+
174+
def instrument_aiobotocore_client(module):
175+
if hasattr(module, "AioBaseClient"):
176+
wrap_function_wrapper(module, "AioBaseClient._make_api_call", wrap_client__make_api_call)

0 commit comments

Comments
 (0)