|
1 |
| -import unittest |
| 1 | +from datetime import timedelta |
2 | 2 | from unittest.mock import Mock
|
3 | 3 |
|
| 4 | +from django.db import connection |
| 5 | +from django.test import TestCase |
| 6 | +from django.test.utils import CaptureQueriesContext |
| 7 | + |
| 8 | +from styleguide_example.common.factories import RandomModelFactory, SimpleModelFactory |
4 | 9 | from styleguide_example.common.services import model_update
|
5 | 10 |
|
6 | 11 |
|
7 |
| -class ModelUpdateTests(unittest.TestCase): |
| 12 | +class ModelUpdateTests(TestCase): |
8 | 13 | def setUp(self):
|
| 14 | + self.model_instance = RandomModelFactory() |
| 15 | + self.simple_object = SimpleModelFactory() |
9 | 16 | self.instance = Mock(field_a=None, field_b=None, field_c=None)
|
10 | 17 |
|
11 |
| - def test_model_update_does_not_update_if_none_of_the_fields_are_in_the_data(self): |
12 |
| - update_fields = ["non_existing_field"] |
13 |
| - data = {"field_a": "value_a"} |
| 18 | + def test_model_update_does_nothing(self): |
| 19 | + with self.subTest("when no fields are provided"): |
| 20 | + instance = RandomModelFactory() |
| 21 | + |
| 22 | + updated_instance, has_updated = model_update(instance=instance, fields=[], data={}) |
14 | 23 |
|
15 |
| - updated_instance, has_updated = model_update(instance=self.instance, fields=update_fields, data=data) |
| 24 | + self.assertEqual(instance, updated_instance) |
| 25 | + self.assertFalse(has_updated) |
| 26 | + self.assertNumQueries(0) |
16 | 27 |
|
17 |
| - self.assertEqual(updated_instance, self.instance) |
18 |
| - self.assertFalse(has_updated) |
| 28 | + with self.subTest("when non of the fields are in the data"): |
| 29 | + instance = RandomModelFactory() |
19 | 30 |
|
20 |
| - self.assertIsNone(updated_instance.field_a) |
21 |
| - self.assertIsNone(updated_instance.field_b) |
22 |
| - self.assertIsNone(updated_instance.field_c) |
| 31 | + updated_instance, has_updated = model_update(instance=instance, fields=["start_date"], data={"foo": "bar"}) |
23 | 32 |
|
24 |
| - self.instance.full_clean.assert_not_called() |
25 |
| - self.instance.save.assert_not_called() |
| 33 | + self.assertEqual(instance, updated_instance) |
| 34 | + self.assertFalse(has_updated) |
| 35 | + self.assertNumQueries(0) |
26 | 36 |
|
27 | 37 | def test_model_update_updates_only_passed_fields_from_data(self):
|
28 |
| - update_fields = ["field_a"] |
29 |
| - data = {"field_a": "value_a", "field_b": "value_b"} |
| 38 | + instance = RandomModelFactory() |
| 39 | + |
| 40 | + update_fields = ["start_date"] |
| 41 | + data = { |
| 42 | + "field_a": "value_a", |
| 43 | + "start_date": instance.start_date - timedelta(days=1), |
| 44 | + "end_date": instance.end_date + timedelta(days=1), |
| 45 | + } |
| 46 | + |
| 47 | + self.assertNotEqual(instance.start_date, data["start_date"]) |
| 48 | + |
| 49 | + update_query = None |
30 | 50 |
|
31 |
| - updated_instance, has_updated = model_update(instance=self.instance, fields=update_fields, data=data) |
| 51 | + with CaptureQueriesContext(connection) as ctx: |
| 52 | + updated_instance, has_updated = model_update(instance=instance, fields=update_fields, data=data) |
| 53 | + update_query = ctx.captured_queries[-1] |
32 | 54 |
|
33 | 55 | self.assertTrue(has_updated)
|
| 56 | + self.assertEqual(updated_instance.start_date, data["start_date"]) |
| 57 | + self.assertNotEqual(updated_instance.end_date, data["end_date"]) |
| 58 | + |
| 59 | + self.assertFalse(hasattr(updated_instance, "field_a")) |
| 60 | + |
| 61 | + self.assertNotIn("end_date", update_query) |
| 62 | + |
| 63 | + def test_model_update_raises_error_when_called_with_non_existent_field(self): |
| 64 | + instance = RandomModelFactory() |
| 65 | + |
| 66 | + update_fields = ["non_existing_field"] |
| 67 | + data = {"non_existing_field": "foo"} |
| 68 | + |
| 69 | + with self.assertRaises(AssertionError): |
| 70 | + updated_instance, has_updated = model_update(instance=instance, fields=update_fields, data=data) |
34 | 71 |
|
35 |
| - self.assertEqual(updated_instance.field_a, "value_a") |
36 |
| - # Even though `field_b` is passed in `data` - it does not get updated |
37 |
| - # because it is not present in the `fields` list. |
38 |
| - self.assertIsNone(updated_instance.field_b) |
39 |
| - # `field_c` remains `None`, because it is not passed anywhere. |
40 |
| - self.assertIsNone(updated_instance.field_c) |
| 72 | + def test_model_update_updates_many_to_many_fields(self): |
| 73 | + instance = RandomModelFactory() |
| 74 | + simple_obj = SimpleModelFactory() |
41 | 75 |
|
42 |
| - self.instance.full_clean.assert_called_once() |
43 |
| - self.instance.save.assert_called_once_with(update_fields=update_fields) |
| 76 | + update_fields = ["simple_objects"] |
| 77 | + data = {"simple_objects": [simple_obj]} |
| 78 | + |
| 79 | + self.assertNotIn(simple_obj, instance.simple_objects.all()) |
| 80 | + |
| 81 | + updated_instance, has_updated = model_update(instance=instance, fields=update_fields, data=data) |
| 82 | + |
| 83 | + self.assertEqual(updated_instance, instance) |
| 84 | + self.assertTrue(has_updated) |
| 85 | + |
| 86 | + self.assertIn(simple_obj, updated_instance.simple_objects.all()) |
| 87 | + |
| 88 | + def test_model_update_updates_standard_and_many_to_many_fields(self): |
| 89 | + instance = RandomModelFactory() |
| 90 | + simple_obj = SimpleModelFactory() |
| 91 | + |
| 92 | + update_fields = ["start_date", "simple_objects"] |
| 93 | + data = {"start_date": instance.start_date - timedelta(days=1), "simple_objects": [simple_obj]} |
| 94 | + |
| 95 | + self.assertNotIn(simple_obj, instance.simple_objects.all()) |
| 96 | + |
| 97 | + updated_instance, has_updated = model_update(instance=instance, fields=update_fields, data=data) |
| 98 | + |
| 99 | + self.assertTrue(has_updated) |
| 100 | + self.assertEqual(updated_instance.start_date, data["start_date"]) |
| 101 | + self.assertIn(simple_obj, updated_instance.simple_objects.all()) |
0 commit comments