1717# Standard
1818from concurrent .futures import ThreadPoolExecutor
1919from types import ModuleType
20- from typing import List , Optional , Set
20+ from typing import Dict , List , Optional , Set , Union
2121import argparse
2222import cmath
2323import importlib
24+ import inspect
2425import json
2526import logging
2627import os
@@ -54,9 +55,10 @@ def _get_import_parent_path(mod) -> str:
5455 return parent_path
5556
5657
57- def _get_non_std_modules (mod_names : Set [str ]) -> Set [str ]:
58+ def _get_non_std_modules (mod_names : Union [ Set [str ], Dict [ str , List [ dict ]] ]) -> Set [str ]:
5859 """Take a snapshot of the non-standard modules currently imported"""
59- return {
60+ # Determine the names from the list that are non-standard
61+ non_std_mods = {
6062 mod_name .split ("." )[0 ]
6163 for mod_name , mod in sys .modules .items ()
6264 if mod_name in mod_names
@@ -66,6 +68,17 @@ def _get_non_std_modules(mod_names: Set[str]) -> Set[str]:
6668 and mod_name .split ("." )[0 ] != THIS_PACKAGE
6769 }
6870
71+ # If this is a set, just return it directly
72+ if isinstance (mod_names , set ):
73+ return non_std_mods
74+
75+ # If it's a dict, limit to the non standard names
76+ return {
77+ mod_name : mod_vals
78+ for mod_name , mod_vals in mod_names .items ()
79+ if mod_name in non_std_mods
80+ }
81+
6982
7083class _DeferredModule (ModuleType ):
7184 """A _DeferredModule is a module subclass that wraps another module but imports
@@ -197,17 +210,16 @@ def exec_module(self, *_, **__):
197210class ImportTrackerMetaFinder (importlib .abc .MetaPathFinder ):
198211 """The ImportTrackerMetaFinder is a meta finder that is intended to be used
199212 at the front of the sys.meta_path to automatically track the imports for a
200- given library. It does this by looking at the call stack when a given import
201- is requested and tracking the upstream for each import made inside of the
202- target package.
203-
204- NOTE: Since a stack trace is traversed on every import, this is very slow
205- and is intended only for a static build-time operation and should not be
206- used during the import phase of a library at runtime!
213+ given library. It does this by deferring all imports which occur before the
214+ target module has been seen, then collecting all imports seen until the
215+ target import has completed.
207216 """
208217
209218 def __init__ (
210- self , tracked_module : str , side_effect_modules : Optional [List [str ]] = None
219+ self ,
220+ tracked_module : str ,
221+ side_effect_modules : Optional [List [str ]] = None ,
222+ track_import_stack : bool = False ,
211223 ):
212224 """Initialize with the name of the package being tracked
213225
@@ -219,6 +231,12 @@ def __init__(
219231 to perform required import tasks (e.g. global singleton
220232 registries). These modules will be allowed to import regardless
221233 of where they fall relative to the targeted module.
234+ track_import_stack: bool
235+ If true, when imports are allowed through, their stack trace is
236+ captured.
237+ NOTE: This will cause a stack trace to be computed for every
238+ import in the tracked set, so it will be very slow and
239+ should only be used as a debugging tool on targeted imports.
222240 """
223241 self ._tracked_module = tracked_module
224242 self ._side_effect_modules = side_effect_modules or []
@@ -228,9 +246,14 @@ def __init__(
228246 log .debug2 ("Starting modules: %s" , self ._starting_modules )
229247 self ._ending_modules = None
230248 self ._deferred_modules = set ()
249+ self ._track_import_stack = track_import_stack
250+ self ._import_stacks = {}
231251
232252 def find_spec (
233- self , fullname : str , * args , ** kwargs
253+ self ,
254+ fullname : str ,
255+ * args ,
256+ ** kwargs ,
234257 ) -> Optional [importlib .machinery .ModuleSpec ]:
235258 """The find_spec implementation for this finder tracks the source of the
236259 import call for the given module and determines if it is on the critical
@@ -249,6 +272,59 @@ def find_spec(
249272 import is on the critical path, None will be returned to defer
250273 to the rest of the "real" finders.
251274 """
275+ # Do the main tracking logic
276+ result = self ._find_spec (fullname , * args , ** kwargs )
277+
278+ # If this module is deferred, return it
279+ if result is not None :
280+ log .debug2 ("Returning deferred module for [%s]" , fullname )
281+ return result
282+
283+ # If this module is part of the set of modules belonging to the tracked
284+ # module and stack tracing is enabled, grab all frames in the stack that
285+ # come from the tracked module's package.
286+ log .debug2 (
287+ "Stack tracking? %s, Ending modules set? %s" ,
288+ self ._track_import_stack ,
289+ self ._ending_modules is not None ,
290+ )
291+ if (
292+ self ._track_import_stack
293+ and fullname != self ._tracked_module
294+ and not self ._enabled
295+ ):
296+ stack = inspect .stack ()
297+ stack_info = []
298+ for frame in stack :
299+ frame_module_name = frame .frame .f_globals ["__name__" ].split ("." )[0 ]
300+ if frame_module_name == self ._tracked_module_parts [0 ]:
301+ stack_info .append (
302+ {
303+ "filename" : frame .filename ,
304+ "lineno" : frame .lineno ,
305+ "code_context" : [
306+ line .strip ("\n " ) for line in frame .code_context
307+ ],
308+ }
309+ )
310+
311+ # NOTE: Under certain _strange_ cases, you can end up overwriting a
312+ # previous import stack here. I've only ever seen this happen with
313+ # pytest internals. Also, in this case the best we can do is just
314+ # keep the latest one.
315+ log .debug2 ("Found %d stack frames for [%s]" , len (stack_info ), fullname )
316+ self ._import_stacks [fullname ] = stack_info
317+
318+ # Let the module pass through
319+ return None
320+
321+ def _find_spec (
322+ self , fullname : str , * args , ** kwargs
323+ ) -> Optional [importlib .machinery .ModuleSpec ]:
324+ """This implements the core logic of find_spec. It is wrapped by the
325+ public find_spec so that when an import is allowed, the stack can be
326+ optionally tracked.
327+ """
252328
253329 # If this module fullname is one of the modules with known side-effects,
254330 # let it fall through
@@ -309,11 +385,17 @@ def get_all_new_modules(self) -> Set[str]:
309385 assert self ._starting_modules is not None , f"Target module never impoted!"
310386 if self ._ending_modules is None :
311387 self ._set_ending_modules ()
312- return {
388+ mod_names = {
313389 mod
314390 for mod in self ._ending_modules - self ._starting_modules
315391 if not self ._is_parent_module (mod )
316392 }
393+ if self ._track_import_stack :
394+ return {
395+ mod_name : self ._import_stacks .get (mod_name , [])
396+ for mod_name in mod_names
397+ }
398+ return mod_names
317399
318400 ## Implementation Details ##
319401
@@ -417,6 +499,13 @@ def main():
417499 default = None ,
418500 help = "Modules with known import-time side effect which should always be allowed to import" ,
419501 )
502+ parser .add_argument (
503+ "--track_import_stack" ,
504+ "-t" ,
505+ action = "store_true" ,
506+ default = False ,
507+ help = "Store the stack trace of imports belonging to the tracked module" ,
508+ )
420509 args = parser .parse_args ()
421510
422511 # Validate sets of args
@@ -442,7 +531,11 @@ def main():
442531 full_module_name = f"{ args .package } { args .name } "
443532
444533 # Create the tracking meta finder
445- tracker_finder = ImportTrackerMetaFinder (full_module_name , args .side_effect_modules )
534+ tracker_finder = ImportTrackerMetaFinder (
535+ tracked_module = full_module_name ,
536+ side_effect_modules = args .side_effect_modules ,
537+ track_import_stack = args .track_import_stack ,
538+ )
446539 sys .meta_path = [tracker_finder ] + sys .meta_path
447540
448541 # Do the import
@@ -480,6 +573,14 @@ def main():
480573 ]
481574 log .debug ("Recursing on: %s" , recursive_internals )
482575
576+ # Set up the kwargs for recursing
577+ recursive_kwargs = dict (
578+ log_level = log_level ,
579+ recursive = False ,
580+ side_effect_modules = args .side_effect_modules ,
581+ track_import_stack = args .track_import_stack ,
582+ )
583+
483584 # Create the thread pool to manage the subprocesses
484585 if args .num_jobs > 0 :
485586 pool = ThreadPoolExecutor (max_workers = args .num_jobs )
@@ -489,9 +590,7 @@ def main():
489590 pool .submit (
490591 track_module ,
491592 module_name = internal_downstream ,
492- log_level = log_level ,
493- recursive = False ,
494- side_effect_modules = args .side_effect_modules ,
593+ ** recursive_kwargs ,
495594 )
496595 )
497596
@@ -507,12 +606,10 @@ def main():
507606 )
508607 downstream_mapping .update (
509608 track_module (
510- module_name = internal_downstream ,
511- log_level = log_level ,
512- recursive = False ,
513- side_effect_modules = args .side_effect_modules ,
609+ module_name = internal_downstream , ** recursive_kwargs
514610 )
515611 )
612+
516613 # This is useful for catching errors caused by unexpected corner
517614 # cases. If it's triggered, it's a sign of a bug in the library,
518615 # so we don't have ways to explicitly exercise this in tests.
@@ -527,13 +624,19 @@ def main():
527624 # Get all of the downstreams for the module in question, including internals
528625 log .debug ("Downstream Mapping: %s" , downstream_mapping )
529626
627+ # Set up the output dict depending on whether or not the stack info is being
628+ # tracked
629+ if args .track_import_stack :
630+ output_dict = {
631+ key : dict (sorted (val .items ())) for key , val in downstream_mapping .items ()
632+ }
633+ else :
634+ output_dict = {
635+ key : sorted (list (val )) for key , val in downstream_mapping .items ()
636+ }
637+
530638 # Print out the json dump
531- print (
532- json .dumps (
533- {key : sorted (list (val )) for key , val in downstream_mapping .items ()},
534- indent = args .indent ,
535- ),
536- )
639+ print (json .dumps (output_dict , indent = args .indent ))
537640
538641
539642if __name__ == "__main__" :
0 commit comments