Skip to content

Commit 5e11096

Browse files
committed
patch server_cursor_factory
1 parent 29ef6a9 commit 5e11096

File tree

2 files changed

+154
-6
lines changed

2 files changed

+154
-6
lines changed

instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@
137137

138138
_logger = logging.getLogger(__name__)
139139
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"
140+
_OTEL_SERVER_CURSOR_FACTORY_KEY = "_otel_orig_server_cursor_factory"
140141

141142

142143
class PsycopgInstrumentor(BaseInstrumentor):
@@ -231,9 +232,15 @@ def instrument_connection(connection, tracer_provider=None):
231232
setattr(
232233
connection, _OTEL_CURSOR_FACTORY_KEY, connection.cursor_factory
233234
)
235+
setattr(
236+
connection, _OTEL_SERVER_CURSOR_FACTORY_KEY, connection.server_cursor_factory
237+
)
234238
connection.cursor_factory = _new_cursor_factory(
235239
tracer_provider=tracer_provider
236240
)
241+
connection.server_cursor_factory = _new_cursor_factory(
242+
tracer_provider=tracer_provider
243+
)
237244
connection._is_instrumented_by_opentelemetry = True
238245
else:
239246
_logger.warning(
@@ -247,6 +254,9 @@ def uninstrument_connection(connection):
247254
connection.cursor_factory = getattr(
248255
connection, _OTEL_CURSOR_FACTORY_KEY, None
249256
)
257+
connection.server_cursor_factory = getattr(
258+
connection, _OTEL_SERVER_CURSOR_FACTORY_KEY, None
259+
)
250260

251261
return connection
252262

@@ -267,6 +277,11 @@ def wrapped_connection(
267277
kwargs["cursor_factory"] = _new_cursor_factory(**new_factory_kwargs)
268278
connection = connect_method(*args, **kwargs)
269279
self.get_connection_attributes(connection)
280+
281+
connection.server_cursor_factory = _new_cursor_factory(
282+
db_api=self, base_factory=getattr(connection, "server_cursor_factory", None)
283+
)
284+
270285
return connection
271286

272287

@@ -287,6 +302,10 @@ async def wrapped_connection(
287302
)
288303
connection = await connect_method(*args, **kwargs)
289304
self.get_connection_attributes(connection)
305+
306+
connection.server_cursor_factory = _new_cursor_async_factory(
307+
db_api=self, base_factory=getattr(connection, "server_cursor_factory", None)
308+
)
290309
return connection
291310

292311

instrumentation/opentelemetry-instrumentation-psycopg/tests/test_psycopg_integration.py

Lines changed: 135 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import types
16+
from typing import Optional
1617
from unittest import IsolatedAsyncioTestCase, mock
1718

1819
import psycopg
@@ -83,10 +84,14 @@ class MockConnection:
8384

8485
def __init__(self, *args, **kwargs):
8586
self.cursor_factory = kwargs.pop("cursor_factory", None)
87+
self.server_cursor_factory = None
8688

87-
def cursor(self):
88-
if self.cursor_factory:
89+
def cursor(self, name: Optional[str] = None):
90+
if not name and self.cursor_factory:
8991
return self.cursor_factory(self)
92+
93+
if name and self.server_cursor_factory:
94+
return self.server_cursor_factory(self)
9095
return MockCursor()
9196

9297
def get_dsn_parameters(self): # pylint: disable=no-self-use
@@ -102,15 +107,18 @@ class MockAsyncConnection:
102107

103108
def __init__(self, *args, **kwargs):
104109
self.cursor_factory = kwargs.pop("cursor_factory", None)
110+
self.server_cursor_factory = None
105111

106112
@staticmethod
107113
async def connect(*args, **kwargs):
108114
return MockAsyncConnection(**kwargs)
109115

110-
def cursor(self):
111-
if self.cursor_factory:
112-
cur = self.cursor_factory(self)
113-
return cur
116+
def cursor(self, name: Optional[str] = None):
117+
if not name and self.cursor_factory:
118+
return self.cursor_factory(self)
119+
120+
if name and self.server_cursor_factory:
121+
return self.server_cursor_factory(self)
114122
return MockAsyncCursor()
115123

116124
def execute(self, query, params=None, *, prepare=None, binary=False):
@@ -197,6 +205,36 @@ def test_instrumentor(self):
197205
spans_list = self.memory_exporter.get_finished_spans()
198206
self.assertEqual(len(spans_list), 1)
199207

208+
def test_instrumentor_with_named_cursor(self):
209+
PsycopgInstrumentor().instrument()
210+
211+
cnx = psycopg.connect(database="test")
212+
213+
cursor = cnx.cursor(name="named_cursor")
214+
215+
query = "SELECT * FROM test"
216+
cursor.execute(query)
217+
218+
spans_list = self.memory_exporter.get_finished_spans()
219+
self.assertEqual(len(spans_list), 1)
220+
span = spans_list[0]
221+
222+
# Check version and name in span's instrumentation info
223+
self.assertEqualSpanInstrumentationScope(
224+
span, opentelemetry.instrumentation.psycopg
225+
)
226+
227+
# check that no spans are generated after uninstrument
228+
PsycopgInstrumentor().uninstrument()
229+
230+
cnx = psycopg.connect(database="test")
231+
cursor = cnx.cursor(name="named_cursor")
232+
query = "SELECT * FROM test"
233+
cursor.execute(query)
234+
235+
spans_list = self.memory_exporter.get_finished_spans()
236+
self.assertEqual(len(spans_list), 1)
237+
200238
# pylint: disable=unused-argument
201239
def test_instrumentor_with_connection_class(self):
202240
PsycopgInstrumentor().instrument()
@@ -228,6 +266,36 @@ def test_instrumentor_with_connection_class(self):
228266
spans_list = self.memory_exporter.get_finished_spans()
229267
self.assertEqual(len(spans_list), 1)
230268

269+
def test_instrumentor_with_connection_class_and_named_cursor(self):
270+
PsycopgInstrumentor().instrument()
271+
272+
cnx = psycopg.Connection.connect(database="test")
273+
274+
cursor = cnx.cursor(name="named_cursor")
275+
276+
query = "SELECT * FROM test"
277+
cursor.execute(query)
278+
279+
spans_list = self.memory_exporter.get_finished_spans()
280+
self.assertEqual(len(spans_list), 1)
281+
span = spans_list[0]
282+
283+
# Check version and name in span's instrumentation info
284+
self.assertEqualSpanInstrumentationScope(
285+
span, opentelemetry.instrumentation.psycopg
286+
)
287+
288+
# check that no spans are generated after uninstrument
289+
PsycopgInstrumentor().uninstrument()
290+
291+
cnx = psycopg.Connection.connect(database="test")
292+
cursor = cnx.cursor(name="named_cursor")
293+
query = "SELECT * FROM test"
294+
cursor.execute(query)
295+
296+
spans_list = self.memory_exporter.get_finished_spans()
297+
self.assertEqual(len(spans_list), 1)
298+
231299
def test_span_name(self):
232300
PsycopgInstrumentor().instrument()
233301

@@ -314,6 +382,23 @@ def test_instrument_connection(self):
314382
spans_list = self.memory_exporter.get_finished_spans()
315383
self.assertEqual(len(spans_list), 1)
316384

385+
# pylint: disable=unused-argument
386+
def test_instrument_connection_with_named_cursor(self):
387+
cnx = psycopg.connect(database="test")
388+
query = "SELECT * FROM test"
389+
cursor = cnx.cursor(name="named_cursor")
390+
cursor.execute(query)
391+
392+
spans_list = self.memory_exporter.get_finished_spans()
393+
self.assertEqual(len(spans_list), 0)
394+
395+
cnx = PsycopgInstrumentor().instrument_connection(cnx)
396+
cursor = cnx.cursor(name="named_cursor")
397+
cursor.execute(query)
398+
399+
spans_list = self.memory_exporter.get_finished_spans()
400+
self.assertEqual(len(spans_list), 1)
401+
317402
# pylint: disable=unused-argument
318403
def test_instrument_connection_with_instrument(self):
319404
cnx = psycopg.connect(database="test")
@@ -368,6 +453,23 @@ def test_uninstrument_connection_with_instrument_connection(self):
368453
spans_list = self.memory_exporter.get_finished_spans()
369454
self.assertEqual(len(spans_list), 1)
370455

456+
def test_uninstrument_connection_with_instrument_connection_and_named_cursor(self):
457+
cnx = psycopg.connect(database="test")
458+
PsycopgInstrumentor().instrument_connection(cnx)
459+
query = "SELECT * FROM test"
460+
cursor = cnx.cursor(name="named_cursor")
461+
cursor.execute(query)
462+
463+
spans_list = self.memory_exporter.get_finished_spans()
464+
self.assertEqual(len(spans_list), 1)
465+
466+
cnx = PsycopgInstrumentor().uninstrument_connection(cnx)
467+
cursor = cnx.cursor(name="named_cursor")
468+
cursor.execute(query)
469+
470+
spans_list = self.memory_exporter.get_finished_spans()
471+
self.assertEqual(len(spans_list), 1)
472+
371473
@mock.patch("opentelemetry.instrumentation.dbapi.wrap_connect")
372474
def test_sqlcommenter_enabled(self, event_mocked):
373475
cnx = psycopg.connect(database="test")
@@ -419,6 +521,33 @@ async def test_async_connection():
419521
spans_list = self.memory_exporter.get_finished_spans()
420522
self.assertEqual(len(spans_list), 1)
421523

524+
async def test_wrap_async_connection_class_with_named_cursor(self):
525+
PsycopgInstrumentor().instrument()
526+
527+
async def test_async_connection():
528+
acnx = await psycopg.AsyncConnection.connect("test")
529+
async with acnx as cnx:
530+
async with cnx.cursor(name="named_cursor") as cursor:
531+
await cursor.execute("SELECT * FROM test")
532+
533+
await test_async_connection()
534+
spans_list = self.memory_exporter.get_finished_spans()
535+
self.assertEqual(len(spans_list), 1)
536+
span = spans_list[0]
537+
538+
# Check version and name in span's instrumentation info
539+
self.assertEqualSpanInstrumentationScope(
540+
span, opentelemetry.instrumentation.psycopg
541+
)
542+
543+
# check that no spans are generated after uninstrument
544+
PsycopgInstrumentor().uninstrument()
545+
546+
await test_async_connection()
547+
548+
spans_list = self.memory_exporter.get_finished_spans()
549+
self.assertEqual(len(spans_list), 1)
550+
422551
# pylint: disable=unused-argument
423552
async def test_instrumentor_with_async_connection_class(self):
424553
PsycopgInstrumentor().instrument()

0 commit comments

Comments
 (0)