1313# limitations under the License.
1414
1515import types
16+ from typing import Optional
1617from unittest import IsolatedAsyncioTestCase , mock
1718
1819import psycopg
@@ -83,10 +84,14 @@ class MockConnection:
8384
8485 def __init__ (self , * args , ** kwargs ):
8586 self .cursor_factory = kwargs .pop ("cursor_factory" , None )
87+ self .server_cursor_factory = None
8688
87- def cursor (self ):
88- if self .cursor_factory :
89+ def cursor (self , name : Optional [ str ] = None ):
90+ if not name and self .cursor_factory :
8991 return self .cursor_factory (self )
92+
93+ if name and self .server_cursor_factory :
94+ return self .server_cursor_factory (self )
9095 return MockCursor ()
9196
9297 def get_dsn_parameters (self ): # pylint: disable=no-self-use
@@ -102,15 +107,18 @@ class MockAsyncConnection:
102107
103108 def __init__ (self , * args , ** kwargs ):
104109 self .cursor_factory = kwargs .pop ("cursor_factory" , None )
110+ self .server_cursor_factory = None
105111
106112 @staticmethod
107113 async def connect (* args , ** kwargs ):
108114 return MockAsyncConnection (** kwargs )
109115
110- def cursor (self ):
111- if self .cursor_factory :
112- cur = self .cursor_factory (self )
113- return cur
116+ def cursor (self , name : Optional [str ] = None ):
117+ if not name and self .cursor_factory :
118+ return self .cursor_factory (self )
119+
120+ if name and self .server_cursor_factory :
121+ return self .server_cursor_factory (self )
114122 return MockAsyncCursor ()
115123
116124 def execute (self , query , params = None , * , prepare = None , binary = False ):
@@ -197,6 +205,36 @@ def test_instrumentor(self):
197205 spans_list = self .memory_exporter .get_finished_spans ()
198206 self .assertEqual (len (spans_list ), 1 )
199207
208+ def test_instrumentor_with_named_cursor (self ):
209+ PsycopgInstrumentor ().instrument ()
210+
211+ cnx = psycopg .connect (database = "test" )
212+
213+ cursor = cnx .cursor (name = "named_cursor" )
214+
215+ query = "SELECT * FROM test"
216+ cursor .execute (query )
217+
218+ spans_list = self .memory_exporter .get_finished_spans ()
219+ self .assertEqual (len (spans_list ), 1 )
220+ span = spans_list [0 ]
221+
222+ # Check version and name in span's instrumentation info
223+ self .assertEqualSpanInstrumentationScope (
224+ span , opentelemetry .instrumentation .psycopg
225+ )
226+
227+ # check that no spans are generated after uninstrument
228+ PsycopgInstrumentor ().uninstrument ()
229+
230+ cnx = psycopg .connect (database = "test" )
231+ cursor = cnx .cursor (name = "named_cursor" )
232+ query = "SELECT * FROM test"
233+ cursor .execute (query )
234+
235+ spans_list = self .memory_exporter .get_finished_spans ()
236+ self .assertEqual (len (spans_list ), 1 )
237+
200238 # pylint: disable=unused-argument
201239 def test_instrumentor_with_connection_class (self ):
202240 PsycopgInstrumentor ().instrument ()
@@ -228,6 +266,36 @@ def test_instrumentor_with_connection_class(self):
228266 spans_list = self .memory_exporter .get_finished_spans ()
229267 self .assertEqual (len (spans_list ), 1 )
230268
269+ def test_instrumentor_with_connection_class_and_named_cursor (self ):
270+ PsycopgInstrumentor ().instrument ()
271+
272+ cnx = psycopg .Connection .connect (database = "test" )
273+
274+ cursor = cnx .cursor (name = "named_cursor" )
275+
276+ query = "SELECT * FROM test"
277+ cursor .execute (query )
278+
279+ spans_list = self .memory_exporter .get_finished_spans ()
280+ self .assertEqual (len (spans_list ), 1 )
281+ span = spans_list [0 ]
282+
283+ # Check version and name in span's instrumentation info
284+ self .assertEqualSpanInstrumentationScope (
285+ span , opentelemetry .instrumentation .psycopg
286+ )
287+
288+ # check that no spans are generated after uninstrument
289+ PsycopgInstrumentor ().uninstrument ()
290+
291+ cnx = psycopg .Connection .connect (database = "test" )
292+ cursor = cnx .cursor (name = "named_cursor" )
293+ query = "SELECT * FROM test"
294+ cursor .execute (query )
295+
296+ spans_list = self .memory_exporter .get_finished_spans ()
297+ self .assertEqual (len (spans_list ), 1 )
298+
231299 def test_span_name (self ):
232300 PsycopgInstrumentor ().instrument ()
233301
@@ -314,6 +382,23 @@ def test_instrument_connection(self):
314382 spans_list = self .memory_exporter .get_finished_spans ()
315383 self .assertEqual (len (spans_list ), 1 )
316384
385+ # pylint: disable=unused-argument
386+ def test_instrument_connection_with_named_cursor (self ):
387+ cnx = psycopg .connect (database = "test" )
388+ query = "SELECT * FROM test"
389+ cursor = cnx .cursor (name = "named_cursor" )
390+ cursor .execute (query )
391+
392+ spans_list = self .memory_exporter .get_finished_spans ()
393+ self .assertEqual (len (spans_list ), 0 )
394+
395+ cnx = PsycopgInstrumentor ().instrument_connection (cnx )
396+ cursor = cnx .cursor (name = "named_cursor" )
397+ cursor .execute (query )
398+
399+ spans_list = self .memory_exporter .get_finished_spans ()
400+ self .assertEqual (len (spans_list ), 1 )
401+
317402 # pylint: disable=unused-argument
318403 def test_instrument_connection_with_instrument (self ):
319404 cnx = psycopg .connect (database = "test" )
@@ -368,6 +453,23 @@ def test_uninstrument_connection_with_instrument_connection(self):
368453 spans_list = self .memory_exporter .get_finished_spans ()
369454 self .assertEqual (len (spans_list ), 1 )
370455
456+ def test_uninstrument_connection_with_instrument_connection_and_named_cursor (self ):
457+ cnx = psycopg .connect (database = "test" )
458+ PsycopgInstrumentor ().instrument_connection (cnx )
459+ query = "SELECT * FROM test"
460+ cursor = cnx .cursor (name = "named_cursor" )
461+ cursor .execute (query )
462+
463+ spans_list = self .memory_exporter .get_finished_spans ()
464+ self .assertEqual (len (spans_list ), 1 )
465+
466+ cnx = PsycopgInstrumentor ().uninstrument_connection (cnx )
467+ cursor = cnx .cursor (name = "named_cursor" )
468+ cursor .execute (query )
469+
470+ spans_list = self .memory_exporter .get_finished_spans ()
471+ self .assertEqual (len (spans_list ), 1 )
472+
371473 @mock .patch ("opentelemetry.instrumentation.dbapi.wrap_connect" )
372474 def test_sqlcommenter_enabled (self , event_mocked ):
373475 cnx = psycopg .connect (database = "test" )
@@ -419,6 +521,33 @@ async def test_async_connection():
419521 spans_list = self .memory_exporter .get_finished_spans ()
420522 self .assertEqual (len (spans_list ), 1 )
421523
524+ async def test_wrap_async_connection_class_with_named_cursor (self ):
525+ PsycopgInstrumentor ().instrument ()
526+
527+ async def test_async_connection ():
528+ acnx = await psycopg .AsyncConnection .connect ("test" )
529+ async with acnx as cnx :
530+ async with cnx .cursor (name = "named_cursor" ) as cursor :
531+ await cursor .execute ("SELECT * FROM test" )
532+
533+ await test_async_connection ()
534+ spans_list = self .memory_exporter .get_finished_spans ()
535+ self .assertEqual (len (spans_list ), 1 )
536+ span = spans_list [0 ]
537+
538+ # Check version and name in span's instrumentation info
539+ self .assertEqualSpanInstrumentationScope (
540+ span , opentelemetry .instrumentation .psycopg
541+ )
542+
543+ # check that no spans are generated after uninstrument
544+ PsycopgInstrumentor ().uninstrument ()
545+
546+ await test_async_connection ()
547+
548+ spans_list = self .memory_exporter .get_finished_spans ()
549+ self .assertEqual (len (spans_list ), 1 )
550+
422551 # pylint: disable=unused-argument
423552 async def test_instrumentor_with_async_connection_class (self ):
424553 PsycopgInstrumentor ().instrument ()
0 commit comments