Skip to content

Commit 43abd99

Browse files
committed
Instrument converse streaming
1 parent 08b9430 commit 43abd99

File tree

1 file changed

+119
-77
lines changed

1 file changed

+119
-77
lines changed

newrelic/hooks/external_botocore.py

Lines changed: 119 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -819,6 +819,16 @@ def _wrap_bedrock_runtime_converse(wrapped, instance, args, kwargs):
819819
bedrock_attrs = extract_bedrock_converse_attrs(kwargs, response, response_headers, model, span_id, trace_id)
820820

821821
try:
822+
if response_streaming:
823+
# Wrap EventStream object here to intercept __iter__ method instead of instrumenting class.
824+
# This class is used in numerous other services in botocore, and would cause conflicts.
825+
response["stream"] = stream = EventStreamWrapper(response["stream"])
826+
stream._nr_ft = ft
827+
stream._nr_bedrock_attrs = bedrock_attrs
828+
stream._nr_model_extractor = stream_extractor
829+
stream._nr_is_converse = True
830+
return response
831+
822832
ft.__exit__(None, None, None)
823833
bedrock_attrs["duration"] = ft.duration * 1000
824834
run_bedrock_response_extractor(response_extractor, {}, bedrock_attrs, False, transaction)
@@ -833,6 +843,7 @@ def _wrap_bedrock_runtime_converse(wrapped, instance, args, kwargs):
833843

834844
def extract_bedrock_converse_attrs(kwargs, response, response_headers, model, span_id, trace_id):
835845
input_message_list = []
846+
output_message_list = None
836847
# If a system message is supplied, it is under its own key in kwargs rather than with the other input messages
837848
if "system" in kwargs.keys():
838849
input_message_list.extend({"role": "system", "content": result["text"]} for result in kwargs.get("system", []))
@@ -843,22 +854,26 @@ def extract_bedrock_converse_attrs(kwargs, response, response_headers, model, sp
843854
[{"role": "user", "content": result["text"]} for result in kwargs["messages"][-1].get("content", [])]
844855
)
845856

846-
output_message_list = [
847-
{"role": "assistant", "content": result["text"]}
848-
for result in response.get("output").get("message").get("content", [])
849-
]
857+
if "output" in response:
858+
output_message_list = [
859+
{"role": "assistant", "content": result["text"]}
860+
for result in response.get("output").get("message").get("content", [])
861+
]
850862

851863
bedrock_attrs = {
852864
"request_id": response_headers.get("x-amzn-requestid"),
853865
"model": model,
854866
"span_id": span_id,
855867
"trace_id": trace_id,
856868
"response.choices.finish_reason": response.get("stopReason"),
857-
"output_message_list": output_message_list,
858869
"request.max_tokens": kwargs.get("inferenceConfig", {}).get("maxTokens", None),
859870
"request.temperature": kwargs.get("inferenceConfig", {}).get("temperature", None),
860871
"input_message_list": input_message_list,
861872
}
873+
874+
if output_message_list is not None:
875+
bedrock_attrs["output_message_list"] = output_message_list
876+
862877
return bedrock_attrs
863878

864879

@@ -886,18 +901,105 @@ def __next__(self):
886901
return_val = None
887902
try:
888903
return_val = self.__wrapped__.__next__()
889-
record_stream_chunk(self, return_val, transaction)
904+
self.record_stream_chunk(return_val, transaction)
890905
except StopIteration:
891-
record_events_on_stop_iteration(self, transaction)
906+
self.record_events_on_stop_iteration(transaction)
892907
raise
893908
except Exception as exc:
894-
record_error(self, transaction, exc)
909+
self.record_error(transaction, exc)
895910
raise
896911
return return_val
897912

898913
def close(self):
899914
return super().close()
900915

916+
def record_events_on_stop_iteration(self, transaction):
917+
if hasattr(self, "_nr_ft"):
918+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
919+
self._nr_ft.__exit__(None, None, None)
920+
921+
# If there are no bedrock attrs exit early as there's no data to record.
922+
if not bedrock_attrs:
923+
return
924+
925+
try:
926+
bedrock_attrs["duration"] = self._nr_ft.duration * 1000
927+
handle_chat_completion_event(transaction, bedrock_attrs)
928+
except Exception:
929+
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE, exc_info=True)
930+
931+
# Clear cached data as this can be very large.
932+
self._nr_bedrock_attrs.clear()
933+
934+
def record_error(self, transaction, exc):
935+
if hasattr(self, "_nr_ft"):
936+
try:
937+
ft = self._nr_ft
938+
error_attributes = getattr(self, "_nr_bedrock_attrs", {})
939+
940+
# If there are no bedrock attrs exit early as there's no data to record.
941+
if not error_attributes:
942+
return
943+
944+
error_attributes = bedrock_error_attributes(exc, error_attributes)
945+
notice_error_attributes = {
946+
"http.statusCode": error_attributes.get("http.statusCode"),
947+
"error.message": error_attributes.get("error.message"),
948+
"error.code": error_attributes.get("error.code"),
949+
}
950+
notice_error_attributes.update({"completion_id": str(uuid.uuid4())})
951+
952+
ft.notice_error(attributes=notice_error_attributes)
953+
954+
ft.__exit__(*sys.exc_info())
955+
error_attributes["duration"] = ft.duration * 1000
956+
957+
handle_chat_completion_event(transaction, error_attributes)
958+
959+
# Clear cached data as this can be very large.
960+
error_attributes.clear()
961+
except Exception:
962+
_logger.warning(EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE, exc_info=True)
963+
964+
def record_stream_chunk(self, event, transaction):
965+
if event:
966+
try:
967+
if getattr(self, "_nr_is_converse", False):
968+
return self.converse_record_stream_chunk(event, transaction)
969+
else:
970+
return self.invoke_record_stream_chunk(event, transaction)
971+
except Exception:
972+
_logger.warning(RESPONSE_EXTRACTOR_FAILURE_LOG_MESSAGE, exc_info=True)
973+
974+
def invoke_record_stream_chunk(self, event, transaction):
975+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
976+
chunk = json.loads(event["chunk"]["bytes"].decode("utf-8"))
977+
self._nr_model_extractor(chunk, bedrock_attrs)
978+
# In Langchain, the bedrock iterator exits early if type is "content_block_stop".
979+
# So we need to call the record events here since stop iteration will not be raised.
980+
_type = chunk.get("type")
981+
if _type == "content_block_stop":
982+
self.record_events_on_stop_iteration(transaction)
983+
984+
def converse_record_stream_chunk(self, event, transaction):
985+
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
986+
if "contentBlockDelta" in event:
987+
if not bedrock_attrs:
988+
return
989+
990+
content = ((event.get("contentBlockDelta") or {}).get("delta") or {}).get("text", "")
991+
if "output_message_list" not in bedrock_attrs:
992+
bedrock_attrs["output_message_list"] = [{"role": "assistant", "content": ""}]
993+
bedrock_attrs["output_message_list"][0]["content"] += content
994+
995+
if "messageStop" in event:
996+
bedrock_attrs["response.choices.finish_reason"] = (event.get("messageStop") or {}).get("stopReason", "")
997+
998+
# TODO: Is this also subject to the content_block_stop behavior from Langchain?
999+
# If so, that would preclude us from ever capturing the messageStop event with the stopReason.
1000+
# if "contentBlockStop" in event:
1001+
# self.record_events_on_stop_iteration(transaction)
1002+
9011003

9021004
class AsyncEventStreamWrapper(ObjectProxy):
9031005
def __aiter__(self):
@@ -909,8 +1011,11 @@ def __aiter__(self):
9091011

9101012

9111013
class AsyncGeneratorProxy(ObjectProxy):
912-
def __init__(self, wrapped):
913-
super().__init__(wrapped)
1014+
# Import these methods from the synchronous GeneratorProxy
1015+
# Avoid direct inheritance so we don't implement both __iter__ and __aiter__
1016+
record_stream_chunk = GeneratorProxy.record_stream_chunk
1017+
record_events_on_stop_iteration = GeneratorProxy.record_events_on_stop_iteration
1018+
record_error = GeneratorProxy.record_error
9141019

9151020
def __aiter__(self):
9161021
return self
@@ -922,83 +1027,19 @@ async def __anext__(self):
9221027
return_val = None
9231028
try:
9241029
return_val = await self.__wrapped__.__anext__()
925-
record_stream_chunk(self, return_val, transaction)
1030+
self.record_stream_chunk(return_val, transaction)
9261031
except StopAsyncIteration:
927-
record_events_on_stop_iteration(self, transaction)
1032+
self.record_events_on_stop_iteration(transaction)
9281033
raise
9291034
except Exception as exc:
930-
record_error(self, transaction, exc)
1035+
self.record_error(transaction, exc)
9311036
raise
9321037
return return_val
9331038

9341039
async def aclose(self):
9351040
return await super().aclose()
9361041

9371042

938-
def record_stream_chunk(self, return_val, transaction):
939-
if return_val:
940-
try:
941-
chunk = json.loads(return_val["chunk"]["bytes"].decode("utf-8"))
942-
self._nr_model_extractor(chunk, self._nr_bedrock_attrs)
943-
# In Langchain, the bedrock iterator exits early if type is "content_block_stop".
944-
# So we need to call the record events here since stop iteration will not be raised.
945-
_type = chunk.get("type")
946-
if _type == "content_block_stop":
947-
record_events_on_stop_iteration(self, transaction)
948-
except Exception:
949-
_logger.warning(RESPONSE_EXTRACTOR_FAILURE_LOG_MESSAGE, exc_info=True)
950-
951-
952-
def record_events_on_stop_iteration(self, transaction):
953-
if hasattr(self, "_nr_ft"):
954-
bedrock_attrs = getattr(self, "_nr_bedrock_attrs", {})
955-
self._nr_ft.__exit__(None, None, None)
956-
957-
# If there are no bedrock attrs exit early as there's no data to record.
958-
if not bedrock_attrs:
959-
return
960-
961-
try:
962-
bedrock_attrs["duration"] = self._nr_ft.duration * 1000
963-
handle_chat_completion_event(transaction, bedrock_attrs)
964-
except Exception:
965-
_logger.warning(RESPONSE_PROCESSING_FAILURE_LOG_MESSAGE, exc_info=True)
966-
967-
# Clear cached data as this can be very large.
968-
self._nr_bedrock_attrs.clear()
969-
970-
971-
def record_error(self, transaction, exc):
972-
if hasattr(self, "_nr_ft"):
973-
try:
974-
ft = self._nr_ft
975-
error_attributes = getattr(self, "_nr_bedrock_attrs", {})
976-
977-
# If there are no bedrock attrs exit early as there's no data to record.
978-
if not error_attributes:
979-
return
980-
981-
error_attributes = bedrock_error_attributes(exc, error_attributes)
982-
notice_error_attributes = {
983-
"http.statusCode": error_attributes.get("http.statusCode"),
984-
"error.message": error_attributes.get("error.message"),
985-
"error.code": error_attributes.get("error.code"),
986-
}
987-
notice_error_attributes.update({"completion_id": str(uuid.uuid4())})
988-
989-
ft.notice_error(attributes=notice_error_attributes)
990-
991-
ft.__exit__(*sys.exc_info())
992-
error_attributes["duration"] = ft.duration * 1000
993-
994-
handle_chat_completion_event(transaction, error_attributes)
995-
996-
# Clear cached data as this can be very large.
997-
error_attributes.clear()
998-
except Exception:
999-
_logger.warning(EXCEPTION_HANDLING_FAILURE_LOG_MESSAGE, exc_info=True)
1000-
1001-
10021043
def handle_embedding_event(transaction, bedrock_attrs):
10031044
embedding_id = str(uuid.uuid4())
10041045

@@ -1529,6 +1570,7 @@ def wrap_serialize_to_request(wrapped, instance, args, kwargs):
15291570
response_streaming=True
15301571
),
15311572
("bedrock-runtime", "converse"): wrap_bedrock_runtime_converse(response_streaming=False),
1573+
("bedrock-runtime", "converse_stream"): wrap_bedrock_runtime_converse(response_streaming=True),
15321574
}
15331575

15341576

0 commit comments

Comments
 (0)