@@ -12,31 +12,63 @@ class TestStarlettePatch(TestCase):
1212    @patch ("amazon.opentelemetry.distro.patches._starlette_patches._logger" ) 
1313    def  test_starlette_patch_applied_successfully (self , mock_logger ):
1414        """Test that the Starlette instrumentation patch is applied successfully.""" 
15-         # Create a mock StarletteInstrumentor class 
16-         mock_instrumentor_class  =  MagicMock ()
17-         mock_instrumentor_class .__name__  =  "StarletteInstrumentor" 
18- 
19-         # Create a mock module 
20-         mock_starlette_module  =  MagicMock ()
21-         mock_starlette_module .StarletteInstrumentor  =  mock_instrumentor_class 
22- 
23-         # Mock the import 
24-         with  patch .dict ("sys.modules" , {"opentelemetry.instrumentation.starlette" : mock_starlette_module }):
25-             # Apply the patch 
26-             _apply_starlette_instrumentation_patches ()
27- 
28-             # Verify the instrumentation_dependencies method was replaced 
29-             self .assertTrue (hasattr (mock_instrumentor_class , "instrumentation_dependencies" ))
30- 
31-             # Test the patched method returns the expected value 
32-             mock_instance  =  MagicMock ()
33-             result  =  mock_instrumentor_class .instrumentation_dependencies (mock_instance )
34-             self .assertEqual (result , ("starlette >= 0.13" ,))
35- 
36-             # Verify logging 
37-             mock_logger .debug .assert_called_once_with (
38-                 "Successfully patched Starlette instrumentation_dependencies method" 
39-             )
15+         for  agent_enabled  in  [True , False ]:
16+             with  self .subTest (agent_enabled = agent_enabled ):
17+                 with  patch .dict ("os.environ" , {"AGENT_OBSERVABILITY_ENABLED" : "true"  if  agent_enabled  else  "false" }):
18+                     # Create a mock StarletteInstrumentor class 
19+                     mock_instrumentor_class  =  MagicMock ()
20+                     mock_instrumentor_class .__name__  =  "StarletteInstrumentor" 
21+ 
22+                     def  create_middleware_class ():
23+                         class  MockMiddleware :
24+                             def  __init__ (self , app , ** kwargs ):
25+                                 pass 
26+ 
27+                         return  MockMiddleware 
28+ 
29+                     mock_middleware_class  =  create_middleware_class ()
30+ 
31+                     mock_starlette_module  =  MagicMock ()
32+                     mock_starlette_module .StarletteInstrumentor  =  mock_instrumentor_class 
33+ 
34+                     mock_asgi_module  =  MagicMock ()
35+                     mock_asgi_module .OpenTelemetryMiddleware  =  mock_middleware_class 
36+ 
37+                     with  patch .dict (
38+                         "sys.modules" ,
39+                         {
40+                             "opentelemetry.instrumentation.starlette" : mock_starlette_module ,
41+                             "opentelemetry.instrumentation.asgi" : mock_asgi_module ,
42+                         },
43+                     ):
44+                         # Apply the patch 
45+                         _apply_starlette_instrumentation_patches ()
46+ 
47+                         # Verify the instrumentation_dependencies method was replaced 
48+                         self .assertTrue (hasattr (mock_instrumentor_class , "instrumentation_dependencies" ))
49+ 
50+                         # Test the patched method returns the expected value 
51+                         mock_instance  =  MagicMock ()
52+                         result  =  mock_instrumentor_class .instrumentation_dependencies (mock_instance )
53+                         self .assertEqual (result , ("starlette >= 0.13" ,))
54+ 
55+                         mock_middleware_instance  =  MagicMock ()
56+                         mock_middleware_instance .exclude_receive_span  =  False 
57+                         mock_middleware_instance .exclude_send_span  =  False 
58+                         mock_middleware_class .__init__ (mock_middleware_instance , "app" )
59+ 
60+                         # Test middleware patching sets exclude flags 
61+                         if  agent_enabled :
62+                             self .assertTrue (mock_middleware_instance .exclude_receive_span )
63+                             self .assertTrue (mock_middleware_instance .exclude_send_span )
64+                         else :
65+                             self .assertFalse (mock_middleware_instance .exclude_receive_span )
66+                             self .assertFalse (mock_middleware_instance .exclude_send_span )
67+ 
68+                         # Verify logging 
69+                         mock_logger .debug .assert_called_with (
70+                             "Successfully patched Starlette instrumentation_dependencies method" 
71+                         )
4072
4173    @patch ("amazon.opentelemetry.distro.patches._starlette_patches._logger" ) 
4274    def  test_starlette_patch_handles_import_error (self , mock_logger ):
0 commit comments