101101---
102102"""
103103
104+ from __future__ import annotations
105+
104106import logging
105- import typing
106- from typing import Collection
107+ from typing import Any , Callable , Collection , TypeVar
107108
108109import psycopg # pylint: disable=import-self
109- from psycopg import (
110- AsyncCursor as pg_async_cursor , # pylint: disable=import-self,no-name-in-module
111- )
112- from psycopg import (
113- Cursor as pg_cursor , # pylint: disable=no-name-in-module,import-self
114- )
115110from psycopg .sql import Composed # pylint: disable=no-name-in-module
116111
117112from opentelemetry .instrumentation import dbapi
118113from opentelemetry .instrumentation .instrumentor import BaseInstrumentor
119114from opentelemetry .instrumentation .psycopg .package import _instruments
120115from opentelemetry .instrumentation .psycopg .version import __version__
116+ from opentelemetry .trace import TracerProvider
121117
122118_logger = logging .getLogger (__name__ )
123119_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"
124120
121+ Connection = TypeVar ("Connection" , psycopg .Connection , psycopg .AsyncConnection )
122+ Cursor = TypeVar ("Cursor" , psycopg .Cursor , psycopg .AsyncCursor )
123+
125124
126125class PsycopgInstrumentor (BaseInstrumentor ):
127126 _CONNECTION_ATTRIBUTES = {
@@ -136,7 +135,7 @@ class PsycopgInstrumentor(BaseInstrumentor):
136135 def instrumentation_dependencies (self ) -> Collection [str ]:
137136 return _instruments
138137
139- def _instrument (self , ** kwargs ):
138+ def _instrument (self , ** kwargs : Any ):
140139 """Integrate with PostgreSQL Psycopg library.
141140 Psycopg: http://initd.org/psycopg/
142141 """
@@ -181,7 +180,7 @@ def _instrument(self, **kwargs):
181180 commenter_options = commenter_options ,
182181 )
183182
184- def _uninstrument (self , ** kwargs ):
183+ def _uninstrument (self , ** kwargs : Any ):
185184 """ "Disable Psycopg instrumentation"""
186185 dbapi .unwrap_connect (psycopg , "connect" ) # pylint: disable=no-member
187186 dbapi .unwrap_connect (
@@ -195,7 +194,9 @@ def _uninstrument(self, **kwargs):
195194
196195 # TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
197196 @staticmethod
198- def instrument_connection (connection , tracer_provider = None ):
197+ def instrument_connection (
198+ connection : Connection , tracer_provider : TracerProvider | None = None
199+ ) -> Connection :
199200 if not hasattr (connection , "_is_instrumented_by_opentelemetry" ):
200201 connection ._is_instrumented_by_opentelemetry = False
201202
@@ -215,7 +216,7 @@ def instrument_connection(connection, tracer_provider=None):
215216
216217 # TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
217218 @staticmethod
218- def uninstrument_connection (connection ) :
219+ def uninstrument_connection (connection : Connection ) -> Connection :
219220 connection .cursor_factory = getattr (
220221 connection , _OTEL_CURSOR_FACTORY_KEY , None
221222 )
@@ -227,9 +228,9 @@ def uninstrument_connection(connection):
227228class DatabaseApiIntegration (dbapi .DatabaseApiIntegration ):
228229 def wrapped_connection (
229230 self ,
230- connect_method : typing . Callable [..., typing . Any ],
231- args : typing . Tuple [ typing . Any , typing . Any ],
232- kwargs : typing . Dict [ typing . Any , typing . Any ],
231+ connect_method : Callable [..., Any ],
232+ args : tuple [ Any , Any ],
233+ kwargs : dict [ Any , Any ],
233234 ):
234235 """Add object proxy to connection object."""
235236 base_cursor_factory = kwargs .pop ("cursor_factory" , None )
@@ -245,9 +246,9 @@ def wrapped_connection(
245246class DatabaseApiAsyncIntegration (dbapi .DatabaseApiIntegration ):
246247 async def wrapped_connection (
247248 self ,
248- connect_method : typing . Callable [..., typing . Any ],
249- args : typing . Tuple [ typing . Any , typing . Any ],
250- kwargs : typing . Dict [ typing . Any , typing . Any ],
249+ connect_method : Callable [..., Any ],
250+ args : tuple [ Any , Any ],
251+ kwargs : dict [ Any , Any ],
251252 ):
252253 """Add object proxy to connection object."""
253254 base_cursor_factory = kwargs .pop ("cursor_factory" , None )
@@ -263,7 +264,7 @@ async def wrapped_connection(
263264
264265
265266class CursorTracer (dbapi .CursorTracer ):
266- def get_operation_name (self , cursor , args ) :
267+ def get_operation_name (self , cursor : Cursor , args : list [ Any ]) -> str :
267268 if not args :
268269 return ""
269270
@@ -278,7 +279,7 @@ def get_operation_name(self, cursor, args):
278279
279280 return ""
280281
281- def get_statement (self , cursor , args ) :
282+ def get_statement (self , cursor : Cursor , args : list [ Any ]) -> str :
282283 if not args :
283284 return ""
284285
@@ -288,7 +289,11 @@ def get_statement(self, cursor, args):
288289 return statement
289290
290291
291- def _new_cursor_factory (db_api = None , base_factory = None , tracer_provider = None ):
292+ def _new_cursor_factory (
293+ db_api : DatabaseApiIntegration | None = None ,
294+ base_factory : type [psycopg .Cursor ] | None = None ,
295+ tracer_provider : TracerProvider | None = None ,
296+ ):
292297 if not db_api :
293298 db_api = DatabaseApiIntegration (
294299 __name__ ,
@@ -298,21 +303,21 @@ def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
298303 tracer_provider = tracer_provider ,
299304 )
300305
301- base_factory = base_factory or pg_cursor
306+ base_factory = base_factory or psycopg . Cursor
302307 _cursor_tracer = CursorTracer (db_api )
303308
304309 class TracedCursorFactory (base_factory ):
305- def execute (self , * args , ** kwargs ):
310+ def execute (self , * args : Any , ** kwargs : Any ):
306311 return _cursor_tracer .traced_execution (
307312 self , super ().execute , * args , ** kwargs
308313 )
309314
310- def executemany (self , * args , ** kwargs ):
315+ def executemany (self , * args : Any , ** kwargs : Any ):
311316 return _cursor_tracer .traced_execution (
312317 self , super ().executemany , * args , ** kwargs
313318 )
314319
315- def callproc (self , * args , ** kwargs ):
320+ def callproc (self , * args : Any , ** kwargs : Any ):
316321 return _cursor_tracer .traced_execution (
317322 self , super ().callproc , * args , ** kwargs
318323 )
@@ -321,7 +326,9 @@ def callproc(self, *args, **kwargs):
321326
322327
323328def _new_cursor_async_factory (
324- db_api = None , base_factory = None , tracer_provider = None
329+ db_api : DatabaseApiAsyncIntegration | None = None ,
330+ base_factory : type [psycopg .AsyncCursor ] | None = None ,
331+ tracer_provider : TracerProvider | None = None ,
325332):
326333 if not db_api :
327334 db_api = DatabaseApiAsyncIntegration (
@@ -331,21 +338,21 @@ def _new_cursor_async_factory(
331338 version = __version__ ,
332339 tracer_provider = tracer_provider ,
333340 )
334- base_factory = base_factory or pg_async_cursor
341+ base_factory = base_factory or psycopg . AsyncCursor
335342 _cursor_tracer = CursorTracer (db_api )
336343
337344 class TracedCursorAsyncFactory (base_factory ):
338- async def execute (self , * args , ** kwargs ):
345+ async def execute (self , * args : Any , ** kwargs : Any ):
339346 return await _cursor_tracer .traced_execution (
340347 self , super ().execute , * args , ** kwargs
341348 )
342349
343- async def executemany (self , * args , ** kwargs ):
350+ async def executemany (self , * args : Any , ** kwargs : Any ):
344351 return await _cursor_tracer .traced_execution (
345352 self , super ().executemany , * args , ** kwargs
346353 )
347354
348- async def callproc (self , * args , ** kwargs ):
355+ async def callproc (self , * args : Any , ** kwargs : Any ):
349356 return await _cursor_tracer .traced_execution (
350357 self , super ().callproc , * args , ** kwargs
351358 )
0 commit comments