Skip to content

Commit 94d699a

Browse files
TimPansinolrafeei
andauthored
Fix crashes in aioredis transactions (#633)
* Fix aioredis wrappers for transactions * Refactor aioredis test setup * Expand aioredis tox testing * Add no harm tests for aioredis transactions Co-authored-by: Lalleh Rafeei <[email protected]>
1 parent 8f84692 commit 94d699a

File tree

12 files changed

+309
-180
lines changed

12 files changed

+309
-180
lines changed

newrelic/hooks/datastore_aioredis.py

Lines changed: 88 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,25 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from newrelic.api.datastore_trace import DatastoreTrace
15+
from newrelic.api.datastore_trace import DatastoreTrace, DatastoreTraceWrapper
1616
from newrelic.api.time_trace import current_trace
1717
from newrelic.api.transaction import current_transaction
18-
from newrelic.common.object_wrapper import wrap_function_wrapper
18+
from newrelic.common.object_wrapper import wrap_function_wrapper, function_wrapper, FunctionWrapper
1919
from newrelic.hooks.datastore_redis import (
2020
_redis_client_methods,
2121
_redis_multipart_commands,
2222
_redis_operation_re,
2323
)
2424

25+
from newrelic.common.async_wrapper import async_wrapper
26+
27+
import aioredis
28+
29+
try:
30+
AIOREDIS_VERSION = tuple(int(x) for x in getattr(aioredis, "__version__").split("."))
31+
except Exception:
32+
AIOREDIS_VERSION = (0, 0, 0)
33+
2534

2635
def _conn_attrs_to_dict(connection):
2736
host = getattr(connection, "host", None)
@@ -45,13 +54,36 @@ def _instance_info(kwargs):
4554

4655

4756
def _wrap_AioRedis_method_wrapper(module, instance_class_name, operation):
48-
async def _nr_wrapper_AioRedis_method_(wrapped, instance, args, kwargs):
57+
58+
@function_wrapper
59+
async def _nr_wrapper_AioRedis_async_method_(wrapped, instance, args, kwargs):
4960
transaction = current_transaction()
5061
if transaction is None:
5162
return await wrapped(*args, **kwargs)
5263

5364
with DatastoreTrace(product="Redis", target=None, operation=operation):
5465
return await wrapped(*args, **kwargs)
66+
67+
def _nr_wrapper_AioRedis_method_(wrapped, instance, args, kwargs):
68+
# Check for transaction and return early if found.
69+
# Method will return synchronously without executing,
70+
# it will be added to the command stack and run later.
71+
if AIOREDIS_VERSION < (2,):
72+
# AioRedis v1 uses a RedisBuffer instead of a real connection for queueing up pipeline commands
73+
from aioredis.commands.transaction import _RedisBuffer
74+
if isinstance(instance._pool_or_conn, _RedisBuffer):
75+
# Method will return synchronously without executing,
76+
# it will be added to the command stack and run later.
77+
return wrapped(*args, **kwargs)
78+
else:
79+
# AioRedis v2 uses a Pipeline object for a client and internally queues up pipeline commands
80+
from aioredis.client import Pipeline
81+
if isinstance(instance, Pipeline):
82+
return wrapped(*args, **kwargs)
83+
84+
# Method should be run when awaited, therefore we wrap in an async wrapper.
85+
return _nr_wrapper_AioRedis_async_method_(wrapped)(*args, **kwargs)
86+
5587

5688
name = "%s.%s" % (instance_class_name, operation)
5789
wrap_function_wrapper(module, name, _nr_wrapper_AioRedis_method_)
@@ -108,6 +140,58 @@ async def wrap_Connection_send_command(wrapped, instance, args, kwargs):
108140
return await wrapped(*args, **kwargs)
109141

110142

143+
def wrap_RedisConnection_execute(wrapped, instance, args, kwargs):
144+
# RedisConnection in aioredis v1 returns a future instead of using coroutines
145+
transaction = current_transaction()
146+
if not transaction:
147+
return wrapped(*args, **kwargs)
148+
149+
host, port_path_or_id, db = (None, None, None)
150+
151+
try:
152+
dt = transaction.settings.datastore_tracer
153+
if dt.instance_reporting.enabled or dt.database_name_reporting.enabled:
154+
conn_kwargs = _conn_attrs_to_dict(instance)
155+
host, port_path_or_id, db = _instance_info(conn_kwargs)
156+
except Exception:
157+
pass
158+
159+
# Older Redis clients would when sending multi part commands pass
160+
# them in as separate arguments to send_command(). Need to therefore
161+
# detect those and grab the next argument from the set of arguments.
162+
163+
operation = args[0].strip().lower()
164+
165+
# If it's not a multi part command, there's no need to trace it, so
166+
# we can return early.
167+
168+
if operation.split()[0] not in _redis_multipart_commands: # Set the datastore info on the DatastoreTrace containing this function call.
169+
trace = current_trace()
170+
171+
# Find DatastoreTrace no matter how many other traces are inbetween
172+
while trace is not None and not isinstance(trace, DatastoreTrace):
173+
trace = getattr(trace, "parent", None)
174+
175+
if trace is not None:
176+
trace.host = host
177+
trace.port_path_or_id = port_path_or_id
178+
trace.database_name = db
179+
180+
return wrapped(*args, **kwargs)
181+
182+
# Convert multi args to single arg string
183+
184+
if operation in _redis_multipart_commands and len(args) > 1:
185+
operation = "%s %s" % (operation, args[1].strip().lower())
186+
187+
operation = _redis_operation_re.sub("_", operation)
188+
189+
with DatastoreTrace(
190+
product="Redis", target=None, operation=operation, host=host, port_path_or_id=port_path_or_id, database_name=db
191+
):
192+
return wrapped(*args, **kwargs)
193+
194+
111195
def instrument_aioredis_client(module):
112196
# StrictRedis is just an alias of Redis, no need to wrap it as well.
113197
if hasattr(module, "Redis"):
@@ -124,4 +208,4 @@ def instrument_aioredis_connection(module):
124208

125209
if hasattr(module, "RedisConnection"):
126210
if hasattr(module.RedisConnection, "execute"):
127-
wrap_function_wrapper(module, "RedisConnection.execute", wrap_Connection_send_command)
211+
wrap_function_wrapper(module, "RedisConnection.execute", wrap_RedisConnection_execute)

tests/datastore_aioredis/conftest.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,22 @@
1313
# limitations under the License.
1414

1515
import aioredis
16-
import asyncio
1716
import pytest
1817

18+
from testing_support.db_settings import redis_settings
19+
20+
from testing_support.fixture.event_loop import event_loop as loop
1921
from testing_support.fixtures import ( # noqa: F401
2022
code_coverage_fixture,
2123
collector_agent_registration_fixture,
2224
collector_available_fixture,
2325
)
2426

2527
AIOREDIS_VERSION = tuple(int(x) for x in aioredis.__version__.split(".")[:2])
28+
SKIPIF_AIOREDIS_V1 = pytest.mark.skipif(AIOREDIS_VERSION < (2,), reason="Unsupported aioredis version.")
29+
SKIPIF_AIOREDIS_V2 = pytest.mark.skipif(AIOREDIS_VERSION >= (2,), reason="Unsupported aioredis version.")
30+
DB_SETTINGS = redis_settings()[0]
31+
2632

2733
_coverage_source = [
2834
"newrelic.hooks.datastore_aioredis",
@@ -45,10 +51,19 @@
4551
)
4652

4753

48-
event_loop = asyncio.get_event_loop()
49-
asyncio.set_event_loop(event_loop)
50-
51-
52-
@pytest.fixture()
53-
def loop():
54-
yield event_loop
54+
@pytest.fixture(params=("Redis", "StrictRedis"))
55+
def client(request, loop):
56+
if AIOREDIS_VERSION >= (2, 0):
57+
if request.param == "Redis":
58+
return aioredis.Redis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0)
59+
elif request.param == "StrictRedis":
60+
return aioredis.StrictRedis(host=DB_SETTINGS["host"], port=DB_SETTINGS["port"], db=0)
61+
else:
62+
raise NotImplementedError()
63+
else:
64+
if request.param == "Redis":
65+
return loop.run_until_complete(aioredis.create_redis("redis://%s:%d" % (DB_SETTINGS["host"], DB_SETTINGS["port"]), db=0))
66+
elif request.param == "StrictRedis":
67+
pytest.skip("StrictRedis not implemented.")
68+
else:
69+
raise NotImplementedError()

tests/datastore_aioredis/test_custom_conn_pool.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,6 @@
1717
will not result in an error.
1818
"""
1919

20-
import asyncio
21-
import pytest
22-
import aioredis
23-
24-
from conftest import event_loop, loop, AIOREDIS_VERSION
25-
2620
from newrelic.api.background_task import background_task
2721

2822
# from testing_support.fixture.event_loop import event_loop as loop
@@ -43,7 +37,7 @@ async def get_connection(self, name=None, *keys, **options):
4337
return self.connection
4438

4539
async def release(self, connection):
46-
self.connection.disconnect()
40+
await self.connection.disconnect()
4741

4842
async def execute(self, *args, **kwargs):
4943
return await self.connection.execute(*args, **kwargs)
@@ -105,18 +99,6 @@ async def exercise_redis(client):
10599
await client.execute("CLIENT", "LIST")
106100

107101

108-
if AIOREDIS_VERSION >= (2, 0):
109-
clients = [
110-
aioredis.Redis(host=DB_SETTINGS["host"], port=_port, db=0),
111-
aioredis.StrictRedis(host=DB_SETTINGS["host"], port=_port, db=0),
112-
]
113-
else:
114-
clients = [
115-
event_loop.run_until_complete(aioredis.create_redis("redis://%s:%d" % (DB_SETTINGS["host"], _port), db=0)),
116-
]
117-
118-
119-
@pytest.mark.parametrize("client", clients)
120102
@override_application_settings(_enable_instance_settings)
121103
@validate_transaction_metrics(
122104
"test_custom_conn_pool:test_fake_conn_pool_enable_instance",
@@ -125,7 +107,7 @@ async def exercise_redis(client):
125107
background_task=True,
126108
)
127109
@background_task()
128-
def test_fake_conn_pool_enable_instance(client, loop):
110+
def test_fake_conn_pool_enable_instance(client, loop, monkeypatch):
129111
# Get a real connection
130112
conn = getattr(client, "_pool_or_conn", None)
131113
if conn is None:
@@ -135,14 +117,13 @@ def test_fake_conn_pool_enable_instance(client, loop):
135117
# have the `connection_kwargs` attribute.
136118

137119
fake_pool = FakeConnectionPool(conn)
138-
client.connection_pool = fake_pool
139-
client._pool_or_conn = fake_pool
120+
monkeypatch.setattr(client, "connection_pool", fake_pool, raising=False)
121+
monkeypatch.setattr(client, "_pool_or_conn", fake_pool, raising=False)
140122
assert not hasattr(client.connection_pool, "connection_kwargs")
141123

142124
loop.run_until_complete(exercise_redis(client))
143125

144126

145-
@pytest.mark.parametrize("client", clients)
146127
@override_application_settings(_disable_instance_settings)
147128
@validate_transaction_metrics(
148129
"test_custom_conn_pool:test_fake_conn_pool_disable_instance",
@@ -151,15 +132,18 @@ def test_fake_conn_pool_enable_instance(client, loop):
151132
background_task=True,
152133
)
153134
@background_task()
154-
def test_fake_conn_pool_disable_instance(client, loop):
135+
def test_fake_conn_pool_disable_instance(client, loop, monkeypatch):
155136
# Get a real connection
156-
conn = loop.run_until_complete(client.connection_pool.get_connection("GET"))
137+
conn = getattr(client, "_pool_or_conn", None)
138+
if conn is None:
139+
conn = loop.run_until_complete(client.connection_pool.get_connection("GET"))
157140

158141
# Replace the original connection pool with one that doesn't
159142
# have the `connection_kwargs` attribute.
160143

161144
fake_pool = FakeConnectionPool(conn)
162-
client.connection_pool = fake_pool
145+
monkeypatch.setattr(client, "connection_pool", fake_pool, raising=False)
146+
monkeypatch.setattr(client, "_pool_or_conn", fake_pool, raising=False)
163147
assert not hasattr(client.connection_pool, "connection_kwargs")
164148

165149
loop.run_until_complete(exercise_redis(client))

tests/datastore_aioredis/test_execute_command.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@
1313
# limitations under the License.
1414

1515
import pytest
16-
import aioredis
1716
from newrelic.api.background_task import background_task
1817

1918
from testing_support.fixtures import validate_transaction_metrics, override_application_settings
20-
from conftest import event_loop, loop, AIOREDIS_VERSION
19+
from conftest import AIOREDIS_VERSION
2120
from testing_support.db_settings import redis_settings
2221
from testing_support.util import instance_hostname
2322

@@ -70,19 +69,7 @@ async def exercise_redis_single_arg(client):
7069
await client.execute_command("CLIENT LIST")
7170

7271

73-
if AIOREDIS_VERSION >= (2, 0):
74-
clients = [
75-
aioredis.Redis(host=DB_SETTINGS["host"], port=_port, db=0),
76-
aioredis.StrictRedis(host=DB_SETTINGS["host"], port=_port, db=0),
77-
]
78-
else:
79-
clients = [
80-
event_loop.run_until_complete(aioredis.create_redis("redis://%s:%d" % (DB_SETTINGS["host"], _port), db=0)),
81-
]
82-
83-
8472
@SKIP_IF_AIOREDIS_V1
85-
@pytest.mark.parametrize("client", clients)
8673
@override_application_settings(_enable_instance_settings)
8774
@validate_transaction_metrics(
8875
"test_execute_command:test_redis_execute_command_as_one_arg_enable",
@@ -96,7 +83,6 @@ def test_redis_execute_command_as_one_arg_enable(client, loop):
9683

9784

9885
@SKIP_IF_AIOREDIS_V1
99-
@pytest.mark.parametrize("client", clients)
10086
@override_application_settings(_disable_instance_settings)
10187
@validate_transaction_metrics(
10288
"test_execute_command:test_redis_execute_command_as_one_arg_disable",
@@ -109,7 +95,6 @@ def test_redis_execute_command_as_one_arg_disable(client, loop):
10995
loop.run_until_complete(exercise_redis_single_arg(client))
11096

11197

112-
@pytest.mark.parametrize("client", clients)
11398
@override_application_settings(_enable_instance_settings)
11499
@validate_transaction_metrics(
115100
"test_execute_command:test_redis_execute_command_as_two_args_enable",
@@ -122,7 +107,6 @@ def test_redis_execute_command_as_two_args_enable(client, loop):
122107
loop.run_until_complete(exercise_redis_multi_args(client))
123108

124109

125-
@pytest.mark.parametrize("client", clients)
126110
@override_application_settings(_disable_instance_settings)
127111
@validate_transaction_metrics(
128112
"test_execute_command:test_redis_execute_command_as_two_args_disable",

tests/datastore_aioredis/test_get_and_set.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import pytest
16-
import aioredis
17-
1815
from newrelic.api.background_task import background_task
1916

20-
from conftest import event_loop, loop, AIOREDIS_VERSION
2117
from testing_support.fixtures import validate_transaction_metrics, override_application_settings
2218
from testing_support.db_settings import redis_settings
2319
from testing_support.util import instance_hostname
@@ -64,23 +60,11 @@
6460
_disable_rollup_metrics.append((_instance_metric_name, None))
6561

6662

67-
if AIOREDIS_VERSION >= (2, 0):
68-
clients = [
69-
aioredis.Redis(host=DB_SETTINGS["host"], port=_port, db=0),
70-
aioredis.StrictRedis(host=DB_SETTINGS["host"], port=_port, db=0),
71-
]
72-
else:
73-
clients = [
74-
event_loop.run_until_complete(aioredis.create_redis("redis://%s:%d" % (DB_SETTINGS["host"], _port), db=0)),
75-
]
76-
77-
7863
async def exercise_redis(client):
7964
await client.set("key", "value")
8065
await client.get("key")
8166

8267

83-
@pytest.mark.parametrize("client", clients)
8468
@override_application_settings(_enable_instance_settings)
8569
@validate_transaction_metrics(
8670
"test_get_and_set:test_redis_client_operation_enable_instance",
@@ -93,7 +77,6 @@ def test_redis_client_operation_enable_instance(client, loop):
9377
loop.run_until_complete(exercise_redis(client))
9478

9579

96-
@pytest.mark.parametrize("client", clients)
9780
@override_application_settings(_disable_instance_settings)
9881
@validate_transaction_metrics(
9982
"test_get_and_set:test_redis_client_operation_disable_instance",

0 commit comments

Comments
 (0)