|
1 | 1 | import erp |
| 2 | +import json |
| 3 | +from collections import Counter |
2 | 4 | from django.test import TestCase |
| 5 | + |
| 6 | +import reversion |
| 7 | +from reversion.errors import RevertError |
| 8 | +from reversion.models import Version |
| 9 | +from django.contrib.contenttypes.models import ContentType |
| 10 | + |
3 | 11 | from api.models import Event, FieldReport, Region, Country, DisasterType, ERPGUID |
4 | 12 | from main.mock import erp_request_side_effect_mock |
| 13 | +from main.utils import DjangoReversionDataFixHelper |
5 | 14 | from unittest.mock import patch |
6 | 15 |
|
| 16 | +from per.models import Overview as PerOverview |
| 17 | +from per.factories import OverviewFactory as PerOverviewFactory |
7 | 18 | from api.factories import disaster_type as dtFactory |
8 | 19 | from api.factories import country as countryFactory |
9 | 20 | from api.factories import event as eventFactory |
@@ -44,3 +55,85 @@ def test_successful(self, erp_request_side_effect_mock): |
44 | 55 | self.assertEqual(ERP.api_guid, 'FindThisGUID') |
45 | 56 | self.assertEqual(ERP.field_report_id, report.id) |
46 | 57 | self.assertEqual(erp_request_side_effect_mock.called, True) |
| 58 | + |
| 59 | + |
| 60 | +class DjangoReversionDataFixHelperTest(TestCase): |
| 61 | + def get_version_qs(self): |
| 62 | + return Version.objects.get_for_model(PerOverview) |
| 63 | + |
| 64 | + def update_serialized_data(self, raw_data, new_value): |
| 65 | + new_data = json.loads(raw_data) |
| 66 | + new_data[0]['fields'][self.field_name] = new_value |
| 67 | + return json.dumps(new_data) |
| 68 | + |
| 69 | + def get_version_data_snapshot(self, field_name): |
| 70 | + # Version data snapshot excluding self.field_name |
| 71 | + version_data_snapshot = [] |
| 72 | + for _id, data_raw in self.get_version_qs().values_list('id', 'serialized_data').order_by('id'): |
| 73 | + data = json.loads(data_raw) |
| 74 | + data[0]['fields'].pop(field_name) |
| 75 | + version_data_snapshot.append({ |
| 76 | + 'id': _id, |
| 77 | + 'data': data, |
| 78 | + }) |
| 79 | + return version_data_snapshot |
| 80 | + |
| 81 | + def assert_values(self, values: dict): |
| 82 | + # Make sure other values are not changed |
| 83 | + assert self.version_data_snapshot == self.get_version_data_snapshot(self.field_name) |
| 84 | + # Check count |
| 85 | + assert self.get_version_qs().count() == sum(values.values()) |
| 86 | + # Check count by value |
| 87 | + assert dict(Counter([ |
| 88 | + json.loads(data)[0]['fields'][self.field_name] |
| 89 | + for data in self.get_version_qs().values_list('serialized_data', flat=True) |
| 90 | + ])) == values |
| 91 | + |
| 92 | + def confirm_version_data_serialization(self): |
| 93 | + for version in self.get_version_qs().all(): |
| 94 | + version._local_field_dict |
| 95 | + |
| 96 | + def setUp(self): |
| 97 | + super().setUp() |
| 98 | + self.field_name = 'date_of_assessment' |
| 99 | + for _ in range(95): |
| 100 | + reversion.create_revision()(PerOverviewFactory.create)() |
| 101 | + |
| 102 | + versions = self.get_version_qs().all() |
| 103 | + # Create dataset with different formats |
| 104 | + for index, version in enumerate(versions): |
| 105 | + new_value = '2022-01-01' |
| 106 | + if (index % 2) == 0: |
| 107 | + new_value = '2022-01-01T00:00:00' |
| 108 | + version.serialized_data = self.update_serialized_data( |
| 109 | + version.serialized_data, |
| 110 | + new_value, |
| 111 | + ) |
| 112 | + version.save() |
| 113 | + Version.objects.bulk_update(versions, fields=('serialized_data',)) |
| 114 | + # Version data snapshot excluding self.field_name |
| 115 | + self.version_data_snapshot = self.get_version_data_snapshot(self.field_name) |
| 116 | + self.assert_values({'2022-01-01': 47, '2022-01-01T00:00:00': 48}) |
| 117 | + |
| 118 | + def test_date_fields_to_datetime(self): |
| 119 | + self.assert_values({'2022-01-01': 47, '2022-01-01T00:00:00': 48}) |
| 120 | + DjangoReversionDataFixHelper.date_fields_to_datetime( |
| 121 | + ContentType, |
| 122 | + Version, |
| 123 | + PerOverview, |
| 124 | + [self.field_name] |
| 125 | + ) |
| 126 | + self.assert_values({'2022-01-01T00:00:00': 95}) |
| 127 | + |
| 128 | + def test_datetime_fields_to_date(self): |
| 129 | + with self.assertRaises(RevertError): |
| 130 | + self.confirm_version_data_serialization() |
| 131 | + self.assert_values({'2022-01-01': 47, '2022-01-01T00:00:00': 48}) |
| 132 | + DjangoReversionDataFixHelper.datetime_fields_to_date( |
| 133 | + ContentType, |
| 134 | + Version, |
| 135 | + PerOverview, |
| 136 | + [self.field_name] |
| 137 | + ) |
| 138 | + self.assert_values({'2022-01-01': 95}) |
| 139 | + self.confirm_version_data_serialization() |
0 commit comments