@@ -592,9 +592,15 @@ def test_instrument_connection(self):
592592 connection2 = dbapi .instrument_connection (self .tracer , connection , "-" )
593593 self .assertIs (connection2 .__wrapped__ , connection )
594594
595+ @mock .patch (
596+ "opentelemetry.instrumentation.dbapi.get_traced_connection_proxy"
597+ )
595598 @mock .patch ("opentelemetry.instrumentation.dbapi.DatabaseApiIntegration" )
596- def test_instrument_connection_kwargs_defaults (self , mock_dbapiint ):
597- dbapi .instrument_connection (self .tracer , mock .Mock (), "foo" )
599+ def test_instrument_connection_kwargs_defaults (
600+ self , mock_dbapiint , mock_get_cnx_proxy
601+ ):
602+ mock_get_cnx_proxy .return_value = "foo_cnx"
603+ cnx = dbapi .instrument_connection (self .tracer , mock .Mock (), "foo" )
598604 kwargs = mock_dbapiint .call_args [1 ]
599605 self .assertEqual (kwargs ["connection_attributes" ], None )
600606 self .assertEqual (kwargs ["version" ], "" )
@@ -603,11 +609,19 @@ def test_instrument_connection_kwargs_defaults(self, mock_dbapiint):
603609 self .assertEqual (kwargs ["enable_commenter" ], False )
604610 self .assertEqual (kwargs ["commenter_options" ], None )
605611 self .assertEqual (kwargs ["connect_module" ], None )
612+ assert cnx == "foo_cnx"
606613
614+ @mock .patch (
615+ "opentelemetry.instrumentation.dbapi.get_traced_connection_proxy"
616+ )
607617 @mock .patch ("opentelemetry.instrumentation.dbapi.DatabaseApiIntegration" )
608- def test_instrument_connection_kwargs_provided (self , mock_dbapiint ):
618+ def test_instrument_connection_kwargs_provided (
619+ self , mock_dbapiint , mock_get_cnx_proxy
620+ ):
609621 mock_tracer_provider = mock .MagicMock ()
610622 mock_connect_module = mock .MagicMock ()
623+ mock_custom_dbapiint = mock .MagicMock ()
624+ mock_custom_get_cnx_proxy = mock .MagicMock ()
611625 dbapi .instrument_connection (
612626 self .tracer ,
613627 mock .Mock (),
@@ -619,15 +633,20 @@ def test_instrument_connection_kwargs_provided(self, mock_dbapiint):
619633 enable_commenter = True ,
620634 commenter_options = {"foo" : "bar" },
621635 connect_module = mock_connect_module ,
636+ db_api_integration_factory = mock_custom_dbapiint ,
637+ get_cnx_proxy = mock_custom_get_cnx_proxy ,
622638 )
623- kwargs = mock_dbapiint .call_args [1 ]
639+ mock_dbapiint .assert_not_called ()
640+ kwargs = mock_custom_dbapiint .call_args [1 ]
624641 self .assertEqual (kwargs ["connection_attributes" ], {"foo" : "bar" })
625642 self .assertEqual (kwargs ["version" ], "test" )
626643 self .assertIs (kwargs ["tracer_provider" ], mock_tracer_provider )
627644 self .assertEqual (kwargs ["capture_parameters" ], True )
628645 self .assertEqual (kwargs ["enable_commenter" ], True )
629646 self .assertEqual (kwargs ["commenter_options" ], {"foo" : "bar" })
630647 self .assertIs (kwargs ["connect_module" ], mock_connect_module )
648+ mock_get_cnx_proxy .assert_not_called ()
649+ mock_custom_get_cnx_proxy .assert_called_once ()
631650
632651 def test_uninstrument_connection (self ):
633652 connection = mock .Mock ()
0 commit comments