Skip to content

Commit 57fdc10

Browse files
adamchainznessita
andcommitted
Refs #36383 -- Added extra tests for serializing functools.partial in tests/migrations/test_writer.py.
This includes a test helper to better assert over the expected output. Co-authored-by: Natalia <[email protected]>
1 parent 4647e2b commit 57fdc10

File tree

1 file changed

+60
-8
lines changed

1 file changed

+60
-8
lines changed

tests/migrations/test_writer.py

Lines changed: 60 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,18 @@ def assertSerializedFieldEqual(self, value):
300300
self.assertEqual(value.null, new_value.null)
301301
self.assertEqual(value.unique, new_value.unique)
302302

303+
def assertSerializedFunctoolsPartialEqual(
304+
self, value, expected_string, expected_imports
305+
):
306+
string, imports = MigrationWriter.serialize(value)
307+
self.assertEqual(string, expected_string)
308+
self.assertEqual(imports, expected_imports)
309+
result = self.serialize_round_trip(value)
310+
self.assertEqual(result.func, value.func)
311+
self.assertEqual(result.args, value.args)
312+
self.assertEqual(result.keywords, value.keywords)
313+
return result
314+
303315
def test_serialize_numbers(self):
304316
self.assertSerializedEqual(1)
305317
self.assertSerializedEqual(1.2)
@@ -895,19 +907,59 @@ def test_serialize_timedelta(self):
895907
self.assertSerializedEqual(datetime.timedelta(minutes=42))
896908

897909
def test_serialize_functools_partial(self):
910+
value = functools.partial(datetime.timedelta)
911+
string, imports = MigrationWriter.serialize(value)
912+
self.assertSerializedFunctoolsPartialEqual(
913+
value,
914+
"functools.partial(datetime.timedelta, *(), **{})",
915+
{"import datetime", "import functools"},
916+
)
917+
918+
def test_serialize_functools_partial_posarg(self):
919+
value = functools.partial(datetime.timedelta, 1)
920+
string, imports = MigrationWriter.serialize(value)
921+
self.assertSerializedFunctoolsPartialEqual(
922+
value,
923+
"functools.partial(datetime.timedelta, *(1,), **{})",
924+
{"import datetime", "import functools"},
925+
)
926+
927+
def test_serialize_functools_partial_kwarg(self):
928+
value = functools.partial(datetime.timedelta, seconds=2)
929+
string, imports = MigrationWriter.serialize(value)
930+
self.assertSerializedFunctoolsPartialEqual(
931+
value,
932+
"functools.partial(datetime.timedelta, *(), **{'seconds': 2})",
933+
{"import datetime", "import functools"},
934+
)
935+
936+
def test_serialize_functools_partial_mixed(self):
898937
value = functools.partial(datetime.timedelta, 1, seconds=2)
899-
result = self.serialize_round_trip(value)
900-
self.assertEqual(result.func, value.func)
901-
self.assertEqual(result.args, value.args)
902-
self.assertEqual(result.keywords, value.keywords)
938+
string, imports = MigrationWriter.serialize(value)
939+
self.assertSerializedFunctoolsPartialEqual(
940+
value,
941+
"functools.partial(datetime.timedelta, *(1,), **{'seconds': 2})",
942+
{"import datetime", "import functools"},
943+
)
944+
945+
def test_serialize_functools_partial_non_identifier_keyword(self):
946+
value = functools.partial(datetime.timedelta, **{"kebab-case": 1})
947+
string, imports = MigrationWriter.serialize(value)
948+
self.assertSerializedFunctoolsPartialEqual(
949+
value,
950+
"functools.partial(datetime.timedelta, *(), **{'kebab-case': 1})",
951+
{"import datetime", "import functools"},
952+
)
903953

904954
def test_serialize_functools_partialmethod(self):
905955
value = functools.partialmethod(datetime.timedelta, 1, seconds=2)
906-
result = self.serialize_round_trip(value)
956+
string, imports = MigrationWriter.serialize(value)
957+
result = self.assertSerializedFunctoolsPartialEqual(
958+
value,
959+
"functools.partialmethod(datetime.timedelta, *(1,), **{'seconds': 2})",
960+
{"import datetime", "import functools"},
961+
)
907962
self.assertIsInstance(result, functools.partialmethod)
908-
self.assertEqual(result.func, value.func)
909-
self.assertEqual(result.args, value.args)
910-
self.assertEqual(result.keywords, value.keywords)
911963

912964
def test_serialize_type_none(self):
913965
self.assertSerializedEqual(NoneType)

0 commit comments

Comments
 (0)