Skip to content

Commit 5c3e98f

Browse files
committed
Make sure that many-to-many fields get excluded from JavaScript as expected.
1 parent c4f560d commit 5c3e98f

File tree

1 file changed

+97
-51
lines changed

1 file changed

+97
-51
lines changed

django_unicorn/serializer.py

Lines changed: 97 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
from decimal import Decimal
33
from functools import lru_cache
4-
from typing import Any, Dict, List, Tuple
4+
from typing import Any, Dict, List, Optional, Tuple
55

66
from django.core.serializers import serialize
77
from django.db.models import (
@@ -82,6 +82,31 @@ def _parse_field_values_from_string(model: Model) -> None:
8282
setattr(model, field.attname, parse_duration(val))
8383

8484

85+
def _get_many_to_many_field_related_names(model: Model) -> List[str]:
86+
"""
87+
Get the many-to-many fields for a particular model. Returns either the automatically
88+
defined field name (i.e. something_set) or the related name.
89+
"""
90+
91+
# Use this internal method so that the fields can be cached
92+
@lru_cache(maxsize=128, typed=True)
93+
def _get_many_to_many_field_related_names_from_meta(meta):
94+
names = []
95+
96+
for field in meta.get_fields():
97+
if field.is_relation and field.many_to_many:
98+
related_name = field.name
99+
100+
if field.auto_created:
101+
related_name = field.related_name or f"{field.name}_set"
102+
103+
names.append(related_name)
104+
105+
return names
106+
107+
return _get_many_to_many_field_related_names_from_meta(model._meta)
108+
109+
85110
def _get_model_dict(model: Model) -> dict:
86111
"""
87112
Serializes Django models. Uses the built-in Django JSON serializer, but moves the data around to
@@ -97,32 +122,27 @@ def _get_model_dict(model: Model) -> dict:
97122
model_json = model_json.get("fields")
98123
model_json["pk"] = model_pk
99124

100-
for field in model._meta.get_fields():
101-
if field.is_relation and field.many_to_many:
102-
exclude_field_attributes = getattr(
103-
model, "__unicorn__exclude_field_attribute", []
104-
)
105-
related_name = field.name
106-
107-
if field.auto_created:
108-
related_name = field.related_name or f"{field.name}_set"
125+
exclude_field_related_names = getattr(
126+
model, "__unicorn__exclude_field_related_names", []
127+
)
109128

110-
if related_name in exclude_field_attributes:
111-
continue
129+
for related_name in _get_many_to_many_field_related_names(model):
130+
if related_name in exclude_field_related_names:
131+
continue
112132

113-
pks = []
133+
pks = []
114134

115-
try:
116-
related_descriptor = getattr(model, related_name)
135+
try:
136+
related_descriptor = getattr(model, related_name)
117137

118-
# Get `pk` from `all` because it will re-use the cached data if the m-2-m field is prefetched
119-
# Using `values_list("pk", flat=True)` or `only()` won't use the cached prefetched values
120-
pks = [m.pk for m in related_descriptor.all()]
121-
except ValueError:
122-
# ValueError is throuwn when the model doesn't have an id already set
123-
pass
138+
# Get `pk` from `all` because it will re-use the cached data if the m-2-m field is prefetched
139+
# Using `values_list("pk", flat=True)` or `only()` won't use the cached prefetched values
140+
pks = [m.pk for m in related_descriptor.all()]
141+
except ValueError:
142+
# ValueError is throuwn when the model doesn't have an id already set
143+
pass
124144

125-
model_json[related_name] = pks
145+
model_json[related_name] = pks
126146

127147
return model_json
128148

@@ -262,6 +282,58 @@ def _exclude_field_attributes(
262282
del dict_data[field_name][field_attr]
263283

264284

285+
def _handle_many_to_many_excluded_field_attributes(
286+
data: Dict, exclude_field_attributes: Optional[Tuple[str]]
287+
) -> Optional[Tuple[str]]:
288+
"""
289+
Explicitly handle excluding many-to-many fields on models with a semi-hacky private
290+
`__unicorn__exclude_field_related_names attribute that gets used later in `_get_model_json`.
291+
Since the many-to-many field won't be serialized, remove it from the list so it won't
292+
be tried to be removed in `_exclude_field_attributes`.
293+
"""
294+
295+
if exclude_field_attributes:
296+
many_to_many_field_attributes = set()
297+
298+
for field_attributes in exclude_field_attributes:
299+
if "." not in field_attributes:
300+
continue
301+
302+
(field_attribute, exclude_field_related_name, *_) = field_attributes.split(
303+
"."
304+
)
305+
306+
for key in data.keys():
307+
if isinstance(data[key], Model) and key == field_attribute:
308+
model = data[key]
309+
310+
many_to_many_related_names = _get_many_to_many_field_related_names(
311+
model
312+
)
313+
314+
if exclude_field_related_name in many_to_many_related_names:
315+
if hasattr(model, "__unicorn__exclude_field_related_names"):
316+
model.__unicorn__exclude_field_related_names.append(
317+
exclude_field_related_name
318+
)
319+
else:
320+
setattr(
321+
model,
322+
"__unicorn__exclude_field_related_names",
323+
[exclude_field_related_name],
324+
)
325+
326+
many_to_many_field_attributes.add(field_attributes)
327+
break
328+
329+
# Convert list to tuple again so it's hashable for `lru_cache`
330+
exclude_field_attributes = tuple(
331+
set(exclude_field_attributes) - many_to_many_field_attributes
332+
)
333+
334+
return exclude_field_attributes
335+
336+
265337
def dumps(
266338
data: Dict, fix_floats: bool = True, exclude_field_attributes: Tuple[str] = None
267339
) -> str:
@@ -284,35 +356,9 @@ def dumps(
284356
exclude_field_attributes
285357
), "exclude_field_attributes type needs to be a sequence"
286358

287-
# This explicitly handles excluding many-to-many fields with a hacky private `__unicorn__exclude_field_attribute`
288-
# attribute that gets used later in `_get_model_json`
289-
if exclude_field_attributes:
290-
updated_exclude_field_attributes = set()
291-
292-
for field_attributes in exclude_field_attributes:
293-
field_attributes_split = field_attributes.split(".")
294-
field_attribute = field_attributes_split[0]
295-
296-
for key in data.keys():
297-
if isinstance(data[key], Model) and key == field_attribute:
298-
remaining_field_attributes = field_attributes_split[1:]
299-
300-
if hasattr(data[key], "__unicorn__exclude_field_attribute"):
301-
data[key].__unicorn__exclude_field_attribute.append(
302-
remaining_field_attributes[0]
303-
)
304-
else:
305-
setattr(
306-
data[key],
307-
"__unicorn__exclude_field_attribute",
308-
remaining_field_attributes,
309-
)
310-
break
311-
else:
312-
updated_exclude_field_attributes.add(field_attributes)
313-
314-
# Convert list to tuple again so it's hashable for `lru_cache`
315-
exclude_field_attributes = tuple(updated_exclude_field_attributes)
359+
exclude_field_attributes = _handle_many_to_many_excluded_field_attributes(
360+
data, exclude_field_attributes
361+
)
316362

317363
serialized_data = orjson.dumps(data, default=_json_serializer)
318364

0 commit comments

Comments
 (0)