102102 PY2 = False
103103 from importlib ._bootstrap import _find_spec
104104
105+ _extract_code_globals_cache = weakref .WeakKeyDictionary ()
106+
105107
106108def _ensure_tracking (class_def ):
107109 with _DYNAMIC_CLASS_TRACKER_LOCK :
@@ -195,6 +197,78 @@ def _is_global(obj, name=None):
195197 return obj2 is obj
196198
197199
200+ def _extract_code_globals (co ):
201+ """
202+ Find all globals names read or written to by codeblock co
203+ """
204+ out_names = _extract_code_globals_cache .get (co )
205+ if out_names is None :
206+ names = co .co_names
207+ out_names = {names [oparg ] for _ , oparg in _walk_global_ops (co )}
208+
209+ # Declaring a function inside another one using the "def ..."
210+ # syntax generates a constant code object corresonding to the one
211+ # of the nested function's As the nested function may itself need
212+ # global variables, we need to introspect its code, extract its
213+ # globals, (look for code object in it's co_consts attribute..) and
214+ # add the result to code_globals
215+ if co .co_consts :
216+ for const in co .co_consts :
217+ if isinstance (const , types .CodeType ):
218+ out_names |= _extract_code_globals (const )
219+
220+ _extract_code_globals_cache [co ] = out_names
221+
222+ return out_names
223+
224+
225+ def _find_imported_submodules (code , top_level_dependencies ):
226+ """
227+ Find currently imported submodules used by a function.
228+
229+ Submodules used by a function need to be detected and referenced for the
230+ function to work correctly at depickling time. Because submodules can be
231+ referenced as attribute of their parent package (``package.submodule``), we
232+ need a special introspection technique that does not rely on GLOBAL-related
233+ opcodes to find references of them in a code object.
234+
235+ Example:
236+ ```
237+ import concurrent.futures
238+ import cloudpickle
239+ def func():
240+ x = concurrent.futures.ThreadPoolExecutor
241+ if __name__ == '__main__':
242+ cloudpickle.dumps(func)
243+ ```
244+ The globals extracted by cloudpickle in the function's state include the
245+ concurrent package, but not its submodule (here, concurrent.futures), which
246+ is the module used by func. Find_imported_submodules will detect the usage
247+ of concurrent.futures. Saving this module alongside with func will ensure
248+ that calling func once depickled does not fail due to concurrent.futures
249+ not being imported
250+ """
251+
252+ subimports = []
253+ # check if any known dependency is an imported package
254+ for x in top_level_dependencies :
255+ if (isinstance (x , types .ModuleType ) and
256+ hasattr (x , '__package__' ) and x .__package__ ):
257+ # check if the package has any currently loaded sub-imports
258+ prefix = x .__name__ + '.'
259+ # A concurrent thread could mutate sys.modules,
260+ # make sure we iterate over a copy to avoid exceptions
261+ for name in list (sys .modules ):
262+ # Older versions of pytest will add a "None" module to
263+ # sys.modules.
264+ if name is not None and name .startswith (prefix ):
265+ # check whether the function can address the sub-module
266+ tokens = set (name [len (prefix ):].split ('.' ))
267+ if not tokens - set (code .co_names ):
268+ subimports .append (sys .modules [name ])
269+ return subimports
270+
271+
198272def _make_cell_set_template_code ():
199273 """Get the Python compiler to emit LOAD_FAST(arg); STORE_DEREF
200274
@@ -493,54 +567,6 @@ def save_pypy_builtin_func(self, obj):
493567 obj .__dict__ )
494568 self .save_reduce (* rv , obj = obj )
495569
496-
497- def _save_subimports (self , code , top_level_dependencies ):
498- """
499- Save submodules used by a function but not listed in its globals.
500-
501- In the example below:
502-
503- ```
504- import concurrent.futures
505- import cloudpickle
506-
507-
508- def func():
509- x = concurrent.futures.ThreadPoolExecutor
510-
511-
512- if __name__ == '__main__':
513- cloudpickle.dumps(func)
514- ```
515-
516- the globals extracted by cloudpickle in the function's state include
517- the concurrent module, but not its submodule (here,
518- concurrent.futures), which is the module used by func.
519-
520- To ensure that calling the depickled function does not raise an
521- AttributeError, this function looks for any currently loaded submodule
522- that the function uses and whose parent is present in the function
523- globals, and saves it before saving the function.
524- """
525-
526- # check if any known dependency is an imported package
527- for x in top_level_dependencies :
528- if isinstance (x , types .ModuleType ) and hasattr (x , '__package__' ) and x .__package__ :
529- # check if the package has any currently loaded sub-imports
530- prefix = x .__name__ + '.'
531- # A concurrent thread could mutate sys.modules,
532- # make sure we iterate over a copy to avoid exceptions
533- for name in list (sys .modules ):
534- # Older versions of pytest will add a "None" module to sys.modules.
535- if name is not None and name .startswith (prefix ):
536- # check whether the function can address the sub-module
537- tokens = set (name [len (prefix ):].split ('.' ))
538- if not tokens - set (code .co_names ):
539- # ensure unpickler executes this import
540- self .save (sys .modules [name ])
541- # then discards the reference to it
542- self .write (pickle .POP )
543-
544570 def _save_dynamic_enum (self , obj , clsdict ):
545571 """Special handling for dynamic Enum subclasses
546572
@@ -676,7 +702,12 @@ def save_function_tuple(self, func):
676702 save (_fill_function ) # skeleton function updater
677703 write (pickle .MARK ) # beginning of tuple that _fill_function expects
678704
679- self ._save_subimports (
705+ # Extract currently-imported submodules used by func. Storing these
706+ # modules in a smoke _cloudpickle_subimports attribute of the object's
707+ # state will trigger the side effect of importing these modules at
708+ # unpickling time (which is necessary for func to work correctly once
709+ # depickled)
710+ submodules = _find_imported_submodules (
680711 code ,
681712 itertools .chain (f_globals .values (), closure_values or ()),
682713 )
@@ -700,6 +731,7 @@ def save_function_tuple(self, func):
700731 'module' : func .__module__ ,
701732 'name' : func .__name__ ,
702733 'doc' : func .__doc__ ,
734+ '_cloudpickle_submodules' : submodules
703735 }
704736 if hasattr (func , '__annotations__' ) and sys .version_info >= (3 , 4 ):
705737 state ['annotations' ] = func .__annotations__
@@ -711,28 +743,6 @@ def save_function_tuple(self, func):
711743 write (pickle .TUPLE )
712744 write (pickle .REDUCE ) # applies _fill_function on the tuple
713745
714- _extract_code_globals_cache = weakref .WeakKeyDictionary ()
715-
716- @classmethod
717- def extract_code_globals (cls , co ):
718- """
719- Find all globals names read or written to by codeblock co
720- """
721- out_names = cls ._extract_code_globals_cache .get (co )
722- if out_names is None :
723- names = co .co_names
724- out_names = {names [oparg ] for _ , oparg in _walk_global_ops (co )}
725-
726- # see if nested function have any global refs
727- if co .co_consts :
728- for const in co .co_consts :
729- if isinstance (const , types .CodeType ):
730- out_names |= cls .extract_code_globals (const )
731-
732- cls ._extract_code_globals_cache [co ] = out_names
733-
734- return out_names
735-
736746 def extract_func_data (self , func ):
737747 """
738748 Turn the function into a tuple of data necessary to recreate it:
@@ -741,7 +751,7 @@ def extract_func_data(self, func):
741751 code = func .__code__
742752
743753 # extract all global ref's
744- func_global_refs = self . extract_code_globals (code )
754+ func_global_refs = _extract_code_globals (code )
745755
746756 # process all variables referenced by global environment
747757 f_globals = {}
@@ -1202,6 +1212,13 @@ def _fill_function(*args):
12021212 func .__qualname__ = state ['qualname' ]
12031213 if 'kwdefaults' in state :
12041214 func .__kwdefaults__ = state ['kwdefaults' ]
1215+ # _cloudpickle_subimports is a set of submodules that must be loaded for
1216+ # the pickled function to work correctly at unpickling time. Now that these
1217+ # submodules are depickled (hence imported), they can be removed from the
1218+ # object's state (the object state only served as a reference holder to
1219+ # these submodules)
1220+ if '_cloudpickle_submodules' in state :
1221+ state .pop ('_cloudpickle_submodules' )
12051222
12061223 cells = func .__closure__
12071224 if cells is not None :
0 commit comments