Skip to content

Commit 3a14cc6

Browse files
committed
Detect cases of n+1 caused by Django deferred fields
1 parent 7045b8a commit 3a14cc6

File tree

6 files changed

+106
-3
lines changed

6 files changed

+106
-3
lines changed

nplusone/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.0.0'
1+
__version__ = '1.1.0a1'

nplusone/ext/django/patch.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717
create_reverse_many_to_one_manager,
1818
create_forward_many_to_many_manager,
1919
)
20+
from django.db.models.query_utils import DeferredAttribute
21+
22+
NPLUSONE_WRAPPED = 'nplusone_wrapped'
2023

2124

2225
def get_worker():
@@ -345,3 +348,53 @@ def getitem_queryset(self, index):
345348
)
346349
return original_getitem_queryset(self, index)
347350
query.QuerySet.__getitem__ = getitem_queryset
351+
352+
353+
def parse_refresh_from_db(instance, fields, args, kwargs, context):
354+
# Instance & fields passed via partial
355+
model = type(instance)
356+
return model, to_key(instance), fields[0]
357+
358+
359+
original_deferred_attribute_get = DeferredAttribute.__get__
360+
def deferred_attribute_get(self, instance, cls=None):
361+
"""
362+
DeferredAttribute.__get__() is called when a deferred
363+
field is accessed. It may or may not trigger a db query;
364+
if it does, it's going to be a refresh_from_db() call
365+
So we'll emit a `touch` from there
366+
"""
367+
if instance is None:
368+
return self
369+
# Refresh-from-db, intenally, calls QuerySet.get() on our
370+
# instance. Normally, this would make our instance immune
371+
# to further notifications. We don't want that to happen,
372+
# so we disable the ignore_load signal within refresh_from_db
373+
ensure_wrapped_refresh_from_db(instance)
374+
return original_deferred_attribute_get(self, instance, cls)
375+
DeferredAttribute.__get__ = deferred_attribute_get
376+
377+
378+
def ensure_wrapped_refresh_from_db(instance):
379+
orig_refresh_from_db = instance.refresh_from_db
380+
if getattr(orig_refresh_from_db, NPLUSONE_WRAPPED, False):
381+
return
382+
@functools.wraps(orig_refresh_from_db)
383+
def refresh_from_db(fields=None, *args, **kwargs):
384+
with signals.ignore(signals.ignore_load):
385+
ret = orig_refresh_from_db(fields=fields, **kwargs)
386+
# and now, if the refresh_from_db was called for specific fields,
387+
# then it's a lazy load
388+
if fields:
389+
parser = functools.partial(parse_refresh_from_db, instance, fields)
390+
signals.lazy_load.send(
391+
get_worker(),
392+
args=args,
393+
kwargs=kwargs,
394+
ret=ret,
395+
context={},
396+
parser=parser,
397+
)
398+
return ret
399+
setattr(refresh_from_db, NPLUSONE_WRAPPED, True)
400+
instance.refresh_from_db = refresh_from_db

tests/testapp/testapp/models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,8 @@ class Address(models.Model):
2626

2727
class Hobby(models.Model):
2828
pass
29+
30+
31+
class Medicine(models.Model):
32+
name = models.CharField(max_length=20)
33+
prescription = models.BooleanField(default=False)

tests/testapp/testapp/tests.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from django.conf import settings
99
from django.http.request import HttpRequest
1010
from django.http.response import HttpResponse
11+
from django.test import override_settings
1112

1213
from nplusone.ext.django.patch import setup_state
1314
from nplusone.ext.django.middleware import NPlusOneMiddleware
@@ -32,6 +33,7 @@ def objects():
3233
address = models.Address.objects.create(user=user)
3334
hobby = models.Hobby.objects.create()
3435
user.hobbies.add(hobby)
36+
medicine = models.Medicine.objects.create(name="Allergix")
3537
return locals()
3638

3739

@@ -133,6 +135,23 @@ def test_many_to_many_reverse_prefetch(self, objects, calls):
133135
assert len(calls) == 0
134136

135137

138+
@pytest.mark.django_db
139+
class TestDeferred:
140+
141+
def test_deferred(self, objects, calls):
142+
medicine = list(models.Medicine.objects.defer('name'))[0]
143+
medicine.name
144+
assert len(calls) == 1
145+
call = calls[0]
146+
assert call.objects == (models.Medicine, 'Medicine:1', 'name')
147+
assert 'medicine.name' in ''.join(call.frame[4])
148+
149+
def test_non_deferred(self, objects, calls):
150+
medicine = list(models.Medicine.objects.all())[0]
151+
medicine.name
152+
assert len(calls) == 0
153+
154+
136155
@pytest.fixture
137156
def logger(monkeypatch):
138157
mock_logger = mock.Mock()
@@ -272,16 +291,29 @@ def test_select_nested_unused(self, objects, client, logger):
272291
assert any('Pet.user' in call[1] for call in calls)
273292
assert any('User.occupation' in call[1] for call in calls)
274293

294+
@override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.User'}])
275295
def test_many_to_many_whitelist(self, objects, client, logger):
276-
settings.NPLUSONE_WHITELIST = [{'model': 'testapp.User'}]
277296
client.get('/many_to_many/')
278297
assert not logger.log.called
279298

299+
@override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.*'}])
280300
def test_many_to_many_whitelist_wildcard(self, objects, client, logger):
281-
settings.NPLUSONE_WHITELIST = [{'model': 'testapp.*'}]
282301
client.get('/many_to_many/')
283302
assert not logger.log.called
284303

304+
def test_deferred(self, objects, client, logger):
305+
client.get('/deferred/')
306+
assert len(logger.log.call_args_list) == 1
307+
args = logger.log.call_args[0]
308+
assert 'Medicine.name' in args[1]
309+
310+
def test_double_deferred(self, objects, client, logger):
311+
client.get('/double_deferred/')
312+
assert len(logger.log.call_args_list) == 2
313+
messages = sorted({args[0][1] for args in logger.log.call_args_list})
314+
assert 'Medicine.name' in messages[0]
315+
assert 'Medicine.prescription' in messages[1]
316+
285317

286318
@pytest.mark.django_db
287319
def test_values(objects, lazy_listener):

tests/testapp/testapp/urls.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,6 @@
2525
url(r'^prefetch_nested_unused/$', views.prefetch_nested_unused),
2626
url(r'^select_nested/$', views.select_nested),
2727
url(r'^select_nested_unused/$', views.select_nested_unused),
28+
url(r'^deferred/$', views.deferred),
29+
url(r'^double_deferred/$', views.double_deferred),
2830
]

tests/testapp/testapp/views.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,14 @@ def select_nested(request):
127127
def select_nested_unused(request):
128128
pets = list(models.Pet.objects.all().select_related('user__occupation'))
129129
return HttpResponse(pets[0])
130+
131+
132+
def deferred(request):
133+
meds = list(models.Medicine.objects.defer('name'))
134+
return HttpResponse("; ".join(med.name for med in meds))
135+
136+
def double_deferred(request):
137+
meds = list(models.Medicine.objects.only('id'))
138+
return HttpResponse("; ".join(
139+
med.name + (' *' if med.prescription else '') for med in meds
140+
))

0 commit comments

Comments
 (0)