137137---
138138"""
139139
140+ from __future__ import annotations
141+
140142import logging
141- import typing
142- from typing import Collection
143+ from typing import Any , Callable , Collection , TypeVar
143144
144145import psycopg # pylint: disable=import-self
145- from psycopg import (
146- AsyncCursor as pg_async_cursor , # pylint: disable=import-self,no-name-in-module
147- )
148- from psycopg import (
149- Cursor as pg_cursor , # pylint: disable=no-name-in-module,import-self
150- )
151146from psycopg .sql import Composed # pylint: disable=no-name-in-module
152147
153148from opentelemetry import trace as trace_api
154149from opentelemetry .instrumentation import dbapi
155150from opentelemetry .instrumentation .instrumentor import BaseInstrumentor
156151from opentelemetry .instrumentation .psycopg .package import _instruments
157152from opentelemetry .instrumentation .psycopg .version import __version__
153+ from opentelemetry .trace import TracerProvider
158154
159155_logger = logging .getLogger (__name__ )
160156_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"
161157
158+ ConnectionT = TypeVar (
159+ "ConnectionT" , psycopg .Connection , psycopg .AsyncConnection
160+ )
161+ CursorT = TypeVar ("CursorT" , psycopg .Cursor , psycopg .AsyncCursor )
162+
162163
163164class PsycopgInstrumentor (BaseInstrumentor ):
164165 _CONNECTION_ATTRIBUTES = {
@@ -173,7 +174,7 @@ class PsycopgInstrumentor(BaseInstrumentor):
173174 def instrumentation_dependencies (self ) -> Collection [str ]:
174175 return _instruments
175176
176- def _instrument (self , ** kwargs ):
177+ def _instrument (self , ** kwargs : Any ):
177178 """Integrate with PostgreSQL Psycopg library.
178179 Psycopg: https://www.psycopg.org/psycopg3/docs/
179180 """
@@ -224,7 +225,7 @@ def _instrument(self, **kwargs):
224225 enable_attribute_commenter = enable_attribute_commenter ,
225226 )
226227
227- def _uninstrument (self , ** kwargs ):
228+ def _uninstrument (self , ** kwargs : Any ):
228229 """ "Disable Psycopg instrumentation"""
229230 dbapi .unwrap_connect (psycopg , "connect" ) # pylint: disable=no-member
230231 dbapi .unwrap_connect (
@@ -243,7 +244,7 @@ def instrument_connection(
243244 tracer_provider : typing .Optional [trace_api .TracerProvider ] = None ,
244245 enable_commenter : bool = False ,
245246 commenter_options : dict = None ,
246- enable_attribute_commenter = None ,
247+ enable_attribute_commenter : bool = False ,
247248 ):
248249 """Enable instrumentation of a Psycopg connection.
249250
@@ -285,7 +286,7 @@ def instrument_connection(
285286
286287 # TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
287288 @staticmethod
288- def uninstrument_connection (connection ) :
289+ def uninstrument_connection (connection : ConnectionT ) -> ConnectionT :
289290 connection .cursor_factory = getattr (
290291 connection , _OTEL_CURSOR_FACTORY_KEY , None
291292 )
@@ -297,9 +298,9 @@ def uninstrument_connection(connection):
297298class DatabaseApiIntegration (dbapi .DatabaseApiIntegration ):
298299 def wrapped_connection (
299300 self ,
300- connect_method : typing . Callable [..., typing . Any ],
301- args : typing . Tuple [ typing . Any , typing . Any ],
302- kwargs : typing . Dict [ typing . Any , typing . Any ],
301+ connect_method : Callable [..., Any ],
302+ args : tuple [ Any , Any ],
303+ kwargs : dict [ Any , Any ],
303304 ):
304305 """Add object proxy to connection object."""
305306 base_cursor_factory = kwargs .pop ("cursor_factory" , None )
@@ -315,9 +316,9 @@ def wrapped_connection(
315316class DatabaseApiAsyncIntegration (dbapi .DatabaseApiIntegration ):
316317 async def wrapped_connection (
317318 self ,
318- connect_method : typing . Callable [..., typing . Any ],
319- args : typing . Tuple [ typing . Any , typing . Any ],
320- kwargs : typing . Dict [ typing . Any , typing . Any ],
319+ connect_method : Callable [..., Any ],
320+ args : tuple [ Any , Any ],
321+ kwargs : dict [ Any , Any ],
321322 ):
322323 """Add object proxy to connection object."""
323324 base_cursor_factory = kwargs .pop ("cursor_factory" , None )
@@ -333,7 +334,7 @@ async def wrapped_connection(
333334
334335
335336class CursorTracer (dbapi .CursorTracer ):
336- def get_operation_name (self , cursor , args ) :
337+ def get_operation_name (self , cursor : CursorT , args : list [ Any ]) -> str :
337338 if not args :
338339 return ""
339340
@@ -348,7 +349,7 @@ def get_operation_name(self, cursor, args):
348349
349350 return ""
350351
351- def get_statement (self , cursor , args ) :
352+ def get_statement (self , cursor : CursorT , args : list [ Any ]) -> str :
352353 if not args :
353354 return ""
354355
@@ -379,21 +380,21 @@ def _new_cursor_factory(
379380 enable_attribute_commenter = enable_attribute_commenter ,
380381 )
381382
382- base_factory = base_factory or pg_cursor
383+ base_factory = base_factory or psycopg . Cursor
383384 _cursor_tracer = CursorTracer (db_api )
384385
385386 class TracedCursorFactory (base_factory ):
386- def execute (self , * args , ** kwargs ):
387+ def execute (self , * args : Any , ** kwargs : Any ):
387388 return _cursor_tracer .traced_execution (
388389 self , super ().execute , * args , ** kwargs
389390 )
390391
391- def executemany (self , * args , ** kwargs ):
392+ def executemany (self , * args : Any , ** kwargs : Any ):
392393 return _cursor_tracer .traced_execution (
393394 self , super ().executemany , * args , ** kwargs
394395 )
395396
396- def callproc (self , * args , ** kwargs ):
397+ def callproc (self , * args : Any , ** kwargs : Any ):
397398 return _cursor_tracer .traced_execution (
398399 self , super ().callproc , * args , ** kwargs
399400 )
@@ -402,7 +403,9 @@ def callproc(self, *args, **kwargs):
402403
403404
404405def _new_cursor_async_factory (
405- db_api = None , base_factory = None , tracer_provider = None
406+ db_api : DatabaseApiAsyncIntegration | None = None ,
407+ base_factory : type [psycopg .AsyncCursor ] | None = None ,
408+ tracer_provider : TracerProvider | None = None ,
406409):
407410 if not db_api :
408411 db_api = DatabaseApiAsyncIntegration (
@@ -412,21 +415,21 @@ def _new_cursor_async_factory(
412415 version = __version__ ,
413416 tracer_provider = tracer_provider ,
414417 )
415- base_factory = base_factory or pg_async_cursor
418+ base_factory = base_factory or psycopg . AsyncCursor
416419 _cursor_tracer = CursorTracer (db_api )
417420
418421 class TracedCursorAsyncFactory (base_factory ):
419- async def execute (self , * args , ** kwargs ):
422+ async def execute (self , * args : Any , ** kwargs : Any ):
420423 return await _cursor_tracer .traced_execution (
421424 self , super ().execute , * args , ** kwargs
422425 )
423426
424- async def executemany (self , * args , ** kwargs ):
427+ async def executemany (self , * args : Any , ** kwargs : Any ):
425428 return await _cursor_tracer .traced_execution (
426429 self , super ().executemany , * args , ** kwargs
427430 )
428431
429- async def callproc (self , * args , ** kwargs ):
432+ async def callproc (self , * args : Any , ** kwargs : Any ):
430433 return await _cursor_tracer .traced_execution (
431434 self , super ().callproc , * args , ** kwargs
432435 )
0 commit comments