Skip to content

Commit bdadf52

Browse files
authored
Merge pull request #305 from HackSoftware/update-model-m2m
Add support for m2m fields in `model_update`
2 parents d4204d1 + 39e3b02 commit bdadf52

File tree

5 files changed

+179
-28
lines changed

5 files changed

+179
-28
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import factory
2+
3+
from styleguide_example.common.models import RandomModel, SimpleModel
4+
from styleguide_example.utils.tests import faker
5+
6+
7+
class RandomModelFactory(factory.django.DjangoModelFactory):
8+
class Meta:
9+
model = RandomModel
10+
11+
end_date = factory.LazyAttribute(lambda self: faker.date_object())
12+
start_date = factory.LazyAttribute(lambda self: faker.date_object(end_datetime=self.end_date))
13+
14+
15+
class SimpleModelFactory(factory.django.DjangoModelFactory):
16+
class Meta:
17+
model = SimpleModel
18+
19+
name = factory.LazyAttribute(lambda self: faker.word())
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Generated by Django 4.1.3 on 2022-11-28 11:58
2+
3+
from django.db import migrations, models
4+
import django.utils.timezone
5+
6+
7+
class Migration(migrations.Migration):
8+
9+
dependencies = [
10+
('common', '0004_remove_randommodel_start_date_before_end_date_and_more'),
11+
]
12+
13+
operations = [
14+
migrations.CreateModel(
15+
name='SimpleModel',
16+
fields=[
17+
('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
18+
('created_at', models.DateTimeField(db_index=True, default=django.utils.timezone.now)),
19+
('updated_at', models.DateTimeField(auto_now=True)),
20+
('name', models.CharField(blank=True, max_length=255, null=True)),
21+
],
22+
options={
23+
'abstract': False,
24+
},
25+
),
26+
migrations.AddField(
27+
model_name='randommodel',
28+
name='simple_objects',
29+
field=models.ManyToManyField(blank=True, related_name='random_objects', to='common.simplemodel'),
30+
),
31+
]

styleguide_example/common/models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@ class Meta:
1111
abstract = True
1212

1313

14+
class SimpleModel(BaseModel):
15+
"""
16+
This is a basic model used to illustrate a many-to-many relationship
17+
with RandomModel.
18+
"""
19+
20+
name = models.CharField(max_length=255, blank=True, null=True)
21+
22+
1423
class RandomModel(BaseModel):
1524
"""
1625
This is an example model, to be used as reference in the Styleguide,
@@ -20,5 +29,7 @@ class RandomModel(BaseModel):
2029
start_date = models.DateField()
2130
end_date = models.DateField()
2231

32+
simple_objects = models.ManyToManyField(SimpleModel, blank=True, related_name="random_objects")
33+
2334
class Meta:
2435
constraints = [models.CheckConstraint(name="start_date_before_end_date", check=Q(start_date__lt=F("end_date")))]

styleguide_example/common/services.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from typing import Any, Dict, List, Tuple
22

3+
from django.db import models
4+
35
from styleguide_example.common.types import DjangoModelType
46

57

68
def model_update(*, instance: DjangoModelType, fields: List[str], data: Dict[str, Any]) -> Tuple[DjangoModelType, bool]:
79
"""
8-
Generic update service meant to be reused in local update services
10+
Generic update service meant to be reused in local update services.
911
1012
For example:
1113
@@ -18,26 +20,56 @@ def user_update(*, user: User, data) -> User:
1820
return user
1921
2022
Return value: Tuple with the following elements:
21-
1. The instance we updated
23+
1. The instance we updated.
2224
2. A boolean value representing whether we performed an update or not.
25+
26+
Some important notes:
27+
28+
- Only keys present in `fields` will be taken from `data`.
29+
- If something in present in `fields` but not present in `data`, we simply skip.
30+
- There's a strict assertion that all values in `fields` are actual fields in `instance`.
31+
- `fields` can support m2m fields, which are handled after the update on `instance`.
2332
"""
2433
has_updated = False
34+
m2m_data = {}
35+
update_fields = []
36+
37+
model_fields = {field.name: field for field in instance._meta.get_fields()}
2538

2639
for field in fields:
2740
# Skip if a field is not present in the actual data
2841
if field not in data:
2942
continue
3043

44+
# If field is not an actual model field, raise an error
45+
model_field = model_fields.get(field)
46+
47+
assert model_field is not None, f"{field} is not part of {instance.__class__.__name__} fields."
48+
49+
# If we have m2m field, handle differently
50+
if isinstance(model_field, models.ManyToManyField):
51+
m2m_data[field] = data[field]
52+
continue
53+
3154
if getattr(instance, field) != data[field]:
3255
has_updated = True
56+
update_fields.append(field)
3357
setattr(instance, field, data[field])
3458

35-
# Perform an update only if any of the fields was actually changed
59+
# Perform an update only if any of the fields were actually changed
3660
if has_updated:
3761
instance.full_clean()
3862
# Update only the fields that are meant to be updated.
3963
# Django docs reference:
4064
# https://docs.djangoproject.com/en/dev/ref/models/instances/#specifying-which-fields-to-save
41-
instance.save(update_fields=fields)
65+
instance.save(update_fields=update_fields)
66+
67+
for field_name, value in m2m_data.items():
68+
related_manager = getattr(instance, field_name)
69+
related_manager.set(value)
70+
71+
# Still not sure about this.
72+
# What if we only update m2m relations & nothing on the model? Is this still considered as updated?
73+
has_updated = True
4274

4375
return instance, has_updated
Lines changed: 82 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,101 @@
1-
import unittest
1+
from datetime import timedelta
22
from unittest.mock import Mock
33

4+
from django.db import connection
5+
from django.test import TestCase
6+
from django.test.utils import CaptureQueriesContext
7+
8+
from styleguide_example.common.factories import RandomModelFactory, SimpleModelFactory
49
from styleguide_example.common.services import model_update
510

611

7-
class ModelUpdateTests(unittest.TestCase):
12+
class ModelUpdateTests(TestCase):
813
def setUp(self):
14+
self.model_instance = RandomModelFactory()
15+
self.simple_object = SimpleModelFactory()
916
self.instance = Mock(field_a=None, field_b=None, field_c=None)
1017

11-
def test_model_update_does_not_update_if_none_of_the_fields_are_in_the_data(self):
12-
update_fields = ["non_existing_field"]
13-
data = {"field_a": "value_a"}
18+
def test_model_update_does_nothing(self):
19+
with self.subTest("when no fields are provided"):
20+
instance = RandomModelFactory()
21+
22+
updated_instance, has_updated = model_update(instance=instance, fields=[], data={})
1423

15-
updated_instance, has_updated = model_update(instance=self.instance, fields=update_fields, data=data)
24+
self.assertEqual(instance, updated_instance)
25+
self.assertFalse(has_updated)
26+
self.assertNumQueries(0)
1627

17-
self.assertEqual(updated_instance, self.instance)
18-
self.assertFalse(has_updated)
28+
with self.subTest("when non of the fields are in the data"):
29+
instance = RandomModelFactory()
1930

20-
self.assertIsNone(updated_instance.field_a)
21-
self.assertIsNone(updated_instance.field_b)
22-
self.assertIsNone(updated_instance.field_c)
31+
updated_instance, has_updated = model_update(instance=instance, fields=["start_date"], data={"foo": "bar"})
2332

24-
self.instance.full_clean.assert_not_called()
25-
self.instance.save.assert_not_called()
33+
self.assertEqual(instance, updated_instance)
34+
self.assertFalse(has_updated)
35+
self.assertNumQueries(0)
2636

2737
def test_model_update_updates_only_passed_fields_from_data(self):
28-
update_fields = ["field_a"]
29-
data = {"field_a": "value_a", "field_b": "value_b"}
38+
instance = RandomModelFactory()
39+
40+
update_fields = ["start_date"]
41+
data = {
42+
"field_a": "value_a",
43+
"start_date": instance.start_date - timedelta(days=1),
44+
"end_date": instance.end_date + timedelta(days=1),
45+
}
46+
47+
self.assertNotEqual(instance.start_date, data["start_date"])
48+
49+
update_query = None
3050

31-
updated_instance, has_updated = model_update(instance=self.instance, fields=update_fields, data=data)
51+
with CaptureQueriesContext(connection) as ctx:
52+
updated_instance, has_updated = model_update(instance=instance, fields=update_fields, data=data)
53+
update_query = ctx.captured_queries[-1]
3254

3355
self.assertTrue(has_updated)
56+
self.assertEqual(updated_instance.start_date, data["start_date"])
57+
self.assertNotEqual(updated_instance.end_date, data["end_date"])
58+
59+
self.assertFalse(hasattr(updated_instance, "field_a"))
60+
61+
self.assertNotIn("end_date", update_query)
62+
63+
def test_model_update_raises_error_when_called_with_non_existent_field(self):
64+
instance = RandomModelFactory()
65+
66+
update_fields = ["non_existing_field"]
67+
data = {"non_existing_field": "foo"}
68+
69+
with self.assertRaises(AssertionError):
70+
updated_instance, has_updated = model_update(instance=instance, fields=update_fields, data=data)
3471

35-
self.assertEqual(updated_instance.field_a, "value_a")
36-
# Even though `field_b` is passed in `data` - it does not get updated
37-
# because it is not present in the `fields` list.
38-
self.assertIsNone(updated_instance.field_b)
39-
# `field_c` remains `None`, because it is not passed anywhere.
40-
self.assertIsNone(updated_instance.field_c)
72+
def test_model_update_updates_many_to_many_fields(self):
73+
instance = RandomModelFactory()
74+
simple_obj = SimpleModelFactory()
4175

42-
self.instance.full_clean.assert_called_once()
43-
self.instance.save.assert_called_once_with(update_fields=update_fields)
76+
update_fields = ["simple_objects"]
77+
data = {"simple_objects": [simple_obj]}
78+
79+
self.assertNotIn(simple_obj, instance.simple_objects.all())
80+
81+
updated_instance, has_updated = model_update(instance=instance, fields=update_fields, data=data)
82+
83+
self.assertEqual(updated_instance, instance)
84+
self.assertTrue(has_updated)
85+
86+
self.assertIn(simple_obj, updated_instance.simple_objects.all())
87+
88+
def test_model_update_updates_standard_and_many_to_many_fields(self):
89+
instance = RandomModelFactory()
90+
simple_obj = SimpleModelFactory()
91+
92+
update_fields = ["start_date", "simple_objects"]
93+
data = {"start_date": instance.start_date - timedelta(days=1), "simple_objects": [simple_obj]}
94+
95+
self.assertNotIn(simple_obj, instance.simple_objects.all())
96+
97+
updated_instance, has_updated = model_update(instance=instance, fields=update_fields, data=data)
98+
99+
self.assertTrue(has_updated)
100+
self.assertEqual(updated_instance.start_date, data["start_date"])
101+
self.assertIn(simple_obj, updated_instance.simple_objects.all())

0 commit comments

Comments
 (0)