Skip to content

Commit 276828f

Browse files
committed
[Performance] Optimized UserEvent.decay to resolve N+1 queries. Closes #3460
1 parent 5619186 commit 276828f

File tree

2 files changed

+35
-12
lines changed

2 files changed

+35
-12
lines changed

api_app/user_events_manager/queryset.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,7 @@
1212
class UserEventQuerySet(QuerySet):
1313
def decay(self):
1414
from collections import defaultdict
15-
1615
from django.db import transaction
17-
1816
objects = (
1917
self.exclude(decay_progression=DecayProgressionEnum.FIXED.value)
2018
.exclude(next_decay__isnull=True)
@@ -29,22 +27,17 @@ def decay(self):
2927
objects = objects.select_related("data_model")
3028
else:
3129
objects = objects.prefetch_related("data_model")
32-
3330
# Load into memory so we can mutate fields and bulk-write back.
3431
events = list(objects)
3532
if not events:
3633
return 0
37-
3834
# Group by concrete class since bulk_update works per-table.
3935
data_models_by_class = defaultdict(list)
40-
4136
for obj in events:
4237
obj.decay_times += 1
4338
data_model = obj.data_model
44-
4539
if data_model is not None:
4640
data_model.reliability -= 1
47-
4841
if data_model is None or data_model.reliability == 0:
4942
obj.next_decay = None
5043
else:
@@ -54,17 +47,15 @@ def decay(self):
5447
obj.next_decay += datetime.timedelta(
5548
days=obj.decay_timedelta_days ** (obj.decay_times + 1)
5649
)
57-
5850
if data_model is not None:
5951
data_models_by_class[data_model.__class__].append(data_model)
60-
6152
# Bulk-write instead of per-object .save() to avoid O(N) queries.
6253
# Atomic so partial failures don't leave inconsistent state.
6354
with transaction.atomic():
6455
for model_class, models_list in data_models_by_class.items():
65-
model_class.objects.bulk_update(models_list, ["reliability"])
66-
self.model.objects.bulk_update(events, ["decay_times", "next_decay"])
67-
56+
unique_models = {m.pk: m for m in models_list}.values()
57+
model_class.objects.bulk_update(unique_models, ["reliability"], batch_size=1000)
58+
self.model.objects.bulk_update(events, ["decay_times", "next_decay"], batch_size=1000)
6859
return len(events)
6960

7061
def visible_for_user(self, user):

tests/api_app/user_events_manager/test_queryset.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,38 @@ def test_decay_reliability_reaches_zero(self):
151151
ua.delete()
152152
an.delete()
153153

154+
def test_decay_performance(self):
155+
from django.db import connection
156+
from django.test.utils import CaptureQueriesContext
157+
an = Analyzable.objects.create(
158+
name="test_perf.com",
159+
classification=Classification.DOMAIN,
160+
)
161+
N = 5
162+
for i in range(N):
163+
an_i = Analyzable.objects.create(
164+
name=f"test_perf{i}.com",
165+
classification=Classification.DOMAIN,
166+
)
167+
ue_ser = UserAnalyzableEventSerializer(
168+
data={
169+
"analyzable": {"name": an_i.name},
170+
"decay_progression": 0,
171+
"decay_timedelta_days": 1,
172+
"data_model_content": {"evaluation": "malicious", "reliability": 8},
173+
},
174+
context={"request": MockUpRequest(self.user)},
175+
)
176+
ue_ser.is_valid()
177+
ua = ue_ser.save()
178+
ua.next_decay = now() - datetime.timedelta(days=1)
179+
ua.save()
180+
with CaptureQueriesContext(connection) as queries:
181+
number = UserAnalyzableEvent.objects.decay()
182+
self.assertEqual(number, N)
183+
self.assertLessEqual(len(queries), 10)
184+
Analyzable.objects.filter(name__startswith="test_perf").delete()
185+
154186
def test_decay_multiple_events(self):
155187
analyzables = []
156188
events = []

0 commit comments

Comments
 (0)