Skip to content

Commit 824140a

Browse files
umaannamalaiTimPansinomergify[bot]
committed
Add Async Bedrock Instrumentation (#1307)
* Add async bedrock instrumentation. * Async bedrock tests * Refactor and add safeguards. * Remove unused imports. * Add review feedback. * Update attr checks. --------- Co-authored-by: Tim Pansino <[email protected]> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 66d32e8 commit 824140a

File tree

7 files changed

+1704
-62
lines changed

7 files changed

+1704
-62
lines changed

newrelic/hooks/external_aiobotocore.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,35 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import logging
15+
import traceback
16+
import sys
17+
from aiobotocore.response import StreamingBody
18+
from io import BytesIO
1419

1520
from newrelic.api.external_trace import ExternalTrace
1621
from newrelic.common.object_wrapper import wrap_function_wrapper
22+
from newrelic.hooks.external_botocore import (
23+
handle_bedrock_exception,
24+
run_bedrock_response_extractor,
25+
run_bedrock_request_extractor,
26+
RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE,
27+
)
28+
29+
_logger = logging.getLogger(__name__)
30+
31+
32+
# Class from https://github.com/aio-libs/aiobotocore/blob/master/tests/test_response.py
33+
# aiobotocore Apache 2 license: https://github.com/aio-libs/aiobotocore/blob/master/LICENSE
34+
class AsyncBytesIO(BytesIO):
35+
def __init__(self, *args, **kwargs):
36+
super().__init__(*args, **kwargs)
37+
self.content = self
38+
39+
async def read(self, amt=-1):
40+
if amt == -1: # aiohttp to regular response
41+
amt = None
42+
return super().read(amt)
1743

1844

1945
def _bind_make_request_params(operation_model, request_dict, *args, **kwargs):
@@ -44,5 +70,67 @@ async def wrap_endpoint_make_request(wrapped, instance, args, kwargs):
4470
return result
4571

4672

73+
async def wrap_client__make_api_call(wrapped, instance, args, kwargs):
74+
# This instrumentation only applies to bedrock runtimes so exit if this method was hit through a different path
75+
if not hasattr(instance, "_nr_is_bedrock"):
76+
return await wrapped(*args, **kwargs)
77+
78+
transaction = instance._nr_txn
79+
if not transaction:
80+
return await wrapped(*args, **kwargs)
81+
82+
# Grab all context data from botocore invoke_model instrumentation off the shared instance
83+
trace_id = getattr(instance, "_nr_trace_id", "")
84+
span_id = getattr(instance, "_nr_span_id", "")
85+
86+
request_extractor = getattr(instance, "_nr_request_extractor", None)
87+
response_extractor = getattr(instance, "_nr_response_extractor", None)
88+
ft = getattr(instance, "_nr_ft", None)
89+
90+
model = args[1].get("modelId")
91+
is_embedding = "embed" in model
92+
request_body = args[1].get("body")
93+
94+
try:
95+
response = await wrapped(*args, **kwargs)
96+
except Exception as exc:
97+
handle_bedrock_exception(
98+
exc, is_embedding, model, span_id, trace_id, request_extractor, request_body, ft, transaction
99+
)
100+
101+
if not response:
102+
return response
103+
104+
response_headers = response.get("ResponseMetadata", {}).get("HTTPHeaders") or {}
105+
bedrock_attrs = {
106+
"request_id": response_headers.get("x-amzn-requestid"),
107+
"model": model,
108+
"span_id": span_id,
109+
"trace_id": trace_id,
110+
}
111+
112+
run_bedrock_request_extractor(request_extractor, request_body, bedrock_attrs)
113+
114+
try:
115+
# Read and replace response streaming bodies
116+
response_body = await response["body"].read()
117+
if ft:
118+
ft.__exit__(None, None, None)
119+
bedrock_attrs["duration"] = ft.duration * 1000
120+
response["body"] = StreamingBody(AsyncBytesIO(response_body), len(response_body))
121+
122+
run_bedrock_response_extractor(response_extractor, response_body, bedrock_attrs, is_embedding, transaction)
123+
124+
except Exception:
125+
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
126+
127+
return response
128+
129+
47130
def instrument_aiobotocore_endpoint(module):
48131
wrap_function_wrapper(module, "AioEndpoint.make_request", wrap_endpoint_make_request)
132+
133+
134+
def instrument_aiobotocore_client(module):
135+
if hasattr(module, "AioBaseClient"):
136+
wrap_function_wrapper(module, "AioBaseClient._make_api_call", wrap_client__make_api_call)

newrelic/hooks/external_botocore.py

Lines changed: 90 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import json
16+
import inspect
1617
import logging
1718
import re
1819
import sys
@@ -546,12 +547,77 @@ def extract_bedrock_cohere_model_streaming_response(response_body, bedrock_attrs
546547
]
547548

548549

550+
def handle_bedrock_exception(exc, is_embedding, model, span_id, trace_id, request_extractor, request_body, ft, transaction):
551+
try:
552+
bedrock_attrs = {
553+
"model": model,
554+
"span_id": span_id,
555+
"trace_id": trace_id,
556+
}
557+
try:
558+
request_extractor(request_body, bedrock_attrs)
559+
except json.decoder.JSONDecodeError:
560+
pass
561+
except Exception:
562+
_logger.warning(REQUEST_EXTACTOR_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
563+
564+
error_attributes = bedrock_error_attributes(exc, bedrock_attrs)
565+
notice_error_attributes = {
566+
"http.statusCode": error_attributes.get("http.statusCode"),
567+
"error.message": error_attributes.get("error.message"),
568+
"error.code": error_attributes.get("error.code"),
569+
}
570+
571+
if is_embedding:
572+
notice_error_attributes.update({"embedding_id": str(uuid.uuid4())})
573+
else:
574+
notice_error_attributes.update({"completion_id": str(uuid.uuid4())})
575+
576+
if ft:
577+
ft.notice_error(
578+
attributes=notice_error_attributes,
579+
)
580+
581+
ft.__exit__(*sys.exc_info())
582+
error_attributes["duration"] = ft.duration * 1000
583+
584+
if is_embedding:
585+
handle_embedding_event(transaction, error_attributes)
586+
else:
587+
handle_chat_completion_event(transaction, error_attributes)
588+
except Exception:
589+
_logger.warning(EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
590+
591+
raise
592+
593+
594+
def run_bedrock_response_extractor(response_extractor, response_body, bedrock_attrs, is_embedding, transaction):
595+
# Run response extractor for non-streaming responses
596+
try:
597+
response_extractor(response_body, bedrock_attrs)
598+
except Exception:
599+
_logger.warning(RESPONSE_EXTRACTOR_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
600+
601+
if is_embedding:
602+
handle_embedding_event(transaction, bedrock_attrs)
603+
else:
604+
handle_chat_completion_event(transaction, bedrock_attrs)
605+
606+
607+
def run_bedrock_request_extractor(request_extractor, request_body, bedrock_attrs):
608+
try:
609+
request_extractor(request_body, bedrock_attrs)
610+
except json.decoder.JSONDecodeError:
611+
pass
612+
except Exception:
613+
_logger.warning(REQUEST_EXTACTOR_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
614+
615+
549616
def wrap_bedrock_runtime_invoke_model(response_streaming=False):
550617
@function_wrapper
551618
def _wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
552619
# Wrapped function only takes keyword arguments, no need for binding
553620
transaction = current_transaction()
554-
555621
if not transaction:
556622
return wrapped(*args, **kwargs)
557623

@@ -604,54 +670,32 @@ def _wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
604670
span_id = available_metadata.get("span.id")
605671
trace_id = available_metadata.get("trace.id")
606672

673+
# Store data on instance to pass context to async instrumentation
674+
instance._nr_trace_id = trace_id
675+
instance._nr_span_id = span_id
676+
instance._nr_request_extractor = request_extractor
677+
instance._nr_response_extractor = response_extractor
678+
instance._nr_stream_extractor = stream_extractor
679+
instance._nr_txn = transaction
680+
instance._nr_ft = ft
681+
682+
# Add a bedrock flag to instance so we can determine when make_api_call instrumentation is hit from non-Bedrock paths and bypass it if so
683+
instance._nr_is_bedrock = True
684+
607685
try:
686+
# For aioboto3 clients, this will call make_api_call instrumentation in external_aiobotocore
608687
response = wrapped(*args, **kwargs)
609688
except Exception as exc:
610-
try:
611-
bedrock_attrs = {
612-
"model": model,
613-
"span_id": span_id,
614-
"trace_id": trace_id,
615-
}
616-
try:
617-
request_extractor(request_body, bedrock_attrs)
618-
except json.decoder.JSONDecodeError:
619-
pass
620-
except Exception:
621-
_logger.warning(REQUEST_EXTACTOR_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
622-
623-
error_attributes = bedrock_error_attributes(exc, bedrock_attrs)
624-
notice_error_attributes = {
625-
"http.statusCode": error_attributes.get("http.statusCode"),
626-
"error.message": error_attributes.get("error.message"),
627-
"error.code": error_attributes.get("error.code"),
628-
}
629-
630-
if is_embedding:
631-
notice_error_attributes.update({"embedding_id": str(uuid.uuid4())})
632-
else:
633-
notice_error_attributes.update({"completion_id": str(uuid.uuid4())})
634-
635-
ft.notice_error(
636-
attributes=notice_error_attributes,
637-
)
638-
639-
ft.__exit__(*sys.exc_info())
640-
error_attributes["duration"] = ft.duration * 1000
641-
642-
if operation == "embedding":
643-
handle_embedding_event(transaction, error_attributes)
644-
else:
645-
handle_chat_completion_event(transaction, error_attributes)
646-
except Exception:
647-
_logger.warning(EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
648-
649-
raise
689+
handle_bedrock_exception(exc, is_embedding, model, span_id, trace_id, request_extractor, request_body, ft, transaction)
650690

651691
if not response or response_streaming and not settings.ai_monitoring.streaming.enabled:
652692
ft.__exit__(None, None, None)
653693
return response
654694

695+
# Let the instrumentation of make_api_call in the aioboto3 client handle it if we have an async case
696+
if inspect.iscoroutine(response):
697+
return response
698+
655699
if response_streaming and operation == "embedding":
656700
# This combination is not supported at time of writing, but may become
657701
# a supported feature in the future. Instrumentation will need to be written
@@ -668,12 +712,7 @@ def _wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
668712
"trace_id": trace_id,
669713
}
670714

671-
try:
672-
request_extractor(request_body, bedrock_attrs)
673-
except json.decoder.JSONDecodeError:
674-
pass
675-
except Exception:
676-
_logger.warning(REQUEST_EXTACTOR_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
715+
run_bedrock_request_extractor(request_extractor, request_body, bedrock_attrs)
677716

678717
try:
679718
if response_streaming:
@@ -691,16 +730,7 @@ def _wrap_bedrock_runtime_invoke_model(wrapped, instance, args, kwargs):
691730
bedrock_attrs["duration"] = ft.duration * 1000
692731
response["body"] = StreamingBody(BytesIO(response_body), len(response_body))
693732

694-
# Run response extractor for non-streaming responses
695-
try:
696-
response_extractor(response_body, bedrock_attrs)
697-
except Exception:
698-
_logger.warning(RESPONSE_EXTRACTOR_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
699-
700-
if operation == "embedding":
701-
handle_embedding_event(transaction, bedrock_attrs)
702-
else:
703-
handle_chat_completion_event(transaction, bedrock_attrs)
733+
run_bedrock_response_extractor(response_extractor, response_body, bedrock_attrs, is_embedding, transaction)
704734

705735
except Exception:
706736
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE % traceback.format_exception(*sys.exc_info()))
@@ -864,7 +894,6 @@ def handle_chat_completion_event(transaction, bedrock_attrs):
864894
llm_context_attrs = getattr(transaction, "_llm_context_attrs", None)
865895
if llm_context_attrs:
866896
llm_metadata_dict.update(llm_context_attrs)
867-
868897
span_id = bedrock_attrs.get("span_id", None)
869898
trace_id = bedrock_attrs.get("trace_id", None)
870899
request_id = bedrock_attrs.get("request_id", None)
@@ -1009,9 +1038,9 @@ def _nr_dynamodb_datastore_trace_wrapper_(wrapped, instance, args, kwargs):
10091038
partition = "aws-us-gov"
10101039

10111040
if partition and region and account_id and _target:
1012-
agent_attrs["cloud.resource_id"] = (
1013-
f"arn:{partition}:dynamodb:{region}:{account_id:012d}:table/{_target}"
1014-
)
1041+
agent_attrs[
1042+
"cloud.resource_id"
1043+
] = f"arn:{partition}:dynamodb:{region}:{account_id:012d}:table/{_target}"
10151044
agent_attrs["db.system"] = "DynamoDB"
10161045

10171046
except Exception as e:

tests/external_aiobotocore/conftest.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import socket
1818
import threading
1919

20+
import pytest
2021
import moto.server
2122
import werkzeug.serving
2223
from testing_support.fixture.event_loop import ( # noqa: F401, pylint: disable=W0611
@@ -27,6 +28,14 @@
2728
collector_available_fixture,
2829
)
2930

31+
from newrelic.common.package_version_utils import (
32+
get_package_version,
33+
get_package_version_tuple,
34+
)
35+
from external_botocore._mock_external_bedrock_server import MockExternalBedrockServer
36+
37+
BOTOCORE_VERSION = get_package_version("botocore")
38+
3039
PORT = 4443
3140
AWS_ACCESS_KEY_ID = "AAAAAAAAAAAACCESSKEY"
3241
AWS_SECRET_ACCESS_KEY = "AAAAAASECRETKEY" # nosec
@@ -40,6 +49,8 @@
4049
"transaction_tracer.stack_trace_threshold": 0.0,
4150
"debug.log_data_collector_payloads": True,
4251
"debug.record_transaction_failure": True,
52+
"custom_insights_events.max_attribute_value": 4096,
53+
"ai_monitoring.enabled": True,
4354
}
4455
collector_agent_registration = collector_agent_registration_fixture(
4556
app_name="Python Agent Test (external_aiobotocore)",
@@ -146,3 +157,37 @@ async def _stop(self):
146157
self._server.shutdown()
147158

148159
self._thread.join()
160+
161+
162+
# Bedrock Fixtures
163+
@pytest.fixture(scope="session")
164+
def bedrock_server(loop):
165+
"""
166+
This fixture will create a mocked backend for testing purposes.
167+
"""
168+
import aiobotocore
169+
170+
from newrelic.core.config import _environ_as_bool
171+
172+
if get_package_version_tuple("botocore") < (1, 31, 57):
173+
pytest.skip(reason="Bedrock Runtime not available.")
174+
175+
if _environ_as_bool("NEW_RELIC_TESTING_RECORD_BEDROCK_RESPONSES", False):
176+
raise NotImplementedError("To record test responses, use botocore instead.")
177+
178+
# Use mocked Bedrock backend and prerecorded responses
179+
with MockExternalBedrockServer() as server:
180+
session = aiobotocore.session.get_session()
181+
client = loop.run_until_complete(
182+
session.create_client(
183+
"bedrock-runtime",
184+
"us-east-1",
185+
endpoint_url=f"http://localhost:{server.port}",
186+
aws_access_key_id="NOT-A-REAL-SECRET",
187+
aws_secret_access_key="NOT-A-REAL-SECRET",
188+
).__aenter__()
189+
)
190+
191+
yield client
192+
193+
loop.run_until_complete(client.__aexit__(None, None, None))

0 commit comments

Comments
 (0)