1313# limitations under the License.
1414
1515import types
16+ from typing import Optional
1617from unittest import IsolatedAsyncioTestCase , mock
1718
1819import psycopg
@@ -83,10 +84,14 @@ class MockConnection:
8384
8485 def __init__ (self , * args , ** kwargs ):
8586 self .cursor_factory = kwargs .pop ("cursor_factory" , None )
87+ self .server_cursor_factory = lambda _ : MockCursor ()
8688
87- def cursor (self ):
88- if self .cursor_factory :
89+ def cursor (self , name : Optional [ str ] = None ):
90+ if not name and self .cursor_factory :
8991 return self .cursor_factory (self )
92+
93+ if name and self .server_cursor_factory :
94+ return self .server_cursor_factory (self )
9095 return MockCursor ()
9196
9297 def get_dsn_parameters (self ): # pylint: disable=no-self-use
@@ -102,15 +107,18 @@ class MockAsyncConnection:
102107
103108 def __init__ (self , * args , ** kwargs ):
104109 self .cursor_factory = kwargs .pop ("cursor_factory" , None )
110+ self .server_cursor_factory = lambda _ : MockAsyncCursor ()
105111
106112 @staticmethod
107113 async def connect (* args , ** kwargs ):
108114 return MockAsyncConnection (** kwargs )
109115
110- def cursor (self ):
111- if self .cursor_factory :
112- cur = self .cursor_factory (self )
113- return cur
116+ def cursor (self , name : Optional [str ] = None ):
117+ if not name and self .cursor_factory :
118+ return self .cursor_factory (self )
119+
120+ if name and self .server_cursor_factory :
121+ return self .server_cursor_factory (self )
114122 return MockAsyncCursor ()
115123
116124 def execute (self , query , params = None , * , prepare = None , binary = False ):
@@ -197,6 +205,36 @@ def test_instrumentor(self):
197205 spans_list = self .memory_exporter .get_finished_spans ()
198206 self .assertEqual (len (spans_list ), 1 )
199207
208+ def test_instrumentor_with_named_cursor (self ):
209+ PsycopgInstrumentor ().instrument ()
210+
211+ cnx = psycopg .connect (database = "test" )
212+
213+ cursor = cnx .cursor (name = "named_cursor" )
214+
215+ query = "SELECT * FROM test"
216+ cursor .execute (query )
217+
218+ spans_list = self .memory_exporter .get_finished_spans ()
219+ self .assertEqual (len (spans_list ), 1 )
220+ span = spans_list [0 ]
221+
222+ # Check version and name in span's instrumentation info
223+ self .assertEqualSpanInstrumentationScope (
224+ span , opentelemetry .instrumentation .psycopg
225+ )
226+
227+ # check that no spans are generated after uninstrument
228+ PsycopgInstrumentor ().uninstrument ()
229+
230+ cnx = psycopg .connect (database = "test" )
231+ cursor = cnx .cursor (name = "named_cursor" )
232+ query = "SELECT * FROM test"
233+ cursor .execute (query )
234+
235+ spans_list = self .memory_exporter .get_finished_spans ()
236+ self .assertEqual (len (spans_list ), 1 )
237+
200238 # pylint: disable=unused-argument
201239 def test_instrumentor_with_connection_class (self ):
202240 PsycopgInstrumentor ().instrument ()
@@ -228,6 +266,36 @@ def test_instrumentor_with_connection_class(self):
228266 spans_list = self .memory_exporter .get_finished_spans ()
229267 self .assertEqual (len (spans_list ), 1 )
230268
269+ def test_instrumentor_with_connection_class_and_named_cursor (self ):
270+ PsycopgInstrumentor ().instrument ()
271+
272+ cnx = psycopg .Connection .connect (database = "test" )
273+
274+ cursor = cnx .cursor (name = "named_cursor" )
275+
276+ query = "SELECT * FROM test"
277+ cursor .execute (query )
278+
279+ spans_list = self .memory_exporter .get_finished_spans ()
280+ self .assertEqual (len (spans_list ), 1 )
281+ span = spans_list [0 ]
282+
283+ # Check version and name in span's instrumentation info
284+ self .assertEqualSpanInstrumentationScope (
285+ span , opentelemetry .instrumentation .psycopg
286+ )
287+
288+ # check that no spans are generated after uninstrument
289+ PsycopgInstrumentor ().uninstrument ()
290+
291+ cnx = psycopg .Connection .connect (database = "test" )
292+ cursor = cnx .cursor (name = "named_cursor" )
293+ query = "SELECT * FROM test"
294+ cursor .execute (query )
295+
296+ spans_list = self .memory_exporter .get_finished_spans ()
297+ self .assertEqual (len (spans_list ), 1 )
298+
231299 def test_span_name (self ):
232300 PsycopgInstrumentor ().instrument ()
233301
@@ -314,6 +382,23 @@ def test_instrument_connection(self):
314382 spans_list = self .memory_exporter .get_finished_spans ()
315383 self .assertEqual (len (spans_list ), 1 )
316384
385+ # pylint: disable=unused-argument
386+ def test_instrument_connection_with_named_cursor (self ):
387+ cnx = psycopg .connect (database = "test" )
388+ query = "SELECT * FROM test"
389+ cursor = cnx .cursor (name = "named_cursor" )
390+ cursor .execute (query )
391+
392+ spans_list = self .memory_exporter .get_finished_spans ()
393+ self .assertEqual (len (spans_list ), 0 )
394+
395+ cnx = PsycopgInstrumentor ().instrument_connection (cnx )
396+ cursor = cnx .cursor (name = "named_cursor" )
397+ cursor .execute (query )
398+
399+ spans_list = self .memory_exporter .get_finished_spans ()
400+ self .assertEqual (len (spans_list ), 1 )
401+
317402 # pylint: disable=unused-argument
318403 def test_instrument_connection_with_instrument (self ):
319404 cnx = psycopg .connect (database = "test" )
@@ -368,6 +453,25 @@ def test_uninstrument_connection_with_instrument_connection(self):
368453 spans_list = self .memory_exporter .get_finished_spans ()
369454 self .assertEqual (len (spans_list ), 1 )
370455
456+ def test_uninstrument_connection_with_instrument_connection_and_named_cursor (
457+ self ,
458+ ):
459+ cnx = psycopg .connect (database = "test" )
460+ PsycopgInstrumentor ().instrument_connection (cnx )
461+ query = "SELECT * FROM test"
462+ cursor = cnx .cursor (name = "named_cursor" )
463+ cursor .execute (query )
464+
465+ spans_list = self .memory_exporter .get_finished_spans ()
466+ self .assertEqual (len (spans_list ), 1 )
467+
468+ cnx = PsycopgInstrumentor ().uninstrument_connection (cnx )
469+ cursor = cnx .cursor (name = "named_cursor" )
470+ cursor .execute (query )
471+
472+ spans_list = self .memory_exporter .get_finished_spans ()
473+ self .assertEqual (len (spans_list ), 1 )
474+
371475 @mock .patch ("opentelemetry.instrumentation.dbapi.wrap_connect" )
372476 def test_sqlcommenter_enabled (self , event_mocked ):
373477 cnx = psycopg .connect (database = "test" )
@@ -419,6 +523,33 @@ async def test_async_connection():
419523 spans_list = self .memory_exporter .get_finished_spans ()
420524 self .assertEqual (len (spans_list ), 1 )
421525
526+ async def test_wrap_async_connection_class_with_named_cursor (self ):
527+ PsycopgInstrumentor ().instrument ()
528+
529+ async def test_async_connection ():
530+ acnx = await psycopg .AsyncConnection .connect ("test" )
531+ async with acnx as cnx :
532+ async with cnx .cursor (name = "named_cursor" ) as cursor :
533+ await cursor .execute ("SELECT * FROM test" )
534+
535+ await test_async_connection ()
536+ spans_list = self .memory_exporter .get_finished_spans ()
537+ self .assertEqual (len (spans_list ), 1 )
538+ span = spans_list [0 ]
539+
540+ # Check version and name in span's instrumentation info
541+ self .assertEqualSpanInstrumentationScope (
542+ span , opentelemetry .instrumentation .psycopg
543+ )
544+
545+ # check that no spans are generated after uninstrument
546+ PsycopgInstrumentor ().uninstrument ()
547+
548+ await test_async_connection ()
549+
550+ spans_list = self .memory_exporter .get_finished_spans ()
551+ self .assertEqual (len (spans_list ), 1 )
552+
422553 # pylint: disable=unused-argument
423554 async def test_instrumentor_with_async_connection_class (self ):
424555 PsycopgInstrumentor ().instrument ()
0 commit comments