Skip to content

Commit bde45c8

Browse files
authored
Use to_python to make sure keys are correct type (#22)
* pass values to prep functions * test that values are prepped * formatting * always return identity func if no 'to_python' method
1 parent 96d8001 commit bde45c8

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

bulk_sync/__init__.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from collections import OrderedDict
1+
from collections import OrderedDict, defaultdict
22
import logging
3+
import functools
34

45
from django.db import transaction, router
56
from django.core.exceptions import FieldDoesNotExist
@@ -50,7 +51,8 @@ def bulk_sync(
5051

5152
if db_class is None:
5253
raise RuntimeError(
53-
"Unable to identify model to sync. Need to provide at least one object in `new_models`, provide `db_class`, or set `new_models` with a queryset like `db_class.objects.none()`."
54+
"Unable to identify model to sync. Need to provide at least one object in `new_models`, provide "
55+
"`db_class`, or set `new_models` with a queryset like `db_class.objects.none()`."
5456
)
5557

5658
if fields is None:
@@ -79,15 +81,25 @@ def bulk_sync(
7981
objs = objs.filter(filters)
8082
objs = objs.only("pk", *key_fields).select_for_update()
8183

82-
def get_key(obj):
83-
return tuple(getattr(obj, k) for k in key_fields)
84+
prep_functions = defaultdict(lambda: lambda x: x)
85+
prep_functions.update({
86+
field.name: functools.partial(field.to_python)
87+
for field in (db_class._meta.get_field(k) for k in key_fields)
88+
if hasattr(field, 'to_python')
89+
})
90+
91+
def get_key(obj, prep_values=False):
92+
return tuple(
93+
prep_functions[k](getattr(obj, k)) if prep_values else getattr(obj, k)
94+
for k in key_fields
95+
)
8496

8597
obj_dict = {get_key(obj): obj for obj in objs}
8698

8799
new_objs = []
88100
existing_objs = []
89101
for new_obj in new_models:
90-
old_obj = obj_dict.pop(get_key(new_obj), None)
102+
old_obj = obj_dict.pop(get_key(new_obj, prep_values=True), None)
91103
if old_obj is None:
92104
# This is a new object, so create it.
93105
new_objs.append(new_obj)

tests/tests.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,28 @@ def test_empty_new_models_class_detection_works(self):
219219
ret = bulk_sync(new_models=[], filters=None, key_fields=("name",), db_class=Employee)
220220
ret = bulk_sync(new_models=Employee.objects.none(), filters=None, key_fields=("name",))
221221

222+
def test_new_objs_with_unprepped_field_values_are_processed_correctly(self):
223+
c1 = Company.objects.create(name="Foo Products, Ltd.")
224+
c2 = Company.objects.create(name="Bar Microcontrollers, Inc.")
225+
e1 = Employee.objects.create(name="Scott", age=40, company=c1)
226+
227+
new_objs = [Employee(name="Scott", age="40", company=c2)]
228+
229+
ret = bulk_sync(
230+
new_models=new_objs,
231+
filters=None,
232+
key_fields=("name", "age"),
233+
)
234+
235+
# we should should update e1's company to c2
236+
self.assertEqual(1, ret["stats"]["updated"])
237+
self.assertEqual(c2, Employee.objects.get(pk=e1.pk).company)
238+
239+
# we should not create or delete anything
240+
self.assertEqual(0, ret["stats"]["created"])
241+
self.assertEqual(0, ret["stats"]["deleted"])
242+
self.assertEqual(1, Employee.objects.count())
243+
222244

223245
class BulkCompareTests(TestCase):
224246
""" Test `bulk_compare` method """

0 commit comments

Comments
 (0)