118118_logger  =  logging .getLogger (__name__ )
119119_OTEL_CURSOR_FACTORY_KEY  =  "_otel_orig_cursor_factory" 
120120
121- Connection  =  TypeVar ("Connection" , psycopg .Connection , psycopg .AsyncConnection )
122- Cursor  =  TypeVar ("Cursor" , psycopg .Cursor , psycopg .AsyncCursor )
121+ ConnectionT  =  TypeVar (
122+     "ConnectionT" , psycopg .Connection , psycopg .AsyncConnection 
123+ )
124+ CursorT  =  TypeVar ("CursorT" , psycopg .Cursor , psycopg .AsyncCursor )
123125
124126
125127class  PsycopgInstrumentor (BaseInstrumentor ):
@@ -195,8 +197,8 @@ def _uninstrument(self, **kwargs: Any):
195197    # TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql 
196198    @staticmethod  
197199    def  instrument_connection (
198-         connection : Connection , tracer_provider : TracerProvider  |  None  =  None 
199-     ) ->  Connection :
200+         connection : ConnectionT , tracer_provider : TracerProvider  |  None  =  None 
201+     ) ->  ConnectionT :
200202        if  not  hasattr (connection , "_is_instrumented_by_opentelemetry" ):
201203            connection ._is_instrumented_by_opentelemetry  =  False 
202204
@@ -216,7 +218,7 @@ def instrument_connection(
216218
217219    # TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql 
218220    @staticmethod  
219-     def  uninstrument_connection (connection : Connection ) ->  Connection :
221+     def  uninstrument_connection (connection : ConnectionT ) ->  ConnectionT :
220222        connection .cursor_factory  =  getattr (
221223            connection , _OTEL_CURSOR_FACTORY_KEY , None 
222224        )
@@ -264,7 +266,7 @@ async def wrapped_connection(
264266
265267
266268class  CursorTracer (dbapi .CursorTracer ):
267-     def  get_operation_name (self , cursor : Cursor , args : list [Any ]) ->  str :
269+     def  get_operation_name (self , cursor : CursorT , args : list [Any ]) ->  str :
268270        if  not  args :
269271            return  "" 
270272
@@ -279,7 +281,7 @@ def get_operation_name(self, cursor: Cursor, args: list[Any]) -> str:
279281
280282        return  "" 
281283
282-     def  get_statement (self , cursor : Cursor , args : list [Any ]) ->  str :
284+     def  get_statement (self , cursor : CursorT , args : list [Any ]) ->  str :
283285        if  not  args :
284286            return  "" 
285287
0 commit comments