Skip to content

Commit 66a1747

Browse files
committed
Add type hints to Psycopg
1 parent f393546 commit 66a1747

File tree

2 files changed

+39
-32
lines changed
  • instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg

2 files changed

+39
-32
lines changed

instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/__init__.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -101,27 +101,26 @@
101101
---
102102
"""
103103

104+
from __future__ import annotations
105+
104106
import logging
105-
import typing
106-
from typing import Collection
107+
from typing import Any, Callable, Collection, TypeVar
107108

108109
import psycopg # pylint: disable=import-self
109-
from psycopg import (
110-
AsyncCursor as pg_async_cursor, # pylint: disable=import-self,no-name-in-module
111-
)
112-
from psycopg import (
113-
Cursor as pg_cursor, # pylint: disable=no-name-in-module,import-self
114-
)
115110
from psycopg.sql import Composed # pylint: disable=no-name-in-module
116111

117112
from opentelemetry.instrumentation import dbapi
118113
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
119114
from opentelemetry.instrumentation.psycopg.package import _instruments
120115
from opentelemetry.instrumentation.psycopg.version import __version__
116+
from opentelemetry.trace import TracerProvider
121117

122118
_logger = logging.getLogger(__name__)
123119
_OTEL_CURSOR_FACTORY_KEY = "_otel_orig_cursor_factory"
124120

121+
Connection = TypeVar("Connection", psycopg.Connection, psycopg.AsyncConnection)
122+
Cursor = TypeVar("Cursor", psycopg.Cursor, psycopg.AsyncCursor)
123+
125124

126125
class PsycopgInstrumentor(BaseInstrumentor):
127126
_CONNECTION_ATTRIBUTES = {
@@ -136,7 +135,7 @@ class PsycopgInstrumentor(BaseInstrumentor):
136135
def instrumentation_dependencies(self) -> Collection[str]:
137136
return _instruments
138137

139-
def _instrument(self, **kwargs):
138+
def _instrument(self, **kwargs: Any):
140139
"""Integrate with PostgreSQL Psycopg library.
141140
Psycopg: http://initd.org/psycopg/
142141
"""
@@ -181,7 +180,7 @@ def _instrument(self, **kwargs):
181180
commenter_options=commenter_options,
182181
)
183182

184-
def _uninstrument(self, **kwargs):
183+
def _uninstrument(self, **kwargs: Any):
185184
""" "Disable Psycopg instrumentation"""
186185
dbapi.unwrap_connect(psycopg, "connect") # pylint: disable=no-member
187186
dbapi.unwrap_connect(
@@ -195,7 +194,9 @@ def _uninstrument(self, **kwargs):
195194

196195
# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
197196
@staticmethod
198-
def instrument_connection(connection, tracer_provider=None):
197+
def instrument_connection(
198+
connection: Connection, tracer_provider: TracerProvider | None = None
199+
) -> Connection:
199200
if not hasattr(connection, "_is_instrumented_by_opentelemetry"):
200201
connection._is_instrumented_by_opentelemetry = False
201202

@@ -215,7 +216,7 @@ def instrument_connection(connection, tracer_provider=None):
215216

216217
# TODO(owais): check if core dbapi can do this for all dbapi implementations e.g, pymysql and mysql
217218
@staticmethod
218-
def uninstrument_connection(connection):
219+
def uninstrument_connection(connection: Connection) -> Connection:
219220
connection.cursor_factory = getattr(
220221
connection, _OTEL_CURSOR_FACTORY_KEY, None
221222
)
@@ -227,9 +228,9 @@ def uninstrument_connection(connection):
227228
class DatabaseApiIntegration(dbapi.DatabaseApiIntegration):
228229
def wrapped_connection(
229230
self,
230-
connect_method: typing.Callable[..., typing.Any],
231-
args: typing.Tuple[typing.Any, typing.Any],
232-
kwargs: typing.Dict[typing.Any, typing.Any],
231+
connect_method: Callable[..., Any],
232+
args: tuple[Any, Any],
233+
kwargs: dict[Any, Any],
233234
):
234235
"""Add object proxy to connection object."""
235236
base_cursor_factory = kwargs.pop("cursor_factory", None)
@@ -245,9 +246,9 @@ def wrapped_connection(
245246
class DatabaseApiAsyncIntegration(dbapi.DatabaseApiIntegration):
246247
async def wrapped_connection(
247248
self,
248-
connect_method: typing.Callable[..., typing.Any],
249-
args: typing.Tuple[typing.Any, typing.Any],
250-
kwargs: typing.Dict[typing.Any, typing.Any],
249+
connect_method: Callable[..., Any],
250+
args: tuple[Any, Any],
251+
kwargs: dict[Any, Any],
251252
):
252253
"""Add object proxy to connection object."""
253254
base_cursor_factory = kwargs.pop("cursor_factory", None)
@@ -263,7 +264,7 @@ async def wrapped_connection(
263264

264265

265266
class CursorTracer(dbapi.CursorTracer):
266-
def get_operation_name(self, cursor, args):
267+
def get_operation_name(self, cursor: Cursor, args: list[Any]) -> str:
267268
if not args:
268269
return ""
269270

@@ -278,7 +279,7 @@ def get_operation_name(self, cursor, args):
278279

279280
return ""
280281

281-
def get_statement(self, cursor, args):
282+
def get_statement(self, cursor: Cursor, args: list[Any]) -> str:
282283
if not args:
283284
return ""
284285

@@ -288,7 +289,11 @@ def get_statement(self, cursor, args):
288289
return statement
289290

290291

291-
def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
292+
def _new_cursor_factory(
293+
db_api: DatabaseApiIntegration | None = None,
294+
base_factory: type[psycopg.Cursor] | None = None,
295+
tracer_provider: TracerProvider | None = None,
296+
):
292297
if not db_api:
293298
db_api = DatabaseApiIntegration(
294299
__name__,
@@ -298,21 +303,21 @@ def _new_cursor_factory(db_api=None, base_factory=None, tracer_provider=None):
298303
tracer_provider=tracer_provider,
299304
)
300305

301-
base_factory = base_factory or pg_cursor
306+
base_factory = base_factory or psycopg.Cursor
302307
_cursor_tracer = CursorTracer(db_api)
303308

304309
class TracedCursorFactory(base_factory):
305-
def execute(self, *args, **kwargs):
310+
def execute(self, *args: Any, **kwargs: Any):
306311
return _cursor_tracer.traced_execution(
307312
self, super().execute, *args, **kwargs
308313
)
309314

310-
def executemany(self, *args, **kwargs):
315+
def executemany(self, *args: Any, **kwargs: Any):
311316
return _cursor_tracer.traced_execution(
312317
self, super().executemany, *args, **kwargs
313318
)
314319

315-
def callproc(self, *args, **kwargs):
320+
def callproc(self, *args: Any, **kwargs: Any):
316321
return _cursor_tracer.traced_execution(
317322
self, super().callproc, *args, **kwargs
318323
)
@@ -321,7 +326,9 @@ def callproc(self, *args, **kwargs):
321326

322327

323328
def _new_cursor_async_factory(
324-
db_api=None, base_factory=None, tracer_provider=None
329+
db_api: DatabaseApiAsyncIntegration | None = None,
330+
base_factory: type[psycopg.AsyncCursor] | None = None,
331+
tracer_provider: TracerProvider | None = None,
325332
):
326333
if not db_api:
327334
db_api = DatabaseApiAsyncIntegration(
@@ -331,21 +338,21 @@ def _new_cursor_async_factory(
331338
version=__version__,
332339
tracer_provider=tracer_provider,
333340
)
334-
base_factory = base_factory or pg_async_cursor
341+
base_factory = base_factory or psycopg.AsyncCursor
335342
_cursor_tracer = CursorTracer(db_api)
336343

337344
class TracedCursorAsyncFactory(base_factory):
338-
async def execute(self, *args, **kwargs):
345+
async def execute(self, *args: Any, **kwargs: Any):
339346
return await _cursor_tracer.traced_execution(
340347
self, super().execute, *args, **kwargs
341348
)
342349

343-
async def executemany(self, *args, **kwargs):
350+
async def executemany(self, *args: Any, **kwargs: Any):
344351
return await _cursor_tracer.traced_execution(
345352
self, super().executemany, *args, **kwargs
346353
)
347354

348-
async def callproc(self, *args, **kwargs):
355+
async def callproc(self, *args: Any, **kwargs: Any):
349356
return await _cursor_tracer.traced_execution(
350357
self, super().callproc, *args, **kwargs
351358
)

instrumentation/opentelemetry-instrumentation-psycopg/src/opentelemetry/instrumentation/psycopg/package.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from __future__ import annotations
1415

15-
16-
_instruments = ("psycopg >= 3.1.0",)
16+
_instruments: tuple[str, ...] = ("psycopg >= 3.1.0",)

0 commit comments

Comments
 (0)