|
1 | 1 | # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
2 | 2 | # SPDX-License-Identifier: Apache-2.0
|
3 |
| -import unittest |
| 3 | +from unittest import TestCase |
4 | 4 | from unittest.mock import MagicMock, patch
|
5 | 5 |
|
6 | 6 | from amazon.opentelemetry.distro.patches._starlette_patches import _apply_starlette_instrumentation_patches
|
7 | 7 |
|
8 | 8 |
|
9 |
| -class TestStarlettePatch(unittest.TestCase): |
| 9 | +class TestStarlettePatch(TestCase): |
10 | 10 | """Test the Starlette instrumentation patches."""
|
11 | 11 |
|
12 | 12 | @patch("amazon.opentelemetry.distro.patches._starlette_patches._logger")
|
13 | 13 | def test_starlette_patch_applied_successfully(self, mock_logger):
|
14 | 14 | """Test that the Starlette instrumentation patch is applied successfully."""
|
15 |
| - for agent_enabled in [True, False]: |
16 |
| - with self.subTest(agent_enabled=agent_enabled): |
17 |
| - with patch.dict("os.environ", {"AGENT_OBSERVABILITY_ENABLED": "true" if agent_enabled else "false"}): |
| 15 | + for patched_starlette_enabled in [True, False]: |
| 16 | + with self.subTest(agent_enabled=patched_starlette_enabled): |
| 17 | + with patch.dict( |
| 18 | + "os.environ", {"AGENT_OBSERVABILITY_ENABLED": "true" if patched_starlette_enabled else "false"} |
| 19 | + ): |
18 | 20 | # Create a mock StarletteInstrumentor class
|
19 | 21 | mock_instrumentor_class = MagicMock()
|
20 | 22 | mock_instrumentor_class.__name__ = "StarletteInstrumentor"
|
@@ -52,13 +54,18 @@ def __init__(self, app, **kwargs):
|
52 | 54 | result = mock_instrumentor_class.instrumentation_dependencies(mock_instance)
|
53 | 55 | self.assertEqual(result, ("starlette >= 0.13",))
|
54 | 56 |
|
| 57 | + mock_app = MagicMock() |
| 58 | + if patched_starlette_enabled: |
| 59 | + mock_app.user_middleware = [] |
| 60 | + mock_app.middleware_stack = None |
| 61 | + |
55 | 62 | mock_middleware_instance = MagicMock()
|
56 | 63 | mock_middleware_instance.exclude_receive_span = False
|
57 | 64 | mock_middleware_instance.exclude_send_span = False
|
58 |
| - mock_middleware_class.__init__(mock_middleware_instance, "app") |
| 65 | + mock_middleware_class.__init__(mock_middleware_instance, mock_app) |
59 | 66 |
|
60 | 67 | # Test middleware patching sets exclude flags
|
61 |
| - if agent_enabled: |
| 68 | + if patched_starlette_enabled: |
62 | 69 | self.assertTrue(mock_middleware_instance.exclude_receive_span)
|
63 | 70 | self.assertTrue(mock_middleware_instance.exclude_send_span)
|
64 | 71 | else:
|
|
0 commit comments