@@ -12,31 +12,63 @@ class TestStarlettePatch(TestCase):
1212 @patch ("amazon.opentelemetry.distro.patches._starlette_patches._logger" )
1313 def test_starlette_patch_applied_successfully (self , mock_logger ):
1414 """Test that the Starlette instrumentation patch is applied successfully."""
15- # Create a mock StarletteInstrumentor class
16- mock_instrumentor_class = MagicMock ()
17- mock_instrumentor_class .__name__ = "StarletteInstrumentor"
18-
19- # Create a mock module
20- mock_starlette_module = MagicMock ()
21- mock_starlette_module .StarletteInstrumentor = mock_instrumentor_class
22-
23- # Mock the import
24- with patch .dict ("sys.modules" , {"opentelemetry.instrumentation.starlette" : mock_starlette_module }):
25- # Apply the patch
26- _apply_starlette_instrumentation_patches ()
27-
28- # Verify the instrumentation_dependencies method was replaced
29- self .assertTrue (hasattr (mock_instrumentor_class , "instrumentation_dependencies" ))
30-
31- # Test the patched method returns the expected value
32- mock_instance = MagicMock ()
33- result = mock_instrumentor_class .instrumentation_dependencies (mock_instance )
34- self .assertEqual (result , ("starlette >= 0.13" ,))
35-
36- # Verify logging
37- mock_logger .debug .assert_called_once_with (
38- "Successfully patched Starlette instrumentation_dependencies method"
39- )
15+ for agent_enabled in [True , False ]:
16+ with self .subTest (agent_enabled = agent_enabled ):
17+ with patch .dict ("os.environ" , {"AGENT_OBSERVABILITY_ENABLED" : "true" if agent_enabled else "false" }):
18+ # Create a mock StarletteInstrumentor class
19+ mock_instrumentor_class = MagicMock ()
20+ mock_instrumentor_class .__name__ = "StarletteInstrumentor"
21+
22+ def create_middleware_class ():
23+ class MockMiddleware :
24+ def __init__ (self , app , ** kwargs ):
25+ pass
26+
27+ return MockMiddleware
28+
29+ mock_middleware_class = create_middleware_class ()
30+
31+ mock_starlette_module = MagicMock ()
32+ mock_starlette_module .StarletteInstrumentor = mock_instrumentor_class
33+
34+ mock_asgi_module = MagicMock ()
35+ mock_asgi_module .OpenTelemetryMiddleware = mock_middleware_class
36+
37+ with patch .dict (
38+ "sys.modules" ,
39+ {
40+ "opentelemetry.instrumentation.starlette" : mock_starlette_module ,
41+ "opentelemetry.instrumentation.asgi" : mock_asgi_module ,
42+ },
43+ ):
44+ # Apply the patch
45+ _apply_starlette_instrumentation_patches ()
46+
47+ # Verify the instrumentation_dependencies method was replaced
48+ self .assertTrue (hasattr (mock_instrumentor_class , "instrumentation_dependencies" ))
49+
50+ # Test the patched method returns the expected value
51+ mock_instance = MagicMock ()
52+ result = mock_instrumentor_class .instrumentation_dependencies (mock_instance )
53+ self .assertEqual (result , ("starlette >= 0.13" ,))
54+
55+ mock_middleware_instance = MagicMock ()
56+ mock_middleware_instance .exclude_receive_span = False
57+ mock_middleware_instance .exclude_send_span = False
58+ mock_middleware_class .__init__ (mock_middleware_instance , "app" )
59+
60+ # Test middleware patching sets exclude flags
61+ if agent_enabled :
62+ self .assertTrue (mock_middleware_instance .exclude_receive_span )
63+ self .assertTrue (mock_middleware_instance .exclude_send_span )
64+ else :
65+ self .assertFalse (mock_middleware_instance .exclude_receive_span )
66+ self .assertFalse (mock_middleware_instance .exclude_send_span )
67+
68+ # Verify logging
69+ mock_logger .debug .assert_called_with (
70+ "Successfully patched Starlette instrumentation_dependencies method"
71+ )
4072
4173 @patch ("amazon.opentelemetry.distro.patches._starlette_patches._logger" )
4274 def test_starlette_patch_handles_import_error (self , mock_logger ):
0 commit comments