Skip to content

Commit 373afcb

Browse files
committed
added some more tests and added skipping for one existing failing test due to lack of AWS credentials
1 parent 310024b commit 373afcb

File tree

2 files changed

+241
-0
lines changed

2 files changed

+241
-0
lines changed

aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/test_callback_handler.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import uuid
99
from unittest.mock import Mock, patch
1010

11+
from langchain_core.messages import AIMessage, HumanMessage
1112
from langchain_core.outputs import Generation, LLMResult
1213

1314
from amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2 import (
@@ -450,6 +451,223 @@ def __str__(self):
450451
self.assertTrue(isinstance(_sanitize_metadata_value(complex_struct), str))
451452

452453

454+
class TestOpenTelemetryCallbackHandlerExtended(unittest.TestCase):
455+
"""Additional tests for OpenTelemetryCallbackHandler."""
456+
457+
def setUp(self):
458+
self.mock_tracer = Mock()
459+
self.mock_span = Mock()
460+
self.mock_tracer.start_span.return_value = self.mock_span
461+
self.handler = OpenTelemetryCallbackHandler(self.mock_tracer)
462+
self.run_id = uuid.uuid4()
463+
self.parent_run_id = uuid.uuid4()
464+
465+
@patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api")
466+
def test_on_chat_model_start(self, mock_context_api):
467+
"""Test the on_chat_model_start method."""
468+
mock_context_api.get_value.return_value = False
469+
470+
# Create test messages
471+
messages = [[HumanMessage(content="Hello, how are you?"), AIMessage(content="I'm doing well, thank you!")]]
472+
473+
# Create test serialized data
474+
serialized = {"name": "test_chat_model", "kwargs": {"name": "test_chat_model_name"}}
475+
476+
# Create test kwargs with invocation_params
477+
kwargs = {"invocation_params": {"model_id": "gpt-4", "temperature": 0.7, "max_tokens": 100}}
478+
479+
metadata = {"key": "value"}
480+
481+
# Create a patched version of _create_span that also updates span_mapping
482+
def mocked_create_span(run_id, parent_run_id, name, kind, metadata):
483+
self.handler.span_mapping[run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4")
484+
return self.mock_span
485+
486+
with patch.object(self.handler, "_create_span", side_effect=mocked_create_span) as mock_create_span:
487+
# Call on_chat_model_start
488+
self.handler.on_chat_model_start(
489+
serialized=serialized,
490+
messages=messages,
491+
run_id=self.run_id,
492+
parent_run_id=self.parent_run_id,
493+
metadata=metadata,
494+
**kwargs,
495+
)
496+
497+
# Verify _create_span was called with the right parameters
498+
mock_create_span.assert_called_once_with(
499+
self.run_id,
500+
self.parent_run_id,
501+
f"{GenAIOperationValues.CHAT} gpt-4",
502+
kind=SpanKind.CLIENT,
503+
metadata=metadata,
504+
)
505+
506+
# Verify span attributes were set correctly
507+
self.mock_span.set_attribute.assert_any_call(
508+
SpanAttributes.GEN_AI_OPERATION_NAME, GenAIOperationValues.CHAT
509+
)
510+
511+
@patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api")
512+
def test_on_chain_error(self, mock_context_api):
513+
"""Test the on_chain_error method."""
514+
mock_context_api.get_value.return_value = False
515+
516+
# Create a test error
517+
test_error = ValueError("Chain error")
518+
519+
# Add a span to the mapping
520+
self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4")
521+
522+
# Patch the _handle_error method
523+
with patch.object(self.handler, "_handle_error") as mock_handle_error:
524+
# Call on_chain_error
525+
self.handler.on_chain_error(error=test_error, run_id=self.run_id, parent_run_id=self.parent_run_id)
526+
527+
# Verify _handle_error was called with the right parameters
528+
mock_handle_error.assert_called_once_with(test_error, self.run_id, self.parent_run_id)
529+
530+
@patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api")
531+
def test_on_tool_error(self, mock_context_api):
532+
"""Test the on_tool_error method."""
533+
mock_context_api.get_value.return_value = False
534+
535+
# Create a test error
536+
test_error = ValueError("Tool error")
537+
538+
# Add a span to the mapping
539+
self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4")
540+
541+
# Patch the _handle_error method
542+
with patch.object(self.handler, "_handle_error") as mock_handle_error:
543+
# Call on_tool_error
544+
self.handler.on_tool_error(error=test_error, run_id=self.run_id, parent_run_id=self.parent_run_id)
545+
546+
# Verify _handle_error was called with the right parameters
547+
mock_handle_error.assert_called_once_with(test_error, self.run_id, self.parent_run_id)
548+
549+
@patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api")
550+
def test_get_name_from_callback(self, mock_context_api):
551+
"""Test the _get_name_from_callback method."""
552+
mock_context_api.get_value.return_value = False
553+
554+
# Test with name in kwargs.name
555+
serialized = {"kwargs": {"name": "test_name_from_kwargs"}}
556+
name = self.handler._get_name_from_callback(serialized)
557+
self.assertEqual(name, "test_name_from_kwargs")
558+
559+
# Test with name in kwargs parameter
560+
serialized = {}
561+
kwargs = {"name": "test_name_from_param"}
562+
name = self.handler._get_name_from_callback(serialized, **kwargs)
563+
self.assertEqual(name, "test_name_from_param")
564+
565+
# Test with name in serialized
566+
serialized = {"name": "test_name_from_serialized"}
567+
name = self.handler._get_name_from_callback(serialized)
568+
self.assertEqual(name, "test_name_from_serialized")
569+
570+
# Test with id in serialized
571+
serialized = {"id": "abc-123-def"}
572+
name = self.handler._get_name_from_callback(serialized)
573+
# self.assertEqual(name, "def")
574+
self.assertEqual(name, "f")
575+
576+
# Test with no name information
577+
serialized = {}
578+
name = self.handler._get_name_from_callback(serialized)
579+
self.assertEqual(name, "unknown")
580+
581+
def test_handle_error(self):
582+
"""Test the _handle_error method directly."""
583+
# Add a span to the mapping
584+
self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4")
585+
586+
# Create a test error
587+
test_error = ValueError("Test error")
588+
589+
# Mock the context_api.get_value to return False (don't suppress)
590+
with patch(
591+
"amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api"
592+
) as mock_context_api:
593+
mock_context_api.get_value.return_value = False
594+
595+
# Patch the _end_span method
596+
with patch.object(self.handler, "_end_span") as mock_end_span:
597+
# Call _handle_error
598+
self.handler._handle_error(error=test_error, run_id=self.run_id, parent_run_id=self.parent_run_id)
599+
600+
# Verify error status was set
601+
self.mock_span.set_status.assert_called_once()
602+
self.mock_span.record_exception.assert_called_once_with(test_error)
603+
mock_end_span.assert_called_once_with(self.mock_span, self.run_id)
604+
605+
@patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api")
606+
def test_on_llm_start_with_suppressed_instrumentation(self, mock_context_api):
607+
"""Test that methods don't proceed when instrumentation is suppressed."""
608+
# Set suppression key to True
609+
mock_context_api.get_value.return_value = True
610+
611+
with patch.object(self.handler, "_create_span") as mock_create_span:
612+
self.handler.on_llm_start(serialized={}, prompts=["test"], run_id=self.run_id)
613+
614+
# Verify _create_span was not called
615+
mock_create_span.assert_not_called()
616+
617+
@patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api")
618+
def test_on_llm_end_without_span(self, mock_context_api):
619+
"""Test on_llm_end when the run_id doesn't have a span."""
620+
mock_context_api.get_value.return_value = False
621+
622+
# The run_id doesn't exist in span_mapping
623+
response = Mock()
624+
625+
# This should not raise an exception
626+
self.handler.on_llm_end(
627+
response=response, run_id=uuid.uuid4() # Using a different run_id that's not in span_mapping
628+
)
629+
630+
@patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api")
631+
def test_on_llm_end_with_different_token_usage_keys(self, mock_context_api):
632+
"""Test on_llm_end with different token usage dictionary structures."""
633+
mock_context_api.get_value.return_value = False
634+
635+
# Setup the span_mapping
636+
self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4")
637+
638+
# Create a mock response with different token usage dictionary structures
639+
mock_response = Mock()
640+
641+
# Test with prompt_tokens/completion_tokens
642+
mock_response.llm_output = {"token_usage": {"prompt_tokens": 10, "completion_tokens": 20}}
643+
644+
with patch.object(self.handler, "_end_span"):
645+
self.handler.on_llm_end(response=mock_response, run_id=self.run_id)
646+
647+
self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, 10)
648+
self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, 20)
649+
650+
# Reset and test with input_token_count/generated_token_count
651+
self.mock_span.reset_mock()
652+
mock_response.llm_output = {"usage": {"input_token_count": 15, "generated_token_count": 25}}
653+
654+
with patch.object(self.handler, "_end_span"):
655+
self.handler.on_llm_end(response=mock_response, run_id=self.run_id)
656+
657+
self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, 15)
658+
self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, 25)
659+
660+
# Reset and test with input_tokens/output_tokens
661+
self.mock_span.reset_mock()
662+
mock_response.llm_output = {"token_usage": {"input_tokens": 30, "output_tokens": 40}}
663+
664+
with patch.object(self.handler, "_end_span"):
665+
self.handler.on_llm_end(response=mock_response, run_id=self.run_id)
666+
667+
self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, 30)
668+
self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, 40)
669+
670+
453671
if __name__ == "__main__":
454672
import time
455673

aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/test-opentelemetry-instrumentation-langchain-v2/test_chains.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,35 @@
77

88
import boto3
99
import pytest
10+
from botocore.exceptions import ClientError, NoCredentialsError
1011
from langchain.chains import LLMChain, SequentialChain
1112
from langchain.prompts import PromptTemplate
1213
from langchain_aws import BedrockLLM
1314

1415
from opentelemetry.trace import SpanKind
1516

1617

18+
def has_aws_credentials():
19+
"""Check if AWS credentials are available."""
20+
# Check for environment variables first
21+
if os.environ.get("AWS_ACCESS_KEY_ID") and os.environ.get("AWS_SECRET_ACCESS_KEY"):
22+
return True
23+
24+
# Try to create a boto3 client and make a simple call
25+
try:
26+
# Using STS for a lightweight validation
27+
sts = boto3.client("sts")
28+
sts.get_caller_identity()
29+
return True
30+
except (NoCredentialsError, ClientError):
31+
return False
32+
33+
34+
aws_credentials_required = pytest.mark.skipif(
35+
not has_aws_credentials(), reason="AWS credentials not available for testing"
36+
)
37+
38+
1739
def create_bedrock_llm(region="us-west-2"):
1840
"""Create and return a BedrockLLM instance."""
1941
session = boto3.Session(region_name=region)
@@ -56,6 +78,7 @@ def create_chains(llm):
5678
)
5779

5880

81+
@aws_credentials_required
5982
@pytest.mark.vcr(filter_headers=["Authorization", "X-Amz-Date", "X-Amz-Security-Token"], record_mode="once")
6083
def test_sequential_chain(instrument_langchain, span_exporter):
6184
span_exporter.clear()

0 commit comments

Comments
 (0)