Skip to content

Commit b62f2d4

Browse files
umaannamalaiTimPansinomergify[bot]
authored
Add support for async Bedrock streaming. (#1315)
* Add async bedrock instrumentation. * Async bedrock tests * Refactor and add safeguards. * Remove newline. * Formatting fixes. --------- Co-authored-by: Tim Pansino <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent b3e6bf1 commit b62f2d4

File tree

3 files changed

+566
-453
lines changed

3 files changed

+566
-453
lines changed

newrelic/hooks/external_aiobotocore.py

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
# Copyright 2010 New Relic, Inc.
23
#
34
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,9 +21,11 @@
2021
from newrelic.api.external_trace import ExternalTrace
2122
from newrelic.common.object_wrapper import wrap_function_wrapper
2223
from newrelic.hooks.external_botocore import (
24+
AsyncEventStreamWrapper,
2325
handle_bedrock_exception,
2426
run_bedrock_response_extractor,
2527
run_bedrock_request_extractor,
28+
EMBEDDING_STREAMING_UNSUPPORTED_LOG_MESSAGE,
2629
RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE,
2730
)
2831

@@ -75,21 +78,36 @@ async def wrap_client__make_api_call(wrapped, instance, args, kwargs):
7578
if not hasattr(instance, "_nr_is_bedrock"):
7679
return await wrapped(*args, **kwargs)
7780

78-
transaction = instance._nr_txn
81+
transaction = getattr(instance, "_nr_txn", None)
7982
if not transaction:
8083
return await wrapped(*args, **kwargs)
8184

85+
settings = getattr(instance, "_nr_settings", None)
86+
87+
# Early exit if we can't access the shared settings object from invoke_model instrumentation
88+
# This settings object helps us determine if AIM was enabled as well as streaming
89+
if not (settings and settings.ai_monitoring.enabled):
90+
return await wrapped(*args, **kwargs)
91+
8292
# Grab all context data from botocore invoke_model instrumentation off the shared instance
8393
trace_id = getattr(instance, "_nr_trace_id", "")
8494
span_id = getattr(instance, "_nr_span_id", "")
8595

8696
request_extractor = getattr(instance, "_nr_request_extractor", None)
8797
response_extractor = getattr(instance, "_nr_response_extractor", None)
98+
stream_extractor = getattr(instance, "_nr_stream_extractor", None)
99+
response_streaming = getattr(instance, "_nr_response_streaming", False)
100+
88101
ft = getattr(instance, "_nr_ft", None)
89102

90-
model = args[1].get("modelId")
91-
is_embedding = "embed" in model
92-
request_body = args[1].get("body")
103+
if len(args) >= 2:
104+
model = args[1].get("modelId")
105+
request_body = args[1].get("body")
106+
is_embedding = "embed" in model
107+
else:
108+
model = ""
109+
request_body = None
110+
is_embedding = False
93111

94112
try:
95113
response = await wrapped(*args, **kwargs)
@@ -98,7 +116,18 @@ async def wrap_client__make_api_call(wrapped, instance, args, kwargs):
98116
exc, is_embedding, model, span_id, trace_id, request_extractor, request_body, ft, transaction
99117
)
100118

101-
if not response:
119+
if not response or response_streaming and not settings.ai_monitoring.streaming.enabled:
120+
if ft:
121+
ft.__exit__(None, None, None)
122+
return response
123+
124+
if response_streaming and is_embedding:
125+
# This combination is not supported at time of writing, but may become
126+
# a supported feature in the future. Instrumentation will need to be written
127+
# if this becomes available.
128+
_logger.warning(EMBEDDING_STREAMING_UNSUPPORTED_LOG_MESSAGE)
129+
if ft:
130+
ft.__exit__(None, None, None)
102131
return response
103132

104133
response_headers = response.get("ResponseMetadata", {}).get("HTTPHeaders") or {}
@@ -112,13 +141,22 @@ async def wrap_client__make_api_call(wrapped, instance, args, kwargs):
112141
run_bedrock_request_extractor(request_extractor, request_body, bedrock_attrs)
113142

114143
try:
144+
if response_streaming:
145+
# Wrap EventStream object here to intercept __iter__ method instead of instrumenting class.
146+
# This class is used in numerous other services in botocore, and would cause conflicts.
147+
response["body"] = body = AsyncEventStreamWrapper(response["body"])
148+
body._nr_ft = ft or None
149+
body._nr_bedrock_attrs = bedrock_attrs or {}
150+
body._nr_model_extractor = stream_extractor or None
151+
return response
152+
115153
# Read and replace response streaming bodies
116154
response_body = await response["body"].read()
155+
117156
if ft:
118157
ft.__exit__(None, None, None)
119158
bedrock_attrs["duration"] = ft.duration * 1000
120159
response["body"] = StreamingBody(AsyncBytesIO(response_body), len(response_body))
121-
122160
run_bedrock_response_extractor(response_extractor, response_body, bedrock_attrs, is_embedding, transaction)
123161

124162
except Exception:

newrelic/hooks/external_botocore.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,9 @@ def extract_firehose_agent_attrs(instance, *args, **kwargs):
141141
region = instance._client_config.region_name
142142
if account_id and region:
143143
agent_attrs["cloud.platform"] = "aws_kinesis_delivery_streams"
144-
agent_attrs["cloud.resource_id"] = (
145-
f"arn:aws:firehose:{region}:{account_id}:deliverystream/{stream_name}"
146-
)
144+
agent_attrs[
145+
"cloud.resource_id"
146+
] = f"arn:aws:firehose:{region}:{account_id}:deliverystream/{stream_name}"
147147
except Exception as e:
148148
_logger.debug("Failed to capture AWS Kinesis Delivery Stream (Firehose) info.", exc_info=True)
149149
return agent_attrs
@@ -547,7 +547,9 @@ def extract_bedrock_cohere_model_streaming_response(response_body, bedrock_attrs
547547
]
548548

549549

550-
def handle_bedrock_exception(exc, is_embedding, model, span_id, trace_id, request_extractor, request_body, ft, transaction):
550+
def handle_bedrock_exception(
551+
exc, is_embedding, model, span_id, trace_id, request_extractor, request_body, ft, transaction
552+
):
551553
try:
552554
bedrock_attrs = {
553555
"model": model,
@@ -678,6 +680,8 @@ def _wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
678680
instance._nr_stream_extractor = stream_extractor
679681
instance._nr_txn = transaction
680682
instance._nr_ft = ft
683+
instance._nr_response_streaming = response_streaming
684+
instance._nr_settings = settings
681685

682686
# Add a bedrock flag to instance so we can determine when make_api_call instrumentation is hit from non-Bedrock paths and bypass it if so
683687
instance._nr_is_bedrock = True
@@ -686,7 +690,9 @@ def _wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
686690
# For aioboto3 clients, this will call make_api_call instrumentation in external_aiobotocore
687691
response = wrapped(*args, **kwargs)
688692
except Exception as exc:
689-
handle_bedrock_exception(exc, is_embedding, model, span_id, trace_id, request_extractor, request_body, ft, transaction)
693+
handle_bedrock_exception(
694+
exc, is_embedding, model, span_id, trace_id, request_extractor, request_body, ft, transaction
695+
)
690696

691697
if not response or response_streaming and not settings.ai_monitoring.streaming.enabled:
692698
ft.__exit__(None, None, None)
@@ -777,6 +783,42 @@ def close(self):
777783
return super(GeneratorProxy, self).close()
778784

779785

786+
class AsyncEventStreamWrapper(ObjectProxy):
787+
def __aiter__(self):
788+
g = AsyncGeneratorProxy(self.__wrapped__.__aiter__())
789+
g._nr_ft = getattr(self, "_nr_ft", None)
790+
g._nr_bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
791+
g._nr_model_extractor = getattr(self, "_nr_model_extractor", NULL_EXTRACTOR)
792+
return g
793+
794+
795+
class AsyncGeneratorProxy(ObjectProxy):
796+
def __init__(self, wrapped):
797+
super(AsyncGeneratorProxy, self).__init__(wrapped)
798+
799+
def __aiter__(self):
800+
return self
801+
802+
async def __anext__(self):
803+
transaction = current_transaction()
804+
if not transaction:
805+
return await self.__wrapped__.__anext__()
806+
return_val = None
807+
try:
808+
return_val = await self.__wrapped__.__anext__()
809+
record_stream_chunk(self, return_val, transaction)
810+
except StopAsyncIteration as e:
811+
record_events_on_stop_iteration(self, transaction)
812+
raise
813+
except Exception as exc:
814+
record_error(self, transaction, exc)
815+
raise
816+
return return_val
817+
818+
async def aclose(self):
819+
return await super(AsyncGeneratorProxy, self).aclose()
820+
821+
780822
def record_stream_chunk(self, return_val, transaction):
781823
if return_val:
782824
try:

0 commit comments

Comments
 (0)