Skip to content

Commit d39e296

Browse files
committed
Handle excluding many-to-many fields with javascript_exclude.
1 parent 4efce0d commit d39e296

File tree

3 files changed

+193
-13
lines changed

3 files changed

+193
-13
lines changed

django_unicorn/serializer.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,17 @@ def _get_model_dict(model: Model) -> dict:
9999

100100
for field in model._meta.get_fields():
101101
if field.is_relation and field.many_to_many:
102+
exclude_field_attributes = getattr(
103+
model, "__unicorn__exclude_field_attribute", []
104+
)
102105
related_name = field.name
103106

104107
if field.auto_created:
105108
related_name = field.related_name or f"{field.name}_set"
106109

110+
if related_name in exclude_field_attributes:
111+
continue
112+
107113
pks = []
108114

109115
try:
@@ -231,10 +237,6 @@ def _exclude_field_attributes(
231237
_exclude_field_attributes({"1": {"2": {"3": "4"}}}, ("1.2.3",)) == {"1": {"2": {}}}
232238
"""
233239

234-
assert exclude_field_attributes is None or is_non_string_sequence(
235-
exclude_field_attributes
236-
), "exclude_field_attributes type needs to be a sequence"
237-
238240
if exclude_field_attributes:
239241
for field in exclude_field_attributes:
240242
field_splits = field.split(".")
@@ -278,6 +280,40 @@ def dumps(
278280
Returns a `str` instead of `bytes` (which deviates from `orjson.dumps`), but seems more useful.
279281
"""
280282

283+
assert exclude_field_attributes is None or is_non_string_sequence(
284+
exclude_field_attributes
285+
), "exclude_field_attributes type needs to be a sequence"
286+
287+
# This explicitly handles excluding many-to-many fields with a hacky private `__unicorn__exclude_field_attribute`
288+
# attribute that gets used later in `_get_model_json`
289+
if exclude_field_attributes:
290+
updated_exclude_field_attributes = set()
291+
292+
for field_attributes in exclude_field_attributes:
293+
field_attributes_split = field_attributes.split(".")
294+
field_attribute = field_attributes_split[0]
295+
296+
for key in data.keys():
297+
if isinstance(data[key], Model) and key == field_attribute:
298+
remaining_field_attributes = field_attributes_split[1:]
299+
300+
if hasattr(data[key], "__unicorn__exclude_field_attribute"):
301+
data[key].__unicorn__exclude_field_attribute.append(
302+
remaining_field_attributes[0]
303+
)
304+
else:
305+
setattr(
306+
data[key],
307+
"__unicorn__exclude_field_attribute",
308+
remaining_field_attributes,
309+
)
310+
break
311+
else:
312+
updated_exclude_field_attributes.add(field_attributes)
313+
314+
# Convert list to tuple again so it's hashable for `lru_cache`
315+
exclude_field_attributes = tuple(updated_exclude_field_attributes)
316+
281317
serialized_data = orjson.dumps(data, default=_json_serializer)
282318

283319
if fix_floats:

tests/components/test_component.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,11 +202,14 @@ class TestComponent(UnicornView):
202202
another = {"Neutral Milk Hotel": {"album": {"On Avery Island": 1996}}}
203203

204204
class Meta:
205-
javascript_exclude = ("name.Universe", "another.Neutral Milk Hotel.album")
205+
javascript_exclude = ("another.Neutral Milk Hotel.album",)
206206

207-
expected = '{"another":{"Neutral Milk Hotel":{}},"name":{}}'
207+
expected = (
208+
'{"another":{"Neutral Milk Hotel":{}},"name":{"Universe":{"World":"Earth"}}}'
209+
)
208210
component = TestComponent(component_id="asdf1234", component_name="hello-world")
209-
assert expected == component.get_frontend_context_variables()
211+
actual = component.get_frontend_context_variables()
212+
assert expected == actual
210213

211214

212215
def test_meta_javascript_exclude_nested_with_list():

tests/serializer/test_dumps.py

Lines changed: 147 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,90 @@ def test_model_foreign_key_recursive_parent():
196196
assert expected == actual
197197

198198

199+
@pytest.mark.django_db
200+
def test_model_many_to_many(django_assert_num_queries):
201+
flavor_one = Flavor(name="name1", label="label1")
202+
flavor_one.save()
203+
204+
taste1 = Taste(name="Bitter1")
205+
taste1.save()
206+
taste2 = Taste(name="Bitter2")
207+
taste2.save()
208+
taste3 = Taste(name="Bitter3")
209+
taste3.save()
210+
211+
flavor_one.taste_set.add(taste1)
212+
flavor_one.taste_set.add(taste2)
213+
flavor_one.taste_set.add(taste3)
214+
215+
with django_assert_num_queries(2):
216+
actual = serializer.dumps(flavor_one)
217+
218+
expected = {
219+
"name": "name1",
220+
"label": "label1",
221+
"parent": None,
222+
"float_value": None,
223+
"decimal_value": None,
224+
"uuid": str(flavor_one.uuid),
225+
"datetime": None,
226+
"date": None,
227+
"time": None,
228+
"duration": None,
229+
"pk": 1,
230+
"taste_set": [taste1.pk, taste2.pk, taste3.pk],
231+
"origins": [],
232+
}
233+
234+
assert expected == json.loads(actual)
235+
236+
237+
@pytest.mark.django_db
238+
def test_model_many_to_many_with_excludes(django_assert_num_queries):
239+
flavor_one = Flavor(name="name1", label="label1")
240+
flavor_one.save()
241+
242+
taste1 = Taste(name="Bitter1")
243+
taste1.save()
244+
taste2 = Taste(name="Bitter2")
245+
taste2.save()
246+
taste3 = Taste(name="Bitter3")
247+
taste3.save()
248+
249+
flavor_one.taste_set.add(taste1)
250+
flavor_one.taste_set.add(taste2)
251+
flavor_one.taste_set.add(taste3)
252+
253+
# This shouldn't make any database calls because the many-to-manys are excluded and
254+
# all of the other data is already set
255+
with django_assert_num_queries(0):
256+
actual = serializer.dumps(
257+
{"flavor": flavor_one},
258+
exclude_field_attributes=(
259+
"flavor.taste_set",
260+
"flavor.origins",
261+
),
262+
)
263+
264+
expected = {
265+
"flavor": {
266+
"name": "name1",
267+
"label": "label1",
268+
"parent": None,
269+
"float_value": None,
270+
"decimal_value": None,
271+
"uuid": str(flavor_one.uuid),
272+
"datetime": None,
273+
"date": None,
274+
"time": None,
275+
"duration": None,
276+
"pk": 1,
277+
}
278+
}
279+
280+
assert expected == json.loads(actual)
281+
282+
199283
@pytest.mark.django_db
200284
def test_dumps_queryset(db):
201285
flavor_one = Flavor(name="name1", label="label1")
@@ -269,13 +353,62 @@ def test_get_model_dict():
269353

270354

271355
@pytest.mark.django_db
272-
def test_get_model_dict_many_to_many_is_referenced():
273-
taste = Taste(name="Bitter")
274-
taste.save()
356+
def test_get_model_dict_many_to_many_is_referenced(django_assert_num_queries):
275357
flavor_one = Flavor(name="name1", label="label1")
276358
flavor_one.save()
277-
flavor_one.taste_set.add(taste)
278-
actual = serializer._get_model_dict(flavor_one)
359+
360+
taste1 = Taste(name="Bitter")
361+
taste1.save()
362+
taste2 = Taste(name="Bitter2")
363+
taste2.save()
364+
taste3 = Taste(name="Bitter3")
365+
taste3.save()
366+
367+
flavor_one.taste_set.add(taste1)
368+
flavor_one.taste_set.add(taste2)
369+
flavor_one.taste_set.add(taste3)
370+
371+
expected = {
372+
"pk": 1,
373+
"name": "name1",
374+
"label": "label1",
375+
"parent": None,
376+
"decimal_value": None,
377+
"float_value": None,
378+
"uuid": str(flavor_one.uuid),
379+
"date": None,
380+
"datetime": None,
381+
"time": None,
382+
"duration": None,
383+
"taste_set": [taste1.pk, taste2.pk, taste3.pk],
384+
"origins": [],
385+
}
386+
387+
flavor_one = Flavor.objects.filter(id=flavor_one.id).first()
388+
389+
with django_assert_num_queries(2):
390+
actual = serializer._get_model_dict(flavor_one)
391+
392+
assert expected == actual
393+
394+
395+
@pytest.mark.django_db
396+
def test_get_model_dict_many_to_many_is_referenced_prefetched(
397+
django_assert_num_queries,
398+
):
399+
flavor_one = Flavor(name="name1", label="label1")
400+
flavor_one.save()
401+
402+
taste1 = Taste(name="Bitter")
403+
taste1.save()
404+
taste2 = Taste(name="Bitter2")
405+
taste2.save()
406+
taste3 = Taste(name="Bitter3")
407+
taste3.save()
408+
409+
flavor_one.taste_set.add(taste1)
410+
flavor_one.taste_set.add(taste2)
411+
flavor_one.taste_set.add(taste3)
279412

280413
expected = {
281414
"pk": 1,
@@ -289,10 +422,18 @@ def test_get_model_dict_many_to_many_is_referenced():
289422
"datetime": None,
290423
"time": None,
291424
"duration": None,
292-
"taste_set": [taste.pk],
425+
"taste_set": [taste1.pk, taste2.pk, taste3.pk],
293426
"origins": [],
294427
}
295428

429+
flavor_one = (
430+
Flavor.objects.prefetch_related("taste_set").filter(id=flavor_one.id).first()
431+
)
432+
433+
# prefetch_related should reduce the database calls
434+
with django_assert_num_queries(1):
435+
actual = serializer._get_model_dict(flavor_one)
436+
296437
assert expected == actual
297438

298439

0 commit comments

Comments
 (0)