diff --git a/django_bulk_update/helper.py b/django_bulk_update/helper.py index ef45d90..1c1fb22 100644 --- a/django_bulk_update/helper.py +++ b/django_bulk_update/helper.py @@ -7,7 +7,7 @@ from collections import defaultdict from django.db import connections, models -from django.db.models.query import QuerySet +from django.db.models.signals import post_save, pre_save from django.db.models.sql import UpdateQuery @@ -219,6 +219,27 @@ def bulk_update(objs, meta=None, update_fields=None, exclude_fields=None, lenpks += n_pks + signal_kwargs = { + 'raw': False, + 'using': using, + 'update_fields': update_fields, + } + + for obj in objs_batch: + pre_save.send( + sender=obj.__class__, + instance=obj, + **signal_kwargs + ) + connection.cursor().execute(sql, parameters) + for obj in objs_batch: + post_save.send( + sender=obj.__class__, + instance=obj, + created=False, + **signal_kwargs + ) + return lenpks diff --git a/tests/requirements/requirements_base.txt b/tests/requirements/requirements_base.txt index 376c17c..f90e601 100644 --- a/tests/requirements/requirements_base.txt +++ b/tests/requirements/requirements_base.txt @@ -1,5 +1,5 @@ dj_database_url==0.4.2 jsonfield==2.0.1 pillow==4.1.1 -psycopg2==2.7.3 +psycopg2-binary==2.7.5 six==1.10.0 diff --git a/tests/tests.py b/tests/tests.py index a3964fb..e87ebf6 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -1,5 +1,4 @@ -import random - +import collections from datetime import date, time, timedelta from decimal import Decimal from unittest import skipUnless @@ -7,6 +6,8 @@ from django.conf import settings from django.db.models import F, Func, Value from django.db.models.functions import Concat +from django.db.models.signals import pre_save, post_save +from django.dispatch import receiver from django.test import TestCase from django.utils import timezone @@ -888,3 +889,32 @@ def test_validate_fields(self): exclude_fields = ['jobs'] self.assertRaises(TypeError, helper.get_fields, update_fields, exclude_fields, meta) + + def test_signals(self): + signal_calls = collections.defaultdict(int) + + @receiver(pre_save) + def pre_save_handler(sender, instance, raw, using, update_fields, **kwargs): + self.assertIs(sender, Person) + self.assertIsInstance(instance, Person) + self.assertFalse(raw) + self.assertEquals(using, 'default') + self.assertListEqual(update_fields, ['age']) + signal_calls['pre_save'] += 1 + + @receiver(post_save) + def post_save_handler(sender, instance, created, raw, using, update_fields, **kwargs): + self.assertIs(sender, Person) + self.assertIsInstance(instance, Person) + self.assertFalse(created) + self.assertFalse(raw) + self.assertEquals(using, 'default') + self.assertListEqual(update_fields, ['age']) + signal_calls['post_save'] += 1 + + people = Person.objects.all() + Person.objects.bulk_update(people, update_fields=['age']) + + self.assertGreater(len(people), 0) + self.assertEquals(signal_calls['pre_save'], len(people)) + self.assertEquals(signal_calls['post_save'], len(people))