Skip to content

Commit eef9d09

Browse files
wconti27mabdinur
andauthored
fix(psycopg3): resolve attribute error raised when psycopg3 is patched [backport #5833 to 1.13] (#5866)
Backport of #5833 to 1.13 Resolves: #5821 - Fixes incorrect dbapi_async and psycopg async patches. - Adds tests to ensure new patching is correct. Changes test to use mark_asyncio - Change psycopg async connection patch to use a factory, and get the pin object from the psycopg module - Fix psycop unpatch to correctly rebind the original connect and async connect class methods. - Update dbapi_async to check during `TracedAsyncCursor.__init__` if the cursor is a `FetchTracedAsyncCursor`, in which case we don't enable analytics sample rate stats. Previously this check was for a `FetchTracedCursor` type, which was incorrect. ## Checklist - [x] Change(s) are motivated and described in the PR description. - [x] Testing strategy is described if automated tests are not included in the PR. - [x] Risk is outlined (performance impact, potential for breakage, maintainability, etc). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] [Library release note guidelines](https://ddtrace.readthedocs.io/en/stable/contributing.html#Release-Note-Guidelines) are followed. - [x] Documentation is included (in-code, generated user docs, [public corp docs](https://github.com/DataDog/documentation/)). - [x] OPTIONAL: PR description includes explicit acknowledgement of the performance implications of the change as reported in the benchmarks PR comment. ## Reviewer Checklist - [x] Title is accurate. - [x] No unnecessary changes are introduced. - [x] Description motivates each change. - [x] Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes unless absolutely necessary. - [x] Testing strategy adequately addresses listed risk(s). - [x] Change is maintainable (easy to change, telemetry, documentation). - [x] Release note makes sense to a user of the library. - [x] Reviewer has explicitly acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment. --------- Co-authored-by: Munir Abdinur <[email protected]>
1 parent ac444eb commit eef9d09

File tree

12 files changed

+299
-254
lines changed

12 files changed

+299
-254
lines changed

.riot/requirements/119287b.txt

Lines changed: 0 additions & 27 deletions
This file was deleted.

.riot/requirements/1ad8e2d.txt

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
#
55
# pip-compile --no-annotate --resolver=backtracking .riot/requirements/1ad8e2d.in
66
#
7-
attrs==22.2.0
8-
coverage[toml]==7.2.2
9-
docutils==0.19
7+
attrs==23.1.0
8+
coverage[toml]==7.2.5
9+
docutils==0.20
1010
exceptiongroup==1.1.1
1111
flake8==3.8.4
1212
flake8-blind-except==0.2.1
@@ -19,15 +19,15 @@ hypothesis==6.45.0
1919
iniconfig==2.0.0
2020
isort==5.12.0
2121
mccabe==0.6.1
22-
mock==5.0.1
22+
mock==5.0.2
2323
opentracing==2.4.0
24-
packaging==23.0
24+
packaging==23.1
2525
pluggy==1.0.0
2626
pycodestyle==2.6.0
2727
pydocstyle==6.3.0
2828
pyflakes==2.2.0
29-
pygments==2.14.0
30-
pytest==7.2.2
29+
pygments==2.15.1
30+
pytest==7.3.1
3131
pytest-cov==4.0.0
3232
pytest-mock==3.10.0
3333
restructuredtext-lint==1.4.0

.riot/requirements/bc7c1d4.txt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,19 @@
44
#
55
# pip-compile --no-annotate --resolver=backtracking .riot/requirements/bc7c1d4.in
66
#
7-
attrs==22.2.0
8-
coverage[toml]==7.2.2
7+
attrs==23.1.0
8+
coverage[toml]==7.2.5
99
envier==0.4.0
1010
exceptiongroup==1.1.1
1111
hypothesis==6.45.0
1212
iniconfig==2.0.0
13-
mock==5.0.1
13+
mock==5.0.2
1414
mypy==0.991
1515
mypy-extensions==1.0.0
1616
opentracing==2.4.0
17-
packaging==23.0
17+
packaging==23.1
1818
pluggy==1.0.0
19-
pytest==7.2.2
19+
pytest==7.3.1
2020
pytest-cov==4.0.0
2121
pytest-mock==3.10.0
2222
sortedcontainers==2.4.0

ddtrace/contrib/dbapi_async/__init__.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from ddtrace import config
12
from ddtrace.appsec.iast._util import _is_iast_enabled
23
from ddtrace.internal.constants import COMPONENT
34

@@ -10,7 +11,6 @@
1011
from ...internal.utils import ArgumentError
1112
from ...internal.utils import get_argument_value
1213
from ...pin import Pin
13-
from ..dbapi import FetchTracedCursor
1414
from ..dbapi import TracedConnection
1515
from ..dbapi import TracedCursor
1616
from ..trace_utils import ext_service
@@ -21,6 +21,20 @@
2121

2222

2323
class TracedAsyncCursor(TracedCursor):
24+
async def __aenter__(self):
25+
# previous versions of the dbapi didn't support context managers. let's
26+
# reference the func that would be called to ensure that error
27+
# messages will be the same.
28+
await self.__wrapped__.__aenter__()
29+
30+
return self
31+
32+
async def __aexit__(self, exc_type, exc_val, exc_tb):
33+
# previous versions of the dbapi didn't support context managers. let's
34+
# reference the func that would be called to ensure that error
35+
# messages will be the same.
36+
return await self.__wrapped__.__aexit__()
37+
2438
async def _trace_method(self, method, name, resource, extra_tags, dbm_propagator, *args, **kwargs):
2539
"""
2640
Internal function to trace the call to the underlying cursor method
@@ -61,7 +75,7 @@ async def _trace_method(self, method, name, resource, extra_tags, dbm_propagator
6175
SqlInjection.report(evidence_value=args[0])
6276

6377
# set analytics sample rate if enabled but only for non-FetchTracedCursor
64-
if not isinstance(self, FetchTracedCursor):
78+
if not isinstance(self, FetchTracedAsyncCursor):
6579
s.set_tag(ANALYTICS_SAMPLE_RATE_KEY, self._self_config.get_analytics_sample_rate())
6680

6781
if dbm_propagator:
@@ -147,8 +161,11 @@ async def fetchmany(self, *args, **kwargs):
147161

148162

149163
class TracedAsyncConnection(TracedConnection):
150-
traced_cursor_cls = TracedCursor
151-
traced_fetch_cursor_cls = FetchTracedCursor
164+
def __init__(self, conn, pin=None, cfg=config.dbapi2, cursor_cls=None):
165+
if not cursor_cls:
166+
# Do not trace `fetch*` methods by default
167+
cursor_cls = FetchTracedAsyncCursor if cfg.trace_fetch_methods else TracedAsyncCursor
168+
super(TracedAsyncConnection, self).__init__(conn, pin, cfg, cursor_cls)
152169

153170
async def __aenter__(self):
154171
"""Context management is not defined by the dbapi spec.
@@ -165,7 +182,7 @@ async def __aenter__(self):
165182
- pymysql doesn't implement it.
166183
- sqlite3 returns the connection.
167184
"""
168-
r = self.__wrapped__.__aenter__()
185+
r = await self.__wrapped__.__aenter__()
169186

170187
if hasattr(r, "cursor"):
171188
# r is Connection-like.
@@ -190,7 +207,7 @@ async def __aenter__(self):
190207
pin = Pin.get_from(self)
191208
if not pin:
192209
return r
193-
return await self._self_cursor_cls(r, pin, self._self_config)
210+
return self._self_cursor_cls(r, pin, self._self_config)
194211
else:
195212
# Otherwise r is some other object, so maintain the functionality
196213
# of the original.
@@ -204,10 +221,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
204221
# previous versions of the dbapi didn't support context managers. let's
205222
# reference the func that would be called to ensure that errors
206223
# messages will be the same.
207-
self.__wrapped__.__aenter__
208-
209-
# and finally, yield the traced cursor.
210-
return self
224+
return await self.__wrapped__.__aexit__()
211225

212226
async def _trace_method(self, method, name, extra_tags, *args, **kwargs):
213227
pin = Pin.get_from(self)

ddtrace/contrib/psycopg/async_connection.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,41 +29,38 @@ async def execute(self, *args, **kwargs):
2929

3030
async def patched_execute(*args, **kwargs):
3131
try:
32-
cur = await self.cursor()
32+
cur = self.cursor()
3333
if kwargs.get("binary", None):
3434
cur.format = 1 # set to 1 for binary or 0 if not
35-
return cur.execute(*args, **kwargs)
35+
return await cur.execute(*args, **kwargs)
3636
except Exception as ex:
3737
raise ex.with_traceback(None)
3838

3939
return await self._trace_method(patched_execute, span_name, {}, *args, **kwargs)
4040

4141

42-
async def patched_connect_async(connect_func, _, args, kwargs):
43-
traced_conn_cls = Psycopg3TracedAsyncConnection
42+
def patched_connect_async_factory(psycopg_module):
43+
async def patched_connect_async(connect_func, _, args, kwargs):
44+
traced_conn_cls = Psycopg3TracedAsyncConnection
4445

45-
_config = globals()["config"]._config
46-
module_name = (
47-
connect_func.__module__
48-
if len(connect_func.__module__.split(".")) == 1
49-
else connect_func.__module__.split(".")[0]
50-
)
51-
pin = Pin.get_from(_config[module_name].base_module)
46+
pin = Pin.get_from(psycopg_module)
5247

53-
if not pin or not pin.enabled() or not pin._config.trace_connect:
54-
conn = await connect_func(*args, **kwargs)
55-
else:
56-
with pin.tracer.trace(
57-
"{}.{}".format(connect_func.__module__, connect_func.__name__),
58-
service=ext_service(pin, pin._config),
59-
span_type=SpanTypes.SQL,
60-
) as span:
61-
span.set_tag_str(SPAN_KIND, SpanKind.CLIENT)
62-
span.set_tag_str(COMPONENT, pin._config.integration_name)
63-
if span.get_tag(db.SYSTEM) is None:
64-
span.set_tag_str(db.SYSTEM, pin._config.dbms_name)
65-
66-
span.set_tag(SPAN_MEASURED_KEY)
48+
if not pin or not pin.enabled() or not pin._config.trace_connect:
6749
conn = await connect_func(*args, **kwargs)
50+
else:
51+
with pin.tracer.trace(
52+
"{}.{}".format(connect_func.__module__, connect_func.__name__),
53+
service=ext_service(pin, pin._config),
54+
span_type=SpanTypes.SQL,
55+
) as span:
56+
span.set_tag_str(SPAN_KIND, SpanKind.CLIENT)
57+
span.set_tag_str(COMPONENT, pin._config.integration_name)
58+
if span.get_tag(db.SYSTEM) is None:
59+
span.set_tag_str(db.SYSTEM, pin._config.dbms_name)
60+
61+
span.set_tag(SPAN_MEASURED_KEY)
62+
conn = await connect_func(*args, **kwargs)
63+
64+
return patch_conn(conn, pin=pin, traced_conn_cls=traced_conn_cls)
6865

69-
return patch_conn(conn, pin=pin, traced_conn_cls=traced_conn_cls)
66+
return patched_connect_async

ddtrace/contrib/psycopg/patch.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
try:
11-
from ddtrace.contrib.psycopg.async_connection import patched_connect_async
11+
from ddtrace.contrib.psycopg.async_connection import patched_connect_async_factory
1212
from ddtrace.contrib.psycopg.async_cursor import Psycopg3FetchTracedAsyncCursor
1313
from ddtrace.contrib.psycopg.async_cursor import Psycopg3TracedAsyncCursor
1414
# catch async function syntax errors when using Python<3.7 with no async support
@@ -28,6 +28,18 @@
2828
from ...propagation._database_monitoring import _DBM_Propagator
2929

3030

31+
try:
32+
psycopg_import = import_module("psycopg")
33+
34+
# must get the original connect class method from the class __dict__ to use later in unpatch
35+
# Python 3.11 and wrapt result in the class method being rebinded as an instance method when
36+
# using unwrap
37+
_original_connect = psycopg_import.Connection.__dict__["connect"]
38+
_original_async_connect = psycopg_import.AsyncConnection.__dict__["connect"]
39+
except ImportError:
40+
pass
41+
42+
3143
def _psycopg_sql_injector(dbm_comment, sql_statement):
3244
for psycopg_module in config.psycopg["_patched_modules"]:
3345
if (
@@ -45,7 +57,6 @@ def _psycopg_sql_injector(dbm_comment, sql_statement):
4557
_default_service="postgres",
4658
_dbapi_span_name_prefix="postgres",
4759
_patched_modules=set(),
48-
_patched_functions=dict(),
4960
trace_fetch_methods=asbool(
5061
os.getenv("DD_PSYCOPG_TRACE_FETCH_METHODS", default=False)
5162
or os.getenv("DD_PSYCOPG2_TRACE_FETCH_METHODS", default=False)
@@ -72,7 +83,6 @@ def _psycopg_modules():
7283
pass
7384

7485

75-
# NB: We are patching the default elasticsearch.transport module
7686
def patch():
7787
for psycopg_module in _psycopg_modules():
7888
_patch(psycopg_module)
@@ -89,7 +99,6 @@ def _patch(psycopg_module):
8999
Pin(_config=config.psycopg).onto(psycopg_module)
90100

91101
if psycopg_module.__name__ == "psycopg2":
92-
config.psycopg["_patched_functions"].update({"psycopg2.connect": psycopg_module.connect})
93102

94103
# patch all psycopg2 extensions
95104
_psycopg2_extensions = get_psycopg2_extensions(psycopg_module)
@@ -100,23 +109,14 @@ def _patch(psycopg_module):
100109

101110
config.psycopg["_patched_modules"].add(psycopg_module)
102111
else:
103-
config.psycopg["_patched_functions"].update(
104-
{
105-
"psycopg.connect": psycopg_module.connect,
106-
"psycopg.Connection": psycopg_module.Connection,
107-
"psycopg.Cursor": psycopg_module.Cursor,
108-
"psycopg.AsyncConnection": psycopg_module.AsyncConnection,
109-
"psycopg.AsyncCursor": psycopg_module.AsyncCursor,
110-
}
111-
)
112112

113113
_w(psycopg_module, "connect", patched_connect_factory(psycopg_module))
114-
_w(psycopg_module.Connection, "connect", patched_connect_factory(psycopg_module))
115114
_w(psycopg_module, "Cursor", init_cursor_from_connection_factory(psycopg_module))
116-
117-
_w(psycopg_module.AsyncConnection, "connect", patched_connect_async)
118115
_w(psycopg_module, "AsyncCursor", init_cursor_from_connection_factory(psycopg_module))
119116

117+
_w(psycopg_module.Connection, "connect", patched_connect_factory(psycopg_module))
118+
_w(psycopg_module.AsyncConnection, "connect", patched_connect_async_factory(psycopg_module))
119+
120120
config.psycopg["_patched_modules"].add(psycopg_module)
121121

122122

@@ -139,18 +139,10 @@ def _unpatch(psycopg_module):
139139
_u(psycopg_module, "Cursor")
140140
_u(psycopg_module, "AsyncCursor")
141141

142-
try:
143-
_u(psycopg_module.Connection, "connect")
144-
_u(psycopg_module.AsyncConnection, "connect")
145-
146-
# _u throws an attribute error for Python 3.11 on method objects because of
147-
# no __get__ method on the BoundFunctionWrapper
148-
except AttributeError:
149-
_original_connection_class = config.psycopg["_patched_functions"]["psycopg.Connection"]
150-
_original_asyncconnection_class = config.psycopg["_patched_functions"]["psycopg.AsyncConnection"]
151-
152-
psycopg_module.Connection = _original_connection_class
153-
psycopg_module.AsyncConnection = _original_asyncconnection_class
142+
# _u throws an attribute error for Python 3.11, no __get__ on the BoundFunctionWrapper
143+
# unlike Python Class Methods which implement __get__
144+
setattr(psycopg_module.Connection, "connect", _original_connect)
145+
setattr(psycopg_module.AsyncConnection, "connect", _original_async_connect)
154146

155147
pin = Pin.get_from(psycopg_module)
156148
if pin:
@@ -162,7 +154,14 @@ def init_cursor_from_connection(wrapped_cursor_cls, _, args, kwargs):
162154
connection = kwargs.pop("connection", None)
163155
if not connection:
164156
args = list(args)
165-
connection = args.pop(next((i for i, x in enumerate(args) if isinstance(x, dbapi.TracedConnection)), None))
157+
index = next((i for i, x in enumerate(args) if isinstance(x, dbapi.TracedConnection)), None)
158+
if index is not None:
159+
connection = args.pop(index)
160+
161+
# if we do not have an example of a traced connection, call the original cursor function
162+
if not connection:
163+
return wrapped_cursor_cls(*args, **kwargs)
164+
166165
pin = Pin.get_from(connection).clone()
167166
cfg = config.psycopg
168167

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
psycopg: Resolves an issue where an AttributeError is raised when ``psycopg.AsyncConnection`` is traced.
5+

riotfile.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2403,7 +2403,7 @@ def select_pys(min_version=MIN_PYTHON_VERSION, max_version=MAX_PYTHON_VERSION):
24032403
Venv(
24042404
name="dbapi_async",
24052405
command="pytest {cmdargs} tests/contrib/dbapi_async",
2406-
pys=select_pys(min_version="3.5"),
2406+
pys=select_pys(min_version="3.6"),
24072407
env={
24082408
"DD_IAST_REQUEST_SAMPLING": "100", # Override default 30% to analyze all IAST requests
24092409
},
@@ -2412,7 +2412,7 @@ def select_pys(min_version=MIN_PYTHON_VERSION, max_version=MAX_PYTHON_VERSION):
24122412
},
24132413
venvs=[
24142414
Venv(
2415-
pys=["3.5", "3.6", "3.8", "3.9", "3.10"],
2415+
pys=["3.6", "3.8", "3.9", "3.10"],
24162416
),
24172417
Venv(pys=["3.11"], pkgs={"attrs": latest}),
24182418
],

0 commit comments

Comments
 (0)