@@ -300,6 +300,18 @@ def assertSerializedFieldEqual(self, value):
300300 self .assertEqual (value .null , new_value .null )
301301 self .assertEqual (value .unique , new_value .unique )
302302
303+ def assertSerializedFunctoolsPartialEqual (
304+ self , value , expected_string , expected_imports
305+ ):
306+ string , imports = MigrationWriter .serialize (value )
307+ self .assertEqual (string , expected_string )
308+ self .assertEqual (imports , expected_imports )
309+ result = self .serialize_round_trip (value )
310+ self .assertEqual (result .func , value .func )
311+ self .assertEqual (result .args , value .args )
312+ self .assertEqual (result .keywords , value .keywords )
313+ return result
314+
303315 def test_serialize_numbers (self ):
304316 self .assertSerializedEqual (1 )
305317 self .assertSerializedEqual (1.2 )
@@ -895,19 +907,59 @@ def test_serialize_timedelta(self):
895907 self .assertSerializedEqual (datetime .timedelta (minutes = 42 ))
896908
897909 def test_serialize_functools_partial (self ):
910+ value = functools .partial (datetime .timedelta )
911+ string , imports = MigrationWriter .serialize (value )
912+ self .assertSerializedFunctoolsPartialEqual (
913+ value ,
914+ "functools.partial(datetime.timedelta, *(), **{})" ,
915+ {"import datetime" , "import functools" },
916+ )
917+
918+ def test_serialize_functools_partial_posarg (self ):
919+ value = functools .partial (datetime .timedelta , 1 )
920+ string , imports = MigrationWriter .serialize (value )
921+ self .assertSerializedFunctoolsPartialEqual (
922+ value ,
923+ "functools.partial(datetime.timedelta, *(1,), **{})" ,
924+ {"import datetime" , "import functools" },
925+ )
926+
927+ def test_serialize_functools_partial_kwarg (self ):
928+ value = functools .partial (datetime .timedelta , seconds = 2 )
929+ string , imports = MigrationWriter .serialize (value )
930+ self .assertSerializedFunctoolsPartialEqual (
931+ value ,
932+ "functools.partial(datetime.timedelta, *(), **{'seconds': 2})" ,
933+ {"import datetime" , "import functools" },
934+ )
935+
936+ def test_serialize_functools_partial_mixed (self ):
898937 value = functools .partial (datetime .timedelta , 1 , seconds = 2 )
899- result = self .serialize_round_trip (value )
900- self .assertEqual (result .func , value .func )
901- self .assertEqual (result .args , value .args )
902- self .assertEqual (result .keywords , value .keywords )
938+ string , imports = MigrationWriter .serialize (value )
939+ self .assertSerializedFunctoolsPartialEqual (
940+ value ,
941+ "functools.partial(datetime.timedelta, *(1,), **{'seconds': 2})" ,
942+ {"import datetime" , "import functools" },
943+ )
944+
945+ def test_serialize_functools_partial_non_identifier_keyword (self ):
946+ value = functools .partial (datetime .timedelta , ** {"kebab-case" : 1 })
947+ string , imports = MigrationWriter .serialize (value )
948+ self .assertSerializedFunctoolsPartialEqual (
949+ value ,
950+ "functools.partial(datetime.timedelta, *(), **{'kebab-case': 1})" ,
951+ {"import datetime" , "import functools" },
952+ )
903953
904954 def test_serialize_functools_partialmethod (self ):
905955 value = functools .partialmethod (datetime .timedelta , 1 , seconds = 2 )
906- result = self .serialize_round_trip (value )
956+ string , imports = MigrationWriter .serialize (value )
957+ result = self .assertSerializedFunctoolsPartialEqual (
958+ value ,
959+ "functools.partialmethod(datetime.timedelta, *(1,), **{'seconds': 2})" ,
960+ {"import datetime" , "import functools" },
961+ )
907962 self .assertIsInstance (result , functools .partialmethod )
908- self .assertEqual (result .func , value .func )
909- self .assertEqual (result .args , value .args )
910- self .assertEqual (result .keywords , value .keywords )
911963
912964 def test_serialize_type_none (self ):
913965 self .assertSerializedEqual (NoneType )
0 commit comments