|
8 | 8 | import uuid
|
9 | 9 | from unittest.mock import Mock, patch
|
10 | 10 |
|
| 11 | +from langchain_core.messages import AIMessage, HumanMessage |
11 | 12 | from langchain_core.outputs import Generation, LLMResult
|
12 | 13 |
|
13 | 14 | from amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2 import (
|
@@ -450,6 +451,223 @@ def __str__(self):
|
450 | 451 | self.assertTrue(isinstance(_sanitize_metadata_value(complex_struct), str))
|
451 | 452 |
|
452 | 453 |
|
| 454 | +class TestOpenTelemetryCallbackHandlerExtended(unittest.TestCase): |
| 455 | + """Additional tests for OpenTelemetryCallbackHandler.""" |
| 456 | + |
| 457 | + def setUp(self): |
| 458 | + self.mock_tracer = Mock() |
| 459 | + self.mock_span = Mock() |
| 460 | + self.mock_tracer.start_span.return_value = self.mock_span |
| 461 | + self.handler = OpenTelemetryCallbackHandler(self.mock_tracer) |
| 462 | + self.run_id = uuid.uuid4() |
| 463 | + self.parent_run_id = uuid.uuid4() |
| 464 | + |
| 465 | + @patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api") |
| 466 | + def test_on_chat_model_start(self, mock_context_api): |
| 467 | + """Test the on_chat_model_start method.""" |
| 468 | + mock_context_api.get_value.return_value = False |
| 469 | + |
| 470 | + # Create test messages |
| 471 | + messages = [[HumanMessage(content="Hello, how are you?"), AIMessage(content="I'm doing well, thank you!")]] |
| 472 | + |
| 473 | + # Create test serialized data |
| 474 | + serialized = {"name": "test_chat_model", "kwargs": {"name": "test_chat_model_name"}} |
| 475 | + |
| 476 | + # Create test kwargs with invocation_params |
| 477 | + kwargs = {"invocation_params": {"model_id": "gpt-4", "temperature": 0.7, "max_tokens": 100}} |
| 478 | + |
| 479 | + metadata = {"key": "value"} |
| 480 | + |
| 481 | + # Create a patched version of _create_span that also updates span_mapping |
| 482 | + def mocked_create_span(run_id, parent_run_id, name, kind, metadata): |
| 483 | + self.handler.span_mapping[run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") |
| 484 | + return self.mock_span |
| 485 | + |
| 486 | + with patch.object(self.handler, "_create_span", side_effect=mocked_create_span) as mock_create_span: |
| 487 | + # Call on_chat_model_start |
| 488 | + self.handler.on_chat_model_start( |
| 489 | + serialized=serialized, |
| 490 | + messages=messages, |
| 491 | + run_id=self.run_id, |
| 492 | + parent_run_id=self.parent_run_id, |
| 493 | + metadata=metadata, |
| 494 | + **kwargs, |
| 495 | + ) |
| 496 | + |
| 497 | + # Verify _create_span was called with the right parameters |
| 498 | + mock_create_span.assert_called_once_with( |
| 499 | + self.run_id, |
| 500 | + self.parent_run_id, |
| 501 | + f"{GenAIOperationValues.CHAT} gpt-4", |
| 502 | + kind=SpanKind.CLIENT, |
| 503 | + metadata=metadata, |
| 504 | + ) |
| 505 | + |
| 506 | + # Verify span attributes were set correctly |
| 507 | + self.mock_span.set_attribute.assert_any_call( |
| 508 | + SpanAttributes.GEN_AI_OPERATION_NAME, GenAIOperationValues.CHAT |
| 509 | + ) |
| 510 | + |
| 511 | + @patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api") |
| 512 | + def test_on_chain_error(self, mock_context_api): |
| 513 | + """Test the on_chain_error method.""" |
| 514 | + mock_context_api.get_value.return_value = False |
| 515 | + |
| 516 | + # Create a test error |
| 517 | + test_error = ValueError("Chain error") |
| 518 | + |
| 519 | + # Add a span to the mapping |
| 520 | + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") |
| 521 | + |
| 522 | + # Patch the _handle_error method |
| 523 | + with patch.object(self.handler, "_handle_error") as mock_handle_error: |
| 524 | + # Call on_chain_error |
| 525 | + self.handler.on_chain_error(error=test_error, run_id=self.run_id, parent_run_id=self.parent_run_id) |
| 526 | + |
| 527 | + # Verify _handle_error was called with the right parameters |
| 528 | + mock_handle_error.assert_called_once_with(test_error, self.run_id, self.parent_run_id) |
| 529 | + |
| 530 | + @patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api") |
| 531 | + def test_on_tool_error(self, mock_context_api): |
| 532 | + """Test the on_tool_error method.""" |
| 533 | + mock_context_api.get_value.return_value = False |
| 534 | + |
| 535 | + # Create a test error |
| 536 | + test_error = ValueError("Tool error") |
| 537 | + |
| 538 | + # Add a span to the mapping |
| 539 | + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") |
| 540 | + |
| 541 | + # Patch the _handle_error method |
| 542 | + with patch.object(self.handler, "_handle_error") as mock_handle_error: |
| 543 | + # Call on_tool_error |
| 544 | + self.handler.on_tool_error(error=test_error, run_id=self.run_id, parent_run_id=self.parent_run_id) |
| 545 | + |
| 546 | + # Verify _handle_error was called with the right parameters |
| 547 | + mock_handle_error.assert_called_once_with(test_error, self.run_id, self.parent_run_id) |
| 548 | + |
| 549 | + @patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api") |
| 550 | + def test_get_name_from_callback(self, mock_context_api): |
| 551 | + """Test the _get_name_from_callback method.""" |
| 552 | + mock_context_api.get_value.return_value = False |
| 553 | + |
| 554 | + # Test with name in kwargs.name |
| 555 | + serialized = {"kwargs": {"name": "test_name_from_kwargs"}} |
| 556 | + name = self.handler._get_name_from_callback(serialized) |
| 557 | + self.assertEqual(name, "test_name_from_kwargs") |
| 558 | + |
| 559 | + # Test with name in kwargs parameter |
| 560 | + serialized = {} |
| 561 | + kwargs = {"name": "test_name_from_param"} |
| 562 | + name = self.handler._get_name_from_callback(serialized, **kwargs) |
| 563 | + self.assertEqual(name, "test_name_from_param") |
| 564 | + |
| 565 | + # Test with name in serialized |
| 566 | + serialized = {"name": "test_name_from_serialized"} |
| 567 | + name = self.handler._get_name_from_callback(serialized) |
| 568 | + self.assertEqual(name, "test_name_from_serialized") |
| 569 | + |
| 570 | + # Test with id in serialized |
| 571 | + serialized = {"id": "abc-123-def"} |
| 572 | + name = self.handler._get_name_from_callback(serialized) |
| 573 | + # self.assertEqual(name, "def") |
| 574 | + self.assertEqual(name, "f") |
| 575 | + |
| 576 | + # Test with no name information |
| 577 | + serialized = {} |
| 578 | + name = self.handler._get_name_from_callback(serialized) |
| 579 | + self.assertEqual(name, "unknown") |
| 580 | + |
| 581 | + def test_handle_error(self): |
| 582 | + """Test the _handle_error method directly.""" |
| 583 | + # Add a span to the mapping |
| 584 | + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") |
| 585 | + |
| 586 | + # Create a test error |
| 587 | + test_error = ValueError("Test error") |
| 588 | + |
| 589 | + # Mock the context_api.get_value to return False (don't suppress) |
| 590 | + with patch( |
| 591 | + "amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api" |
| 592 | + ) as mock_context_api: |
| 593 | + mock_context_api.get_value.return_value = False |
| 594 | + |
| 595 | + # Patch the _end_span method |
| 596 | + with patch.object(self.handler, "_end_span") as mock_end_span: |
| 597 | + # Call _handle_error |
| 598 | + self.handler._handle_error(error=test_error, run_id=self.run_id, parent_run_id=self.parent_run_id) |
| 599 | + |
| 600 | + # Verify error status was set |
| 601 | + self.mock_span.set_status.assert_called_once() |
| 602 | + self.mock_span.record_exception.assert_called_once_with(test_error) |
| 603 | + mock_end_span.assert_called_once_with(self.mock_span, self.run_id) |
| 604 | + |
| 605 | + @patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api") |
| 606 | + def test_on_llm_start_with_suppressed_instrumentation(self, mock_context_api): |
| 607 | + """Test that methods don't proceed when instrumentation is suppressed.""" |
| 608 | + # Set suppression key to True |
| 609 | + mock_context_api.get_value.return_value = True |
| 610 | + |
| 611 | + with patch.object(self.handler, "_create_span") as mock_create_span: |
| 612 | + self.handler.on_llm_start(serialized={}, prompts=["test"], run_id=self.run_id) |
| 613 | + |
| 614 | + # Verify _create_span was not called |
| 615 | + mock_create_span.assert_not_called() |
| 616 | + |
| 617 | + @patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api") |
| 618 | + def test_on_llm_end_without_span(self, mock_context_api): |
| 619 | + """Test on_llm_end when the run_id doesn't have a span.""" |
| 620 | + mock_context_api.get_value.return_value = False |
| 621 | + |
| 622 | + # The run_id doesn't exist in span_mapping |
| 623 | + response = Mock() |
| 624 | + |
| 625 | + # This should not raise an exception |
| 626 | + self.handler.on_llm_end( |
| 627 | + response=response, run_id=uuid.uuid4() # Using a different run_id that's not in span_mapping |
| 628 | + ) |
| 629 | + |
| 630 | + @patch("amazon.opentelemetry.distro.opentelemetry.instrumentation.langchain_v2.callback_handler.context_api") |
| 631 | + def test_on_llm_end_with_different_token_usage_keys(self, mock_context_api): |
| 632 | + """Test on_llm_end with different token usage dictionary structures.""" |
| 633 | + mock_context_api.get_value.return_value = False |
| 634 | + |
| 635 | + # Setup the span_mapping |
| 636 | + self.handler.span_mapping[self.run_id] = SpanHolder(self.mock_span, [], time.time(), "gpt-4") |
| 637 | + |
| 638 | + # Create a mock response with different token usage dictionary structures |
| 639 | + mock_response = Mock() |
| 640 | + |
| 641 | + # Test with prompt_tokens/completion_tokens |
| 642 | + mock_response.llm_output = {"token_usage": {"prompt_tokens": 10, "completion_tokens": 20}} |
| 643 | + |
| 644 | + with patch.object(self.handler, "_end_span"): |
| 645 | + self.handler.on_llm_end(response=mock_response, run_id=self.run_id) |
| 646 | + |
| 647 | + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, 10) |
| 648 | + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, 20) |
| 649 | + |
| 650 | + # Reset and test with input_token_count/generated_token_count |
| 651 | + self.mock_span.reset_mock() |
| 652 | + mock_response.llm_output = {"usage": {"input_token_count": 15, "generated_token_count": 25}} |
| 653 | + |
| 654 | + with patch.object(self.handler, "_end_span"): |
| 655 | + self.handler.on_llm_end(response=mock_response, run_id=self.run_id) |
| 656 | + |
| 657 | + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, 15) |
| 658 | + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, 25) |
| 659 | + |
| 660 | + # Reset and test with input_tokens/output_tokens |
| 661 | + self.mock_span.reset_mock() |
| 662 | + mock_response.llm_output = {"token_usage": {"input_tokens": 30, "output_tokens": 40}} |
| 663 | + |
| 664 | + with patch.object(self.handler, "_end_span"): |
| 665 | + self.handler.on_llm_end(response=mock_response, run_id=self.run_id) |
| 666 | + |
| 667 | + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_INPUT_TOKENS, 30) |
| 668 | + self.mock_span.set_attribute.assert_any_call(SpanAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, 40) |
| 669 | + |
| 670 | + |
453 | 671 | if __name__ == "__main__":
|
454 | 672 | import time
|
455 | 673 |
|
|
0 commit comments