|
1 |
| -from collections import OrderedDict |
| 1 | +from collections import OrderedDict, defaultdict |
2 | 2 | import logging
|
| 3 | +import functools |
3 | 4 |
|
4 | 5 | from django.db import transaction, router
|
5 | 6 | from django.core.exceptions import FieldDoesNotExist
|
@@ -50,7 +51,8 @@ def bulk_sync(
|
50 | 51 |
|
51 | 52 | if db_class is None:
|
52 | 53 | raise RuntimeError(
|
53 |
| - "Unable to identify model to sync. Need to provide at least one object in `new_models`, provide `db_class`, or set `new_models` with a queryset like `db_class.objects.none()`." |
| 54 | + "Unable to identify model to sync. Need to provide at least one object in `new_models`, provide " |
| 55 | + "`db_class`, or set `new_models` with a queryset like `db_class.objects.none()`." |
54 | 56 | )
|
55 | 57 |
|
56 | 58 | if fields is None:
|
@@ -79,15 +81,25 @@ def bulk_sync(
|
79 | 81 | objs = objs.filter(filters)
|
80 | 82 | objs = objs.only("pk", *key_fields).select_for_update()
|
81 | 83 |
|
82 |
| - def get_key(obj): |
83 |
| - return tuple(getattr(obj, k) for k in key_fields) |
| 84 | + prep_functions = defaultdict(lambda: lambda x: x) |
| 85 | + prep_functions.update({ |
| 86 | + field.name: functools.partial(field.to_python) |
| 87 | + for field in (db_class._meta.get_field(k) for k in key_fields) |
| 88 | + if hasattr(field, 'to_python') |
| 89 | + }) |
| 90 | + |
| 91 | + def get_key(obj, prep_values=False): |
| 92 | + return tuple( |
| 93 | + prep_functions[k](getattr(obj, k)) if prep_values else getattr(obj, k) |
| 94 | + for k in key_fields |
| 95 | + ) |
84 | 96 |
|
85 | 97 | obj_dict = {get_key(obj): obj for obj in objs}
|
86 | 98 |
|
87 | 99 | new_objs = []
|
88 | 100 | existing_objs = []
|
89 | 101 | for new_obj in new_models:
|
90 |
| - old_obj = obj_dict.pop(get_key(new_obj), None) |
| 102 | + old_obj = obj_dict.pop(get_key(new_obj, prep_values=True), None) |
91 | 103 | if old_obj is None:
|
92 | 104 | # This is a new object, so create it.
|
93 | 105 | new_objs.append(new_obj)
|
|
0 commit comments