Skip to content

Commit 2c90c1c

Browse files
committed
patch server_cursor_factory
1 parent 29ef6a9 commit 2c90c1c

File tree

2 files changed

+160
-6
lines changed

2 files changed

+160
-6
lines changed

instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137

138138
_logger = logging.getLogger(__name__)
139139
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"
140+
_OTEL_SERVER_CURSOR_FACTORY_KEY = "_otel_orig_server_cursor_factory"
140141

141142

142143
class PsycopgInstrumentor(BaseInstrumentor):
@@ -231,9 +232,17 @@ def instrument_connection(connection, tracer_provider=None):
231232
setattr(
232233
connection, _OTEL_CURSOR_FACTORY_KEY, connection.cursor_factory
233234
)
235+
setattr(
236+
connection,
237+
_OTEL_SERVER_CURSOR_FACTORY_KEY,
238+
connection.server_cursor_factory,
239+
)
234240
connection.cursor_factory = _new_cursor_factory(
235241
tracer_provider=tracer_provider
236242
)
243+
connection.server_cursor_factory = _new_cursor_factory(
244+
tracer_provider=tracer_provider
245+
)
237246
connection._is_instrumented_by_opentelemetry = True
238247
else:
239248
_logger.warning(
@@ -247,6 +256,9 @@ def uninstrument_connection(connection):
247256
connection.cursor_factory = getattr(
248257
connection, _OTEL_CURSOR_FACTORY_KEY, None
249258
)
259+
connection.server_cursor_factory = getattr(
260+
connection, _OTEL_SERVER_CURSOR_FACTORY_KEY, None
261+
)
250262

251263
return connection
252264

@@ -267,6 +279,12 @@ def wrapped_connection(
267279
kwargs["cursor_factory"] = _new_cursor_factory(**new_factory_kwargs)
268280
connection = connect_method(*args, **kwargs)
269281
self.get_connection_attributes(connection)
282+
283+
connection.server_cursor_factory = _new_cursor_factory(
284+
db_api=self,
285+
base_factory=getattr(connection, "server_cursor_factory", None),
286+
)
287+
270288
return connection
271289

272290

@@ -287,6 +305,11 @@ async def wrapped_connection(
287305
)
288306
connection = await connect_method(*args, **kwargs)
289307
self.get_connection_attributes(connection)
308+
309+
connection.server_cursor_factory = _new_cursor_async_factory(
310+
db_api=self,
311+
base_factory=getattr(connection, "server_cursor_factory", None),
312+
)
290313
return connection
291314

292315

instrumentation/opentelemetry-instrumentation-psycopg/tests/test_psycopg_integration.py

Lines changed: 137 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import types
16+
from typing import Optional
1617
from unittest import IsolatedAsyncioTestCase, mock
1718

1819
import psycopg
@@ -83,10 +84,14 @@ class MockConnection:
8384

8485
def __init__(self, *args, **kwargs):
8586
self.cursor_factory = kwargs.pop("cursor_factory", None)
87+
self.server_cursor_factory = lambda _: MockCursor()
8688

87-
def cursor(self):
88-
if self.cursor_factory:
89+
def cursor(self, name: Optional[str] = None):
90+
if not name and self.cursor_factory:
8991
return self.cursor_factory(self)
92+
93+
if name and self.server_cursor_factory:
94+
return self.server_cursor_factory(self)
9095
return MockCursor()
9196

9297
def get_dsn_parameters(self): # pylint: disable=no-self-use
@@ -102,15 +107,18 @@ class MockAsyncConnection:
102107

103108
def __init__(self, *args, **kwargs):
104109
self.cursor_factory = kwargs.pop("cursor_factory", None)
110+
self.server_cursor_factory = lambda _: MockAsyncCursor()
105111

106112
@staticmethod
107113
async def connect(*args, **kwargs):
108114
return MockAsyncConnection(**kwargs)
109115

110-
def cursor(self):
111-
if self.cursor_factory:
112-
cur = self.cursor_factory(self)
113-
return cur
116+
def cursor(self, name: Optional[str] = None):
117+
if not name and self.cursor_factory:
118+
return self.cursor_factory(self)
119+
120+
if name and self.server_cursor_factory:
121+
return self.server_cursor_factory(self)
114122
return MockAsyncCursor()
115123

116124
def execute(self, query, params=None, *, prepare=None, binary=False):
@@ -197,6 +205,36 @@ def test_instrumentor(self):
197205
spans_list = self.memory_exporter.get_finished_spans()
198206
self.assertEqual(len(spans_list), 1)
199207

208+
def test_instrumentor_with_named_cursor(self):
209+
PsycopgInstrumentor().instrument()
210+
211+
cnx = psycopg.connect(database="test")
212+
213+
cursor = cnx.cursor(name="named_cursor")
214+
215+
query = "SELECT * FROM test"
216+
cursor.execute(query)
217+
218+
spans_list = self.memory_exporter.get_finished_spans()
219+
self.assertEqual(len(spans_list), 1)
220+
span = spans_list[0]
221+
222+
# Check version and name in span's instrumentation info
223+
self.assertEqualSpanInstrumentationScope(
224+
span, opentelemetry.instrumentation.psycopg
225+
)
226+
227+
# check that no spans are generated after uninstrument
228+
PsycopgInstrumentor().uninstrument()
229+
230+
cnx = psycopg.connect(database="test")
231+
cursor = cnx.cursor(name="named_cursor")
232+
query = "SELECT * FROM test"
233+
cursor.execute(query)
234+
235+
spans_list = self.memory_exporter.get_finished_spans()
236+
self.assertEqual(len(spans_list), 1)
237+
200238
# pylint: disable=unused-argument
201239
def test_instrumentor_with_connection_class(self):
202240
PsycopgInstrumentor().instrument()
@@ -228,6 +266,36 @@ def test_instrumentor_with_connection_class(self):
228266
spans_list = self.memory_exporter.get_finished_spans()
229267
self.assertEqual(len(spans_list), 1)
230268

269+
def test_instrumentor_with_connection_class_and_named_cursor(self):
270+
PsycopgInstrumentor().instrument()
271+
272+
cnx = psycopg.Connection.connect(database="test")
273+
274+
cursor = cnx.cursor(name="named_cursor")
275+
276+
query = "SELECT * FROM test"
277+
cursor.execute(query)
278+
279+
spans_list = self.memory_exporter.get_finished_spans()
280+
self.assertEqual(len(spans_list), 1)
281+
span = spans_list[0]
282+
283+
# Check version and name in span's instrumentation info
284+
self.assertEqualSpanInstrumentationScope(
285+
span, opentelemetry.instrumentation.psycopg
286+
)
287+
288+
# check that no spans are generated after uninstrument
289+
PsycopgInstrumentor().uninstrument()
290+
291+
cnx = psycopg.Connection.connect(database="test")
292+
cursor = cnx.cursor(name="named_cursor")
293+
query = "SELECT * FROM test"
294+
cursor.execute(query)
295+
296+
spans_list = self.memory_exporter.get_finished_spans()
297+
self.assertEqual(len(spans_list), 1)
298+
231299
def test_span_name(self):
232300
PsycopgInstrumentor().instrument()
233301

@@ -314,6 +382,23 @@ def test_instrument_connection(self):
314382
spans_list = self.memory_exporter.get_finished_spans()
315383
self.assertEqual(len(spans_list), 1)
316384

385+
# pylint: disable=unused-argument
386+
def test_instrument_connection_with_named_cursor(self):
387+
cnx = psycopg.connect(database="test")
388+
query = "SELECT * FROM test"
389+
cursor = cnx.cursor(name="named_cursor")
390+
cursor.execute(query)
391+
392+
spans_list = self.memory_exporter.get_finished_spans()
393+
self.assertEqual(len(spans_list), 0)
394+
395+
cnx = PsycopgInstrumentor().instrument_connection(cnx)
396+
cursor = cnx.cursor(name="named_cursor")
397+
cursor.execute(query)
398+
399+
spans_list = self.memory_exporter.get_finished_spans()
400+
self.assertEqual(len(spans_list), 1)
401+
317402
# pylint: disable=unused-argument
318403
def test_instrument_connection_with_instrument(self):
319404
cnx = psycopg.connect(database="test")
@@ -368,6 +453,25 @@ def test_uninstrument_connection_with_instrument_connection(self):
368453
spans_list = self.memory_exporter.get_finished_spans()
369454
self.assertEqual(len(spans_list), 1)
370455

456+
def test_uninstrument_connection_with_instrument_connection_and_named_cursor(
457+
self,
458+
):
459+
cnx = psycopg.connect(database="test")
460+
PsycopgInstrumentor().instrument_connection(cnx)
461+
query = "SELECT * FROM test"
462+
cursor = cnx.cursor(name="named_cursor")
463+
cursor.execute(query)
464+
465+
spans_list = self.memory_exporter.get_finished_spans()
466+
self.assertEqual(len(spans_list), 1)
467+
468+
cnx = PsycopgInstrumentor().uninstrument_connection(cnx)
469+
cursor = cnx.cursor(name="named_cursor")
470+
cursor.execute(query)
471+
472+
spans_list = self.memory_exporter.get_finished_spans()
473+
self.assertEqual(len(spans_list), 1)
474+
371475
@mock.patch("opentelemetry.instrumentation.dbapi.wrap_connect")
372476
def test_sqlcommenter_enabled(self, event_mocked):
373477
cnx = psycopg.connect(database="test")
@@ -419,6 +523,33 @@ async def test_async_connection():
419523
spans_list = self.memory_exporter.get_finished_spans()
420524
self.assertEqual(len(spans_list), 1)
421525

526+
async def test_wrap_async_connection_class_with_named_cursor(self):
527+
PsycopgInstrumentor().instrument()
528+
529+
async def test_async_connection():
530+
acnx = await psycopg.AsyncConnection.connect("test")
531+
async with acnx as cnx:
532+
async with cnx.cursor(name="named_cursor") as cursor:
533+
await cursor.execute("SELECT * FROM test")
534+
535+
await test_async_connection()
536+
spans_list = self.memory_exporter.get_finished_spans()
537+
self.assertEqual(len(spans_list), 1)
538+
span = spans_list[0]
539+
540+
# Check version and name in span's instrumentation info
541+
self.assertEqualSpanInstrumentationScope(
542+
span, opentelemetry.instrumentation.psycopg
543+
)
544+
545+
# check that no spans are generated after uninstrument
546+
PsycopgInstrumentor().uninstrument()
547+
548+
await test_async_connection()
549+
550+
spans_list = self.memory_exporter.get_finished_spans()
551+
self.assertEqual(len(spans_list), 1)
552+
422553
# pylint: disable=unused-argument
423554
async def test_instrumentor_with_async_connection_class(self):
424555
PsycopgInstrumentor().instrument()

0 commit comments

Comments
 (0)