6161import dataclasses
6262import dis
6363from enum import Enum
64+ import functools
6465import io
6566import itertools
6667import logging
100101
101102PYPY = platform .python_implementation () == "PyPy"
102103
104+
105+ def uuid_generator (_ ):
106+ return uuid .uuid4 ().hex
107+
108+
109+ @dataclasses .dataclass
110+ class CloudPickleConfig :
111+ """Configuration for cloudpickle behavior."""
112+ id_generator : typing .Optional [callable ] = uuid_generator
113+ skip_reset_dynamic_type_state : bool = False
114+
115+
116+ DEFAULT_CONFIG = CloudPickleConfig ()
117+
103118builtin_code_type = None
104119if PYPY :
105120 # builtin-code objects only exist in pypy
108123_extract_code_globals_cache = weakref .WeakKeyDictionary ()
109124
110125
111- def _get_or_create_tracker_id (class_def ):
126+ def _get_or_create_tracker_id (class_def , id_generator ):
112127 with _DYNAMIC_CLASS_TRACKER_LOCK :
113128 class_tracker_id = _DYNAMIC_CLASS_TRACKER_BY_CLASS .get (class_def )
114- if class_tracker_id is None :
115- class_tracker_id = uuid . uuid4 (). hex
129+ if class_tracker_id is None and id_generator is not None :
130+ class_tracker_id = id_generator ( class_def )
116131 _DYNAMIC_CLASS_TRACKER_BY_CLASS [class_def ] = class_tracker_id
117132 _DYNAMIC_CLASS_TRACKER_BY_ID [class_tracker_id ] = class_def
118133 return class_tracker_id
@@ -593,26 +608,26 @@ def _make_typevar(
593608 return _lookup_class_or_track (class_tracker_id , tv )
594609
595610
596- def _decompose_typevar (obj ):
611+ def _decompose_typevar (obj , config ):
597612 return (
598613 obj .__name__ ,
599614 obj .__bound__ ,
600615 obj .__constraints__ ,
601616 obj .__covariant__ ,
602617 obj .__contravariant__ ,
603- _get_or_create_tracker_id (obj ),
618+ _get_or_create_tracker_id (obj , config . id_generator ),
604619 )
605620
606621
607- def _typevar_reduce (obj ):
622+ def _typevar_reduce (obj , config ):
608623 # TypeVar instances require the module information hence why we
609624 # are not using the _should_pickle_by_reference directly
610625 module_and_name = _lookup_module_and_qualname (obj , name = obj .__name__ )
611626
612627 if module_and_name is None :
613- return (_make_typevar , _decompose_typevar (obj ))
628+ return (_make_typevar , _decompose_typevar (obj , config ))
614629 elif _is_registered_pickle_by_value (module_and_name [0 ]):
615- return (_make_typevar , _decompose_typevar (obj ))
630+ return (_make_typevar , _decompose_typevar (obj , config ))
616631
617632 return (getattr , module_and_name )
618633
@@ -656,7 +671,7 @@ def _make_dict_items(obj, is_ordered=False):
656671# -------------------------------------------------
657672
658673
659- def _class_getnewargs (obj ):
674+ def _class_getnewargs (obj , config ):
660675 type_kwargs = {}
661676 if "__module__" in obj .__dict__ :
662677 type_kwargs ["__module__" ] = obj .__module__
@@ -670,20 +685,20 @@ def _class_getnewargs(obj):
670685 obj .__name__ ,
671686 _get_bases (obj ),
672687 type_kwargs ,
673- _get_or_create_tracker_id (obj ),
688+ _get_or_create_tracker_id (obj , config . id_generator ),
674689 None ,
675690 )
676691
677692
678- def _enum_getnewargs (obj ):
693+ def _enum_getnewargs (obj , config ):
679694 members = {e .name : e .value for e in obj }
680695 return (
681696 obj .__bases__ ,
682697 obj .__name__ ,
683698 obj .__qualname__ ,
684699 members ,
685700 obj .__module__ ,
686- _get_or_create_tracker_id (obj ),
701+ _get_or_create_tracker_id (obj , config . id_generator ),
687702 None ,
688703 )
689704
@@ -1028,7 +1043,7 @@ def _weakset_reduce(obj):
10281043 return weakref .WeakSet , (list (obj ), )
10291044
10301045
1031- def _dynamic_class_reduce (obj ):
1046+ def _dynamic_class_reduce (obj , config ):
10321047 """Save a class that can't be referenced as a module attribute.
10331048
10341049 This method is used to serialize classes that are defined inside
@@ -1038,24 +1053,28 @@ def _dynamic_class_reduce(obj):
10381053 if Enum is not None and issubclass (obj , Enum ):
10391054 return (
10401055 _make_skeleton_enum ,
1041- _enum_getnewargs (obj ),
1056+ _enum_getnewargs (obj , config ),
10421057 _enum_getstate (obj ),
10431058 None ,
10441059 None ,
1045- _class_setstate ,
1060+ functools .partial (
1061+ _class_setstate ,
1062+ skip_reset_dynamic_type_state = config .skip_reset_dynamic_type_state ),
10461063 )
10471064 else :
10481065 return (
10491066 _make_skeleton_class ,
1050- _class_getnewargs (obj ),
1067+ _class_getnewargs (obj , config ),
10511068 _class_getstate (obj ),
10521069 None ,
10531070 None ,
1054- _class_setstate ,
1071+ functools .partial (
1072+ _class_setstate ,
1073+ skip_reset_dynamic_type_state = config .skip_reset_dynamic_type_state ),
10551074 )
10561075
10571076
1058- def _class_reduce (obj ):
1077+ def _class_reduce (obj , config ):
10591078 """Select the reducer depending on the dynamic nature of the class obj."""
10601079 if obj is type (None ): # noqa
10611080 return type , (None , )
@@ -1066,7 +1085,7 @@ def _class_reduce(obj):
10661085 elif obj in _BUILTIN_TYPE_NAMES :
10671086 return _builtin_type , (_BUILTIN_TYPE_NAMES [obj ], )
10681087 elif not _should_pickle_by_reference (obj ):
1069- return _dynamic_class_reduce (obj )
1088+ return _dynamic_class_reduce (obj , config )
10701089 return NotImplemented
10711090
10721091
@@ -1150,14 +1169,12 @@ def _function_setstate(obj, state):
11501169 setattr (obj , k , v )
11511170
11521171
1153- def _class_setstate (obj , state ):
1154- # This breaks the ability to modify the state of a dynamic type in the main
1155- # process wth the assumption that the type is updatable in the child process.
1172+ def _class_setstate (obj , state , skip_reset_dynamic_type_state ):
1173+ # Lock while potentially modifying class state.
11561174 with _DYNAMIC_CLASS_TRACKER_LOCK :
1157- if obj in _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS :
1175+ if skip_reset_dynamic_type_state and obj in _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS :
11581176 return obj
11591177 _DYNAMIC_CLASS_STATE_TRACKER_BY_CLASS [obj ] = True
1160-
11611178 state , slotstate = state
11621179 registry = None
11631180 for attrname , attr in state .items ():
@@ -1229,7 +1246,6 @@ class Pickler(pickle.Pickler):
12291246 _dispatch_table [types .MethodType ] = _method_reduce
12301247 _dispatch_table [types .MappingProxyType ] = _mappingproxy_reduce
12311248 _dispatch_table [weakref .WeakSet ] = _weakset_reduce
1232- _dispatch_table [typing .TypeVar ] = _typevar_reduce
12331249 _dispatch_table [_collections_abc .dict_keys ] = _dict_keys_reduce
12341250 _dispatch_table [_collections_abc .dict_values ] = _dict_values_reduce
12351251 _dispatch_table [_collections_abc .dict_items ] = _dict_items_reduce
@@ -1309,7 +1325,8 @@ def dump(self, obj):
13091325 else :
13101326 raise
13111327
1312- def __init__ (self , file , protocol = None , buffer_callback = None ):
1328+ def __init__ (
1329+ self , file , protocol = None , buffer_callback = None , config = DEFAULT_CONFIG ):
13131330 if protocol is None :
13141331 protocol = DEFAULT_PROTOCOL
13151332 super ().__init__ (file , protocol = protocol , buffer_callback = buffer_callback )
@@ -1318,6 +1335,7 @@ def __init__(self, file, protocol=None, buffer_callback=None):
13181335 # their global namespace at unpickling time.
13191336 self .globals_ref = {}
13201337 self .proto = int (protocol )
1338+ self .config = config
13211339
13221340 if not PYPY :
13231341 # pickle.Pickler is the C implementation of the CPython pickler and
@@ -1384,7 +1402,9 @@ def reducer_override(self, obj):
13841402 is_anyclass = False
13851403
13861404 if is_anyclass :
1387- return _class_reduce (obj )
1405+ return _class_reduce (obj , self .config )
1406+ elif isinstance (obj , typing .TypeVar ): # Add this check
1407+ return _typevar_reduce (obj , self .config )
13881408 elif isinstance (obj , types .FunctionType ):
13891409 return self ._function_reduce (obj )
13901410 else :
@@ -1454,12 +1474,20 @@ def save_global(self, obj, name=None, pack=struct.pack):
14541474 if name is not None :
14551475 super ().save_global (obj , name = name )
14561476 elif not _should_pickle_by_reference (obj , name = name ):
1457- self ._save_reduce_pickle5 (* _dynamic_class_reduce (obj ), obj = obj )
1477+ self ._save_reduce_pickle5 (
1478+ * _dynamic_class_reduce (obj , self .config ), obj = obj )
14581479 else :
14591480 super ().save_global (obj , name = name )
14601481
14611482 dispatch [type ] = save_global
14621483
1484+ def save_typevar (self , obj , name = None ):
1485+ """Handle TypeVar objects with access to config."""
1486+ return self ._save_reduce_pickle5 (
1487+ * _typevar_reduce (obj , self .config ), obj = obj )
1488+
1489+ dispatch [typing .TypeVar ] = save_typevar
1490+
14631491 def save_function (self , obj , name = None ):
14641492 """Registered with the dispatch to handle all function types.
14651493
@@ -1505,7 +1533,7 @@ def save_pypy_builtin_func(self, obj):
15051533# Shorthands similar to pickle.dump/pickle.dumps
15061534
15071535
1508- def dump (obj , file , protocol = None , buffer_callback = None ):
1536+ def dump (obj , file , protocol = None , buffer_callback = None , config = DEFAULT_CONFIG ):
15091537 """Serialize obj as bytes streamed into file
15101538
15111539 protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
@@ -1518,10 +1546,12 @@ def dump(obj, file, protocol=None, buffer_callback=None):
15181546 implementation details that can change from one Python version to the
15191547 next).
15201548 """
1521- Pickler (file , protocol = protocol , buffer_callback = buffer_callback ).dump (obj )
1549+ Pickler (
1550+ file , protocol = protocol , buffer_callback = buffer_callback ,
1551+ config = config ).dump (obj )
15221552
15231553
1524- def dumps (obj , protocol = None , buffer_callback = None ):
1554+ def dumps (obj , protocol = None , buffer_callback = None , config = DEFAULT_CONFIG ):
15251555 """Serialize obj as a string of bytes allocated in memory
15261556
15271557 protocol defaults to cloudpickle.DEFAULT_PROTOCOL which is an alias to
@@ -1535,7 +1565,8 @@ def dumps(obj, protocol=None, buffer_callback=None):
15351565 next).
15361566 """
15371567 with io .BytesIO () as file :
1538- cp = Pickler (file , protocol = protocol , buffer_callback = buffer_callback )
1568+ cp = Pickler (
1569+ file , protocol = protocol , buffer_callback = buffer_callback , config = config )
15391570 cp .dump (obj )
15401571 return file .getvalue ()
15411572
0 commit comments