Skip to content

Commit 55da032

Browse files
committed
Detect cases of n+1 caused by Django deferred fields
1 parent 7045b8a commit 55da032

File tree

6 files changed

+73
-3
lines changed

6 files changed

+73
-3
lines changed

nplusone/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.0.0'
1+
__version__ = '1.1.0a1'

nplusone/ext/django/patch.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
create_reverse_many_to_one_manager,
1818
create_forward_many_to_many_manager,
1919
)
20+
from django.db.models.query_utils import DeferredAttribute
2021

2122

2223
def get_worker():
@@ -345,3 +346,37 @@ def getitem_queryset(self, index):
345346
)
346347
return original_getitem_queryset(self, index)
347348
query.QuerySet.__getitem__ = getitem_queryset
349+
350+
351+
def parse_refresh_from_db(instance, args, kwargs, context):
352+
# Instance passed via partial
353+
fields = kwargs.get('fields') or args[0]
354+
model = type(instance)
355+
return model, to_key(instance), fields[0]
356+
357+
358+
original_deferred_attribute_get = DeferredAttribute.__get__
359+
def deferred_attribute_get(self, instance, cls=None):
360+
"""
361+
DeferredAttribute.__get__() is called when a deferred
362+
field is accessed. It may or may not trigger a db query;
363+
if it does, it's going to be a refresh_from_db() call
364+
So we'll emit a `touch` from there
365+
"""
366+
if instance is None:
367+
return self
368+
# Refresh-from-db, intenally, calls QuerySet.get() on our
369+
# instance. Normally, this would make our instance immune
370+
# to further notifications. We don't want that to happen,
371+
# so we disable the ignore_load signal within refresh_from_db
372+
orig_refresh_from_db = instance.refresh_from_db
373+
def refresh_from_db(*args, **kwargs):
374+
with signals.ignore(signals.ignore_load):
375+
return orig_refresh_from_db(*args, **kwargs)
376+
instance.refresh_from_db = signals.signalify(
377+
signals.lazy_load,
378+
refresh_from_db,
379+
parser=functools.partial(parse_refresh_from_db, instance),
380+
)
381+
return original_deferred_attribute_get(self, instance, cls)
382+
DeferredAttribute.__get__ = deferred_attribute_get

tests/testapp/testapp/models.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,7 @@ class Address(models.Model):
2626

2727
class Hobby(models.Model):
2828
pass
29+
30+
31+
class Medicine(models.Model):
32+
name = models.CharField(max_length=20)

tests/testapp/testapp/tests.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from django.conf import settings
99
from django.http.request import HttpRequest
1010
from django.http.response import HttpResponse
11+
from django.test import override_settings
1112

1213
from nplusone.ext.django.patch import setup_state
1314
from nplusone.ext.django.middleware import NPlusOneMiddleware
@@ -32,6 +33,7 @@ def objects():
3233
address = models.Address.objects.create(user=user)
3334
hobby = models.Hobby.objects.create()
3435
user.hobbies.add(hobby)
36+
medicine = models.Medicine.objects.create(name="Allergix")
3537
return locals()
3638

3739

@@ -133,6 +135,23 @@ def test_many_to_many_reverse_prefetch(self, objects, calls):
133135
assert len(calls) == 0
134136

135137

138+
@pytest.mark.django_db
139+
class TestDeferred:
140+
141+
def test_deferred(self, objects, calls):
142+
medicine = list(models.Medicine.objects.defer('name'))[0]
143+
medicine.name
144+
assert len(calls) == 1
145+
call = calls[0]
146+
assert call.objects == (models.Medicine, 'Medicine:1', 'name')
147+
assert 'medicine.name' in ''.join(call.frame[4])
148+
149+
def test_non_deferred(self, objects, calls):
150+
medicine = list(models.Medicine.objects.all())[0]
151+
medicine.name
152+
assert len(calls) == 0
153+
154+
136155
@pytest.fixture
137156
def logger(monkeypatch):
138157
mock_logger = mock.Mock()
@@ -272,16 +291,22 @@ def test_select_nested_unused(self, objects, client, logger):
272291
assert any('Pet.user' in call[1] for call in calls)
273292
assert any('User.occupation' in call[1] for call in calls)
274293

294+
@override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.User'}])
275295
def test_many_to_many_whitelist(self, objects, client, logger):
276-
settings.NPLUSONE_WHITELIST = [{'model': 'testapp.User'}]
277296
client.get('/many_to_many/')
278297
assert not logger.log.called
279298

299+
@override_settings(NPLUSONE_WHITELIST=[{'model': 'testapp.*'}])
280300
def test_many_to_many_whitelist_wildcard(self, objects, client, logger):
281-
settings.NPLUSONE_WHITELIST = [{'model': 'testapp.*'}]
282301
client.get('/many_to_many/')
283302
assert not logger.log.called
284303

304+
def test_deferred(self, objects, client, logger):
305+
client.get('/deferred/')
306+
assert len(logger.log.call_args_list) == 1
307+
args = logger.log.call_args[0]
308+
assert 'Medicine.name' in args[1]
309+
285310

286311
@pytest.mark.django_db
287312
def test_values(objects, lazy_listener):

tests/testapp/testapp/urls.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@
2525
url(r'^prefetch_nested_unused/$', views.prefetch_nested_unused),
2626
url(r'^select_nested/$', views.select_nested),
2727
url(r'^select_nested_unused/$', views.select_nested_unused),
28+
url(r'^deferred/$', views.deferred),
2829
]

tests/testapp/testapp/views.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,8 @@ def select_nested(request):
127127
def select_nested_unused(request):
128128
pets = list(models.Pet.objects.all().select_related('user__occupation'))
129129
return HttpResponse(pets[0])
130+
131+
132+
def deferred(request):
133+
meds = list(models.Medicine.objects.defer('name'))
134+
return HttpResponse("; ".join(med.name for med in meds))

0 commit comments

Comments
 (0)