Skip to content

Commit 3334e74

Browse files
committed
patch server_cursor_factory
1 parent c59b514 commit 3334e74

File tree

2 files changed

+160
-6
lines changed

2 files changed

+160
-6
lines changed

instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@
157157

158158
_logger = logging.getLogger(__name__)
159159
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"
160+
_OTEL_SERVER_CURSOR_FACTORY_KEY = "_otel_orig_server_cursor_factory"
160161

161162

162163
class PsycopgInstrumentor(BaseInstrumentor):
@@ -257,9 +258,17 @@ def instrument_connection(connection, tracer_provider=None):
257258
setattr(
258259
connection, _OTEL_CURSOR_FACTORY_KEY, connection.cursor_factory
259260
)
261+
setattr(
262+
connection,
263+
_OTEL_SERVER_CURSOR_FACTORY_KEY,
264+
connection.server_cursor_factory,
265+
)
260266
connection.cursor_factory = _new_cursor_factory(
261267
tracer_provider=tracer_provider
262268
)
269+
connection.server_cursor_factory = _new_cursor_factory(
270+
tracer_provider=tracer_provider
271+
)
263272
connection._is_instrumented_by_opentelemetry = True
264273
else:
265274
_logger.warning(
@@ -273,6 +282,9 @@ def uninstrument_connection(connection):
273282
connection.cursor_factory = getattr(
274283
connection, _OTEL_CURSOR_FACTORY_KEY, None
275284
)
285+
connection.server_cursor_factory = getattr(
286+
connection, _OTEL_SERVER_CURSOR_FACTORY_KEY, None
287+
)
276288

277289
return connection
278290

@@ -293,6 +305,12 @@ def wrapped_connection(
293305
kwargs["cursor_factory"] = _new_cursor_factory(**new_factory_kwargs)
294306
connection = connect_method(*args, **kwargs)
295307
self.get_connection_attributes(connection)
308+
309+
connection.server_cursor_factory = _new_cursor_factory(
310+
db_api=self,
311+
base_factory=getattr(connection, "server_cursor_factory", None),
312+
)
313+
296314
return connection
297315

298316

@@ -313,6 +331,11 @@ async def wrapped_connection(
313331
)
314332
connection = await connect_method(*args, **kwargs)
315333
self.get_connection_attributes(connection)
334+
335+
connection.server_cursor_factory = _new_cursor_async_factory(
336+
db_api=self,
337+
base_factory=getattr(connection, "server_cursor_factory", None),
338+
)
316339
return connection
317340

318341

instrumentation/opentelemetry-instrumentation-psycopg/tests/test_psycopg_integration.py

Lines changed: 137 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import types
16+
from typing import Optional
1617
from unittest import IsolatedAsyncioTestCase, mock
1718

1819
import psycopg
@@ -83,10 +84,14 @@ class MockConnection:
8384

8485
def __init__(self, *args, **kwargs):
8586
self.cursor_factory = kwargs.pop("cursor_factory", None)
87+
self.server_cursor_factory = lambda _: MockCursor()
8688

87-
def cursor(self):
88-
if self.cursor_factory:
89+
def cursor(self, name: Optional[str] = None):
90+
if not name and self.cursor_factory:
8991
return self.cursor_factory(self)
92+
93+
if name and self.server_cursor_factory:
94+
return self.server_cursor_factory(self)
9095
return MockCursor()
9196

9297
def get_dsn_parameters(self): # pylint: disable=no-self-use
@@ -102,15 +107,18 @@ class MockAsyncConnection:
102107

103108
def __init__(self, *args, **kwargs):
104109
self.cursor_factory = kwargs.pop("cursor_factory", None)
110+
self.server_cursor_factory = lambda _: MockAsyncCursor()
105111

106112
@staticmethod
107113
async def connect(*args, **kwargs):
108114
return MockAsyncConnection(**kwargs)
109115

110-
def cursor(self):
111-
if self.cursor_factory:
112-
cur = self.cursor_factory(self)
113-
return cur
116+
def cursor(self, name: Optional[str] = None):
117+
if not name and self.cursor_factory:
118+
return self.cursor_factory(self)
119+
120+
if name and self.server_cursor_factory:
121+
return self.server_cursor_factory(self)
114122
return MockAsyncCursor()
115123

116124
def execute(self, query, params=None, *, prepare=None, binary=False):
@@ -197,6 +205,36 @@ def test_instrumentor(self):
197205
spans_list = self.memory_exporter.get_finished_spans()
198206
self.assertEqual(len(spans_list), 1)
199207

208+
def test_instrumentor_with_named_cursor(self):
209+
PsycopgInstrumentor().instrument()
210+
211+
cnx = psycopg.connect(database="test")
212+
213+
cursor = cnx.cursor(name="named_cursor")
214+
215+
query = "SELECT * FROM test"
216+
cursor.execute(query)
217+
218+
spans_list = self.memory_exporter.get_finished_spans()
219+
self.assertEqual(len(spans_list), 1)
220+
span = spans_list[0]
221+
222+
# Check version and name in span's instrumentation info
223+
self.assertEqualSpanInstrumentationScope(
224+
span, opentelemetry.instrumentation.psycopg
225+
)
226+
227+
# check that no spans are generated after uninstrument
228+
PsycopgInstrumentor().uninstrument()
229+
230+
cnx = psycopg.connect(database="test")
231+
cursor = cnx.cursor(name="named_cursor")
232+
query = "SELECT * FROM test"
233+
cursor.execute(query)
234+
235+
spans_list = self.memory_exporter.get_finished_spans()
236+
self.assertEqual(len(spans_list), 1)
237+
200238
# pylint: disable=unused-argument
201239
def test_instrumentor_with_connection_class(self):
202240
PsycopgInstrumentor().instrument()
@@ -228,6 +266,36 @@ def test_instrumentor_with_connection_class(self):
228266
spans_list = self.memory_exporter.get_finished_spans()
229267
self.assertEqual(len(spans_list), 1)
230268

269+
def test_instrumentor_with_connection_class_and_named_cursor(self):
270+
PsycopgInstrumentor().instrument()
271+
272+
cnx = psycopg.Connection.connect(database="test")
273+
274+
cursor = cnx.cursor(name="named_cursor")
275+
276+
query = "SELECT * FROM test"
277+
cursor.execute(query)
278+
279+
spans_list = self.memory_exporter.get_finished_spans()
280+
self.assertEqual(len(spans_list), 1)
281+
span = spans_list[0]
282+
283+
# Check version and name in span's instrumentation info
284+
self.assertEqualSpanInstrumentationScope(
285+
span, opentelemetry.instrumentation.psycopg
286+
)
287+
288+
# check that no spans are generated after uninstrument
289+
PsycopgInstrumentor().uninstrument()
290+
291+
cnx = psycopg.Connection.connect(database="test")
292+
cursor = cnx.cursor(name="named_cursor")
293+
query = "SELECT * FROM test"
294+
cursor.execute(query)
295+
296+
spans_list = self.memory_exporter.get_finished_spans()
297+
self.assertEqual(len(spans_list), 1)
298+
231299
def test_span_name(self):
232300
PsycopgInstrumentor().instrument()
233301

@@ -314,6 +382,23 @@ def test_instrument_connection(self):
314382
spans_list = self.memory_exporter.get_finished_spans()
315383
self.assertEqual(len(spans_list), 1)
316384

385+
# pylint: disable=unused-argument
386+
def test_instrument_connection_with_named_cursor(self):
387+
cnx = psycopg.connect(database="test")
388+
query = "SELECT * FROM test"
389+
cursor = cnx.cursor(name="named_cursor")
390+
cursor.execute(query)
391+
392+
spans_list = self.memory_exporter.get_finished_spans()
393+
self.assertEqual(len(spans_list), 0)
394+
395+
cnx = PsycopgInstrumentor().instrument_connection(cnx)
396+
cursor = cnx.cursor(name="named_cursor")
397+
cursor.execute(query)
398+
399+
spans_list = self.memory_exporter.get_finished_spans()
400+
self.assertEqual(len(spans_list), 1)
401+
317402
# pylint: disable=unused-argument
318403
def test_instrument_connection_with_instrument(self):
319404
cnx = psycopg.connect(database="test")
@@ -368,6 +453,25 @@ def test_uninstrument_connection_with_instrument_connection(self):
368453
spans_list = self.memory_exporter.get_finished_spans()
369454
self.assertEqual(len(spans_list), 1)
370455

456+
def test_uninstrument_connection_with_instrument_connection_and_named_cursor(
457+
self,
458+
):
459+
cnx = psycopg.connect(database="test")
460+
PsycopgInstrumentor().instrument_connection(cnx)
461+
query = "SELECT * FROM test"
462+
cursor = cnx.cursor(name="named_cursor")
463+
cursor.execute(query)
464+
465+
spans_list = self.memory_exporter.get_finished_spans()
466+
self.assertEqual(len(spans_list), 1)
467+
468+
cnx = PsycopgInstrumentor().uninstrument_connection(cnx)
469+
cursor = cnx.cursor(name="named_cursor")
470+
cursor.execute(query)
471+
472+
spans_list = self.memory_exporter.get_finished_spans()
473+
self.assertEqual(len(spans_list), 1)
474+
371475
@mock.patch("opentelemetry.instrumentation.dbapi.wrap_connect")
372476
def test_sqlcommenter_enabled(self, event_mocked):
373477
cnx = psycopg.connect(database="test")
@@ -419,6 +523,33 @@ async def test_async_connection():
419523
spans_list = self.memory_exporter.get_finished_spans()
420524
self.assertEqual(len(spans_list), 1)
421525

526+
async def test_wrap_async_connection_class_with_named_cursor(self):
527+
PsycopgInstrumentor().instrument()
528+
529+
async def test_async_connection():
530+
acnx = await psycopg.AsyncConnection.connect("test")
531+
async with acnx as cnx:
532+
async with cnx.cursor(name="named_cursor") as cursor:
533+
await cursor.execute("SELECT * FROM test")
534+
535+
await test_async_connection()
536+
spans_list = self.memory_exporter.get_finished_spans()
537+
self.assertEqual(len(spans_list), 1)
538+
span = spans_list[0]
539+
540+
# Check version and name in span's instrumentation info
541+
self.assertEqualSpanInstrumentationScope(
542+
span, opentelemetry.instrumentation.psycopg
543+
)
544+
545+
# check that no spans are generated after uninstrument
546+
PsycopgInstrumentor().uninstrument()
547+
548+
await test_async_connection()
549+
550+
spans_list = self.memory_exporter.get_finished_spans()
551+
self.assertEqual(len(spans_list), 1)
552+
422553
# pylint: disable=unused-argument
423554
async def test_instrumentor_with_async_connection_class(self):
424555
PsycopgInstrumentor().instrument()

0 commit comments

Comments
 (0)