Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
# SPDX-License-Identifier: Apache-2.0
# Modifications Copyright The OpenTelemetry Authors. Licensed under the Apache License 2.0 License.
import importlib
import json
from typing import Any, Dict, Optional, Sequence

from botocore.exceptions import ClientError

Expand Down Expand Up @@ -32,7 +34,7 @@
_determine_call_context,
_safe_invoke,
)
from opentelemetry.instrumentation.botocore.extensions import _KNOWN_EXTENSIONS, _find_extension
from opentelemetry.instrumentation.botocore.extensions import _KNOWN_EXTENSIONS, _find_extension, bedrock_utils
from opentelemetry.instrumentation.botocore.extensions.dynamodb import _DynamoDbExtension
from opentelemetry.instrumentation.botocore.extensions.lmbd import _LambdaExtension
from opentelemetry.instrumentation.botocore.extensions.sns import _SnsExtension
Expand Down Expand Up @@ -234,7 +236,7 @@ def patch_on_success(self, span: Span, result: _BotoResultT, instrumentor_contex
_SqsExtension.on_success = patch_on_success


def _apply_botocore_bedrock_patch() -> None:
def _apply_botocore_bedrock_patch() -> None: # pylint: disable=too-many-statements
"""Botocore instrumentation patch for Bedrock, Bedrock Agent, and Bedrock Agent Runtime

This patch adds an extension to the upstream's list of known extension for Bedrock.
Expand All @@ -245,7 +247,91 @@ def _apply_botocore_bedrock_patch() -> None:
_KNOWN_EXTENSIONS["bedrock"] = _lazy_load(".", "_BedrockExtension")
_KNOWN_EXTENSIONS["bedrock-agent"] = _lazy_load(".", "_BedrockAgentExtension")
_KNOWN_EXTENSIONS["bedrock-agent-runtime"] = _lazy_load(".", "_BedrockAgentRuntimeExtension")
# bedrock-runtime is handled by upstream

# TODO: The following code is to patch bedrock-runtime bugs that are fixed in
# opentelemetry-instrumentation-botocore==0.56b0 in these PRs:
# https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3548
# https://github.com/open-telemetry/opentelemetry-python-contrib/pull/3544
# Remove this code once we've bumped opentelemetry-instrumentation-botocore dependency to 0.56b0

old_init = bedrock_utils.ConverseStreamWrapper.__init__
old_process_event = bedrock_utils.ConverseStreamWrapper._process_event

# The OpenTelemetry Authors code
def patched_init(self, *args, **kwargs):
old_init(self, *args, **kwargs)
self._tool_json_input_buf = ""

def patched_process_event(self, event):
if "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
if "toolUse" in start:
self._content_block = {"toolUse": start["toolUse"]}
return

if "contentBlockDelta" in event:
if self._record_message:
delta = event["contentBlockDelta"].get("delta", {})
if "text" in delta:
self._content_block.setdefault("text", "")
self._content_block["text"] += delta["text"]
elif "toolUse" in delta:
if (input_buf := delta["toolUse"].get("input")) is not None:
self._tool_json_input_buf += input_buf
return

if "contentBlockStop" in event:
if self._record_message:
if self._tool_json_input_buf:
try:
self._content_block["toolUse"]["input"] = json.loads(self._tool_json_input_buf)
except json.JSONDecodeError:
self._content_block["toolUse"]["input"] = self._tool_json_input_buf
self._message["content"].append(self._content_block)
self._content_block = {}
self._tool_json_input_buf = ""
return

old_process_event(self, event)

def patched_extract_tool_calls(
message: dict[str, Any], capture_content: bool
) -> Optional[Sequence[Dict[str, Any]]]:
content = message.get("content")
if not content:
return None

tool_uses = [item["toolUse"] for item in content if "toolUse" in item]
if not tool_uses:
tool_uses = [item for item in content if isinstance(item, dict) and item.get("type") == "tool_use"]
tool_id_key = "id"
else:
tool_id_key = "toolUseId"

if not tool_uses:
return None

tool_calls = []
for tool_use in tool_uses:
tool_call = {"type": "function"}
if call_id := tool_use.get(tool_id_key):
tool_call["id"] = call_id

if function_name := tool_use.get("name"):
tool_call["function"] = {"name": function_name}

if (function_input := tool_use.get("input")) and capture_content:
tool_call.setdefault("function", {})
tool_call["function"]["arguments"] = function_input

tool_calls.append(tool_call)
return tool_calls

bedrock_utils.ConverseStreamWrapper.__init__ = patched_init
bedrock_utils.ConverseStreamWrapper._process_event = patched_process_event
bedrock_utils.extract_tool_calls = patched_extract_tool_calls

# END The OpenTelemetry Authors code


def _apply_botocore_dynamodb_patch() -> None:
Expand All @@ -270,7 +356,7 @@ def patch_on_success(self, span: Span, result: _BotoResultT, instrumentor_contex


def _apply_botocore_api_call_patch() -> None:
# pylint: disable=too-many-locals
# pylint: disable=too-many-locals,too-many-statements
def patched_api_call(self, original_func, instance, args, kwargs):
"""Botocore instrumentation patch to capture AWS authentication details

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
apply_instrumentation_patches,
)
from opentelemetry.instrumentation.botocore import BotocoreInstrumentor
from opentelemetry.instrumentation.botocore.extensions import _KNOWN_EXTENSIONS
from opentelemetry.instrumentation.botocore.extensions import _KNOWN_EXTENSIONS, bedrock_utils
from opentelemetry.propagate import get_global_textmap
from opentelemetry.semconv.trace import SpanAttributes
from opentelemetry.trace.span import Span
Expand Down Expand Up @@ -84,6 +84,10 @@ def _run_patch_behaviour_tests(self):
self._test_unpatched_botocore_propagator()
self._test_unpatched_gevent_instrumentation()
self._test_unpatched_starlette_instrumentation()
# TODO: remove these tests once we bump botocore instrumentation version to 0.56b0
# Bedrock Runtime tests
self._test_unpatched_converse_stream_wrapper()
self._test_unpatched_extract_tool_calls()

# Apply patches
apply_instrumentation_patches()
Expand Down Expand Up @@ -219,6 +223,11 @@ def _test_patched_botocore_instrumentation(self):
# Bedrock Agent Operation
self._test_patched_bedrock_agent_instrumentation()

# TODO: remove these tests once we bump botocore instrumentation version to 0.56b0
# Bedrock Runtime
self._test_patched_converse_stream_wrapper()
self._test_patched_extract_tool_calls()

# Bedrock Agent Runtime
self.assertTrue("bedrock-agent-runtime" in _KNOWN_EXTENSIONS)
bedrock_agent_runtime_attributes: Dict[str, str] = _do_extract_attributes_bedrock("bedrock-agent-runtime")
Expand Down Expand Up @@ -470,6 +479,127 @@ def _test_patched_bedrock_instrumentation(self):
self.assertEqual(len(bedrock_sucess_attributes), 1)
self.assertEqual(bedrock_sucess_attributes["aws.bedrock.guardrail.id"], _BEDROCK_GUARDRAIL_ID)

def _test_unpatched_extract_tool_calls(self):
"""Test unpatched extract_tool_calls with string content throws AttributeError"""
message_with_string_content = {"role": "assistant", "content": "{"}
with self.assertRaises(AttributeError):
bedrock_utils.extract_tool_calls(message_with_string_content, True)

def _test_unpatched_converse_stream_wrapper(self):
"""Test unpatched bedrock-runtime where input values remain as numbers"""

mock_stream = MagicMock()
mock_span = MagicMock()
mock_stream_error_callback = MagicMock()

wrapper = bedrock_utils.ConverseStreamWrapper(mock_stream, mock_span, mock_stream_error_callback)
wrapper._record_message = True
wrapper._message = {"role": "assistant", "content": []}

start_event = {
"contentBlockStart": {
"start": {
"toolUse": {
"toolUseId": "random_id",
"name": "example",
"input": '{"input": 999999999999999999}',
}
},
"contentBlockIndex": 0,
}
}
wrapper._process_event(start_event)

# Validate that _content_block contains toolUse input that has been JSON decoded
self.assertIn("toolUse", wrapper._content_block)
self.assertIn("input", wrapper._content_block["toolUse"])
self.assertIn("input", wrapper._content_block["toolUse"]["input"])
# Validate that input values are numbers (unpatched behavior)
self.assertIsInstance(wrapper._content_block["toolUse"]["input"]["input"], int)
self.assertEqual(wrapper._content_block["toolUse"]["input"]["input"], 999999999999999999)

stop_event = {"contentBlockStop": {"contentBlockIndex": 0}}
wrapper._process_event(stop_event)

expected_tool_use = {
"toolUseId": "random_id",
"name": "example",
"input": {"input": 999999999999999999},
}
self.assertEqual(len(wrapper._message["content"]), 1)
self.assertEqual(wrapper._message["content"][0]["toolUse"], expected_tool_use)

def _test_patched_converse_stream_wrapper(self):
"""Test patched bedrock-runtime"""

# Create mock arguments for ConverseStreamWrapper
mock_stream = MagicMock()
mock_span = MagicMock()
mock_stream_error_callback = MagicMock()

# Create real ConverseStreamWrapper with mocked arguments
wrapper = bedrock_utils.ConverseStreamWrapper(mock_stream, mock_span, mock_stream_error_callback)
wrapper._record_message = True
wrapper._message = {"role": "assistant", "content": []}

# Test contentBlockStart
start_event = {
"contentBlockStart": {
"start": {
"toolUse": {
"toolUseId": "random_id",
"name": "example",
"input": '{"input": 999999999999999999}',
}
},
"contentBlockIndex": 0,
}
}

wrapper._process_event(start_event)

# Validate that _content_block contains toolUse input as literal string (patched behavior)
self.assertIn("toolUse", wrapper._content_block)
self.assertIn("input", wrapper._content_block["toolUse"])
# Validate that input is a string containing the literal JSON (not decoded)
self.assertIsInstance(wrapper._content_block["toolUse"]["input"], str)
self.assertEqual(wrapper._content_block["toolUse"]["input"], '{"input": 999999999999999999}')

# Test contentBlockDelta events
delta_events = [
{"contentBlockDelta": {"delta": {"toolUse": {"input": '{"in'}}, "contentBlockIndex": 0}},
{"contentBlockDelta": {"delta": {"toolUse": {"input": 'put": 9'}}, "contentBlockIndex": 0}},
{"contentBlockDelta": {"delta": {"toolUse": {"input": "99"}}, "contentBlockIndex": 0}},
{"contentBlockDelta": {"delta": {"toolUse": {"input": "99"}}, "contentBlockIndex": 0}},
]

for delta_event in delta_events:
wrapper._process_event(delta_event)

# Verify accumulated input buffer
self.assertEqual(wrapper._tool_json_input_buf, '{"input": 99999')

# Test contentBlockStop
stop_event = {"contentBlockStop": {"contentBlockIndex": 0}}
wrapper._process_event(stop_event)

# Verify final content_block toolUse value (input becomes the accumulated JSON string)
expected_tool_use = {
"toolUseId": "random_id",
"name": "example",
"input": '{"input": 99999',
}
self.assertEqual(len(wrapper._message["content"]), 1)
self.assertEqual(wrapper._message["content"][0]["toolUse"], expected_tool_use)

def _test_patched_extract_tool_calls(self):
"""Test patched extract_tool_calls with string content"""

# Test extract_tool_calls with string content (should return None)
message_with_string_content = {"role": "assistant", "content": "{"}
result = bedrock_utils.extract_tool_calls(message_with_string_content, True)
self.assertIsNone(result)

def _test_patched_bedrock_agent_instrumentation(self):
"""For bedrock-agent service, both extract_attributes and on_success provides attributes,
the attributes depend on the API being invoked."""
Expand Down
Loading