Skip to content

Commit b433989

Browse files
author
Stefan Tjarks
committed
contrib.psycopg: handle extensions.adapt
this function does c-level checks of the type of the argument so it must be a raw psycopg connection passed in. This adds a hook to transparently downgrade.
1 parent 3e78683 commit b433989

File tree

2 files changed

+44
-0
lines changed

2 files changed

+44
-0
lines changed

ddtrace/contrib/psycopg/patch.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,28 @@ def _unroll_args(obj, scope=None):
9494
return func(obj, scope) if scope else func(obj)
9595

9696

97+
def _extensions_adapt(func, _, args, kwargs):
98+
adapt = func(*args, **kwargs)
99+
if hasattr(adapt, 'prepare'):
100+
return AdapterWrapper(adapt)
101+
return adapt
102+
103+
104+
class AdapterWrapper(wrapt.ObjectProxy):
105+
def prepare(self, *args, **kwargs):
106+
func = self.__wrapped__.prepare
107+
if not args:
108+
return func(*args, **kwargs)
109+
conn = args[0]
110+
111+
# prepare performs a c-level check of the object type so
112+
# we must be sure to pass in the actual db connection
113+
if isinstance(conn, wrapt.ObjectProxy):
114+
conn = conn.__wrapped__
115+
116+
return func(conn, *args[1:], **kwargs)
117+
118+
97119
# extension hooks
98120
_psycopg2_extensions = [
99121
(psycopg2.extensions.register_type,
@@ -105,4 +127,7 @@ def _unroll_args(obj, scope=None):
105127
(psycopg2._json.register_type,
106128
psycopg2._json, 'register_type',
107129
_extensions_register_type),
130+
(psycopg2.extensions.adapt,
131+
psycopg2.extensions, 'adapt',
132+
_extensions_adapt),
108133
]

tests/contrib/psycopg/test_psycopg.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
# 3p
55
import psycopg2
6+
from psycopg2 import _psycopg
7+
from psycopg2 import extensions
68
from psycopg2 import extras
79
from nose.tools import eq_
810

@@ -120,6 +122,23 @@ def test_manual_wrap_extension_types(self):
120122
# TypeError: argument 2 must be a connection, cursor or None
121123
extras.register_default_json(conn)
122124

125+
126+
def test_manual_wrap_extension_adapt(self):
127+
conn, _ = self._get_conn_and_tracer()
128+
# NOTE: this will crash if it doesn't work.
129+
# items = _ext.adapt([1, 2, 3])
130+
# items.prepare(conn)
131+
# TypeError: argument 2 must be a connection, cursor or None
132+
items = extensions.adapt([1, 2, 3])
133+
items.prepare(conn)
134+
135+
# NOTE: this will crash if it doesn't work.
136+
# binary = _ext.adapt(b'12345)
137+
# binary.prepare(conn)
138+
# TypeError: argument 2 must be a connection, cursor or None
139+
binary = extensions.adapt(b'12345')
140+
binary.prepare(conn)
141+
123142
def test_connect_factory(self):
124143
tracer = get_dummy_tracer()
125144

0 commit comments

Comments
 (0)