Skip to content

Commit d125801

Browse files
authored
fix(asyncpg): fix _TracedConnection error when using the connect option in create_pool [backport 3.12] (#14450)
Backports #14436 to 3.12 ## Checklist - [x] PR author has checked that all the criteria below are met - The PR description includes an overview of the change - The PR description articulates the motivation for the change - The change includes tests OR the PR description describes a testing strategy - The PR description notes risks associated with the change, if any - Newly-added code is easy to change - The change follows the [library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) - The change includes or references documentation updates if necessary - Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) ## Reviewer Checklist - [x] Reviewer has checked that all the criteria below are met - Title is accurate - All changes are related to the pull request's stated goal - Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - Testing strategy adequately addresses listed risks - Newly-added code is easy to change - Release note makes sense to a user of the library - If necessary, author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)
1 parent 077d257 commit d125801

File tree

5 files changed

+547
-4
lines changed

5 files changed

+547
-4
lines changed

ddtrace/contrib/internal/asyncpg/patch.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ async def _traced_connect(asyncpg, pin, func, instance, args, kwargs):
9494
9595
connect() is instrumented and patched to return a connection proxy.
9696
"""
97+
# When using a pool, there's a connection_class args
98+
is_pool_context = "connection_class" in kwargs
99+
97100
with pin.tracer.trace(
98101
"postgres.connect", span_type=SpanTypes.SQL, service=ext_service(pin, config.asyncpg)
99102
) as span:
@@ -103,10 +106,23 @@ async def _traced_connect(asyncpg, pin, func, instance, args, kwargs):
103106
# set span.kind to the type of request being performed
104107
span.set_tag_str(SPAN_KIND, SpanKind.CLIENT)
105108

106-
# Need an ObjectProxy since Connection uses slots
107-
conn = _TracedConnection(await func(*args, **kwargs), pin)
108-
span.set_tags(_get_connection_tags(conn))
109-
return conn
109+
raw_conn = await func(*args, **kwargs)
110+
if is_pool_context:
111+
# Return the unwrapped connection to avoid _TracedConnection errors
112+
# when using a pool with a custom connect param
113+
connection_tags = _get_connection_tags(raw_conn)
114+
connection_tags[db.SYSTEM] = DBMS_NAME
115+
conn_pin = pin.clone(tags=connection_tags)
116+
conn_pin.onto(raw_conn._protocol)
117+
span.set_tags(connection_tags)
118+
# Returns a asyncpg.connection.Connection object
119+
return raw_conn
120+
else:
121+
# # Need an ObjectProxy when not using pools since Connection uses slots
122+
conn = _TracedConnection(raw_conn, pin)
123+
span.set_tags(_get_connection_tags(conn))
124+
# Returns a _TracedConnection object
125+
return conn
110126

111127

112128
async def _traced_query(pin, method, query, args, kwargs):
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
---
2+
fixes:
3+
- |
4+
asyncpg: Fix the error "Error: expected pool connect callback to return an instance of 'asyncpg.connection.Connection', got 'ddtrace.contrib.internal.asyncpg.patch._TracedConnection'`" when a pool connection" due to using the custom connect option. With this fix, postgres.connect spans will be created when this option is used.

tests/contrib/asyncpg/test_asyncpg.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from ddtrace.contrib.internal.asyncpg.patch import patch
99
from ddtrace.contrib.internal.asyncpg.patch import unpatch
1010
from ddtrace.contrib.internal.trace_utils import iswrapped
11+
from ddtrace.internal.utils.version import parse_version
1112
from ddtrace.trace import Pin
1213
from ddtrace.trace import tracer
1314
from tests.contrib.asyncio.utils import AsyncioTestCase
@@ -315,6 +316,58 @@ async def test():
315316
assert err == b""
316317

317318

319+
@pytest.mark.skipif(
320+
parse_version(getattr(asyncpg, "__version__", "0.0.0")) < (0, 30, 0),
321+
reason="the custom connect parameter for create_pool requires asyncpg >= 0.30.0",
322+
)
323+
@pytest.mark.snapshot()
324+
@pytest.mark.asyncio
325+
async def test_pool_custom_connect():
326+
"""
327+
Test that if someone uses a custom connect parameter when creating a pool,
328+
the tracer doesn't cause a _TracedConnection error or throw any other errors
329+
The connect option was introduced in 0.30.0
330+
"""
331+
database_url = f"postgresql://{POSTGRES_CONFIG['user']}:{POSTGRES_CONFIG['password']}@{POSTGRES_CONFIG['host']}:{POSTGRES_CONFIG['port']}/{POSTGRES_CONFIG['dbname']}"
332+
333+
try:
334+
# The default is 10 connection pools so the integration will create 10 spans in the snapshot
335+
pool = await asyncpg.create_pool(database_url, connect=asyncpg.connect)
336+
337+
async with pool.acquire() as conn:
338+
result = await conn.fetchval("SELECT 1")
339+
assert result == 1
340+
341+
await pool.close()
342+
except Exception as err:
343+
raise err
344+
345+
assert True
346+
347+
348+
@pytest.mark.snapshot()
349+
@pytest.mark.asyncio
350+
async def test_pool_without_custom_connect():
351+
"""
352+
Test that create_pool without the connect option still works
353+
"""
354+
database_url = f"postgresql://{POSTGRES_CONFIG['user']}:{POSTGRES_CONFIG['password']}@{POSTGRES_CONFIG['host']}:{POSTGRES_CONFIG['port']}/{POSTGRES_CONFIG['dbname']}"
355+
356+
try:
357+
# The default is 10 connection pools so the integration will create 10 spans in the snapshot
358+
pool = await asyncpg.create_pool(database_url)
359+
360+
async with pool.acquire() as conn:
361+
result = await conn.fetchval("SELECT 1")
362+
assert result == 1
363+
364+
await pool.close()
365+
except Exception as err:
366+
raise err
367+
368+
assert True
369+
370+
318371
def test_patch_unpatch_asyncpg():
319372
assert iswrapped(asyncpg.connect)
320373
assert iswrapped(asyncpg.protocol.Protocol.execute)

0 commit comments

Comments
 (0)