Skip to content

Commit a818e82

Browse files
authored
Merge pull request #30 from IBM/TrackImportStack-27
Track import stack 27
2 parents d02a300 + 1ce19f5 commit a818e82

File tree

3 files changed

+200
-27
lines changed

3 files changed

+200
-27
lines changed

import_tracker/__main__.py

Lines changed: 130 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
# Standard
1818
from concurrent.futures import ThreadPoolExecutor
1919
from types import ModuleType
20-
from typing import List, Optional, Set
20+
from typing import Dict, List, Optional, Set, Union
2121
import argparse
2222
import cmath
2323
import importlib
24+
import inspect
2425
import json
2526
import logging
2627
import os
@@ -54,9 +55,10 @@ def _get_import_parent_path(mod) -> str:
5455
return parent_path
5556

5657

57-
def _get_non_std_modules(mod_names: Set[str]) -> Set[str]:
58+
def _get_non_std_modules(mod_names: Union[Set[str], Dict[str, List[dict]]]) -> Set[str]:
5859
"""Take a snapshot of the non-standard modules currently imported"""
59-
return {
60+
# Determine the names from the list that are non-standard
61+
non_std_mods = {
6062
mod_name.split(".")[0]
6163
for mod_name, mod in sys.modules.items()
6264
if mod_name in mod_names
@@ -66,6 +68,17 @@ def _get_non_std_modules(mod_names: Set[str]) -> Set[str]:
6668
and mod_name.split(".")[0] != THIS_PACKAGE
6769
}
6870

71+
# If this is a set, just return it directly
72+
if isinstance(mod_names, set):
73+
return non_std_mods
74+
75+
# If it's a dict, limit to the non standard names
76+
return {
77+
mod_name: mod_vals
78+
for mod_name, mod_vals in mod_names.items()
79+
if mod_name in non_std_mods
80+
}
81+
6982

7083
class _DeferredModule(ModuleType):
7184
"""A _DeferredModule is a module subclass that wraps another module but imports
@@ -197,17 +210,16 @@ def exec_module(self, *_, **__):
197210
class ImportTrackerMetaFinder(importlib.abc.MetaPathFinder):
198211
"""The ImportTrackerMetaFinder is a meta finder that is intended to be used
199212
at the front of the sys.meta_path to automatically track the imports for a
200-
given library. It does this by looking at the call stack when a given import
201-
is requested and tracking the upstream for each import made inside of the
202-
target package.
203-
204-
NOTE: Since a stack trace is traversed on every import, this is very slow
205-
and is intended only for a static build-time operation and should not be
206-
used during the import phase of a library at runtime!
213+
given library. It does this by deferring all imports which occur before the
214+
target module has been seen, then collecting all imports seen until the
215+
target import has completed.
207216
"""
208217

209218
def __init__(
210-
self, tracked_module: str, side_effect_modules: Optional[List[str]] = None
219+
self,
220+
tracked_module: str,
221+
side_effect_modules: Optional[List[str]] = None,
222+
track_import_stack: bool = False,
211223
):
212224
"""Initialize with the name of the package being tracked
213225
@@ -219,6 +231,12 @@ def __init__(
219231
to perform required import tasks (e.g. global singleton
220232
registries). These modules will be allowed to import regardless
221233
of where they fall relative to the targeted module.
234+
track_import_stack: bool
235+
If true, when imports are allowed through, their stack trace is
236+
captured.
237+
NOTE: This will cause a stack trace to be computed for every
238+
import in the tracked set, so it will be very slow and
239+
should only be used as a debugging tool on targeted imports.
222240
"""
223241
self._tracked_module = tracked_module
224242
self._side_effect_modules = side_effect_modules or []
@@ -228,9 +246,14 @@ def __init__(
228246
log.debug2("Starting modules: %s", self._starting_modules)
229247
self._ending_modules = None
230248
self._deferred_modules = set()
249+
self._track_import_stack = track_import_stack
250+
self._import_stacks = {}
231251

232252
def find_spec(
233-
self, fullname: str, *args, **kwargs
253+
self,
254+
fullname: str,
255+
*args,
256+
**kwargs,
234257
) -> Optional[importlib.machinery.ModuleSpec]:
235258
"""The find_spec implementation for this finder tracks the source of the
236259
import call for the given module and determines if it is on the critical
@@ -249,6 +272,59 @@ def find_spec(
249272
import is on the critical path, None will be returned to defer
250273
to the rest of the "real" finders.
251274
"""
275+
# Do the main tracking logic
276+
result = self._find_spec(fullname, *args, **kwargs)
277+
278+
# If this module is deferred, return it
279+
if result is not None:
280+
log.debug2("Returning deferred module for [%s]", fullname)
281+
return result
282+
283+
# If this module is part of the set of modules belonging to the tracked
284+
# module and stack tracing is enabled, grab all frames in the stack that
285+
# come from the tracked module's package.
286+
log.debug2(
287+
"Stack tracking? %s, Ending modules set? %s",
288+
self._track_import_stack,
289+
self._ending_modules is not None,
290+
)
291+
if (
292+
self._track_import_stack
293+
and fullname != self._tracked_module
294+
and not self._enabled
295+
):
296+
stack = inspect.stack()
297+
stack_info = []
298+
for frame in stack:
299+
frame_module_name = frame.frame.f_globals["__name__"].split(".")[0]
300+
if frame_module_name == self._tracked_module_parts[0]:
301+
stack_info.append(
302+
{
303+
"filename": frame.filename,
304+
"lineno": frame.lineno,
305+
"code_context": [
306+
line.strip("\n") for line in frame.code_context
307+
],
308+
}
309+
)
310+
311+
# NOTE: Under certain _strange_ cases, you can end up overwriting a
312+
# previous import stack here. I've only ever seen this happen with
313+
# pytest internals. Also, in this case the best we can do is just
314+
# keep the latest one.
315+
log.debug2("Found %d stack frames for [%s]", len(stack_info), fullname)
316+
self._import_stacks[fullname] = stack_info
317+
318+
# Let the module pass through
319+
return None
320+
321+
def _find_spec(
322+
self, fullname: str, *args, **kwargs
323+
) -> Optional[importlib.machinery.ModuleSpec]:
324+
"""This implements the core logic of find_spec. It is wrapped by the
325+
public find_spec so that when an import is allowed, the stack can be
326+
optionally tracked.
327+
"""
252328

253329
# If this module fullname is one of the modules with known side-effects,
254330
# let it fall through
@@ -309,11 +385,17 @@ def get_all_new_modules(self) -> Set[str]:
309385
assert self._starting_modules is not None, f"Target module never impoted!"
310386
if self._ending_modules is None:
311387
self._set_ending_modules()
312-
return {
388+
mod_names = {
313389
mod
314390
for mod in self._ending_modules - self._starting_modules
315391
if not self._is_parent_module(mod)
316392
}
393+
if self._track_import_stack:
394+
return {
395+
mod_name: self._import_stacks.get(mod_name, [])
396+
for mod_name in mod_names
397+
}
398+
return mod_names
317399

318400
## Implementation Details ##
319401

@@ -417,6 +499,13 @@ def main():
417499
default=None,
418500
help="Modules with known import-time side effect which should always be allowed to import",
419501
)
502+
parser.add_argument(
503+
"--track_import_stack",
504+
"-t",
505+
action="store_true",
506+
default=False,
507+
help="Store the stack trace of imports belonging to the tracked module",
508+
)
420509
args = parser.parse_args()
421510

422511
# Validate sets of args
@@ -442,7 +531,11 @@ def main():
442531
full_module_name = f"{args.package}{args.name}"
443532

444533
# Create the tracking meta finder
445-
tracker_finder = ImportTrackerMetaFinder(full_module_name, args.side_effect_modules)
534+
tracker_finder = ImportTrackerMetaFinder(
535+
tracked_module=full_module_name,
536+
side_effect_modules=args.side_effect_modules,
537+
track_import_stack=args.track_import_stack,
538+
)
446539
sys.meta_path = [tracker_finder] + sys.meta_path
447540

448541
# Do the import
@@ -480,6 +573,14 @@ def main():
480573
]
481574
log.debug("Recursing on: %s", recursive_internals)
482575

576+
# Set up the kwargs for recursing
577+
recursive_kwargs = dict(
578+
log_level=log_level,
579+
recursive=False,
580+
side_effect_modules=args.side_effect_modules,
581+
track_import_stack=args.track_import_stack,
582+
)
583+
483584
# Create the thread pool to manage the subprocesses
484585
if args.num_jobs > 0:
485586
pool = ThreadPoolExecutor(max_workers=args.num_jobs)
@@ -489,9 +590,7 @@ def main():
489590
pool.submit(
490591
track_module,
491592
module_name=internal_downstream,
492-
log_level=log_level,
493-
recursive=False,
494-
side_effect_modules=args.side_effect_modules,
593+
**recursive_kwargs,
495594
)
496595
)
497596

@@ -507,12 +606,10 @@ def main():
507606
)
508607
downstream_mapping.update(
509608
track_module(
510-
module_name=internal_downstream,
511-
log_level=log_level,
512-
recursive=False,
513-
side_effect_modules=args.side_effect_modules,
609+
module_name=internal_downstream, **recursive_kwargs
514610
)
515611
)
612+
516613
# This is useful for catching errors caused by unexpected corner
517614
# cases. If it's triggered, it's a sign of a bug in the library,
518615
# so we don't have ways to explicitly exercise this in tests.
@@ -527,13 +624,19 @@ def main():
527624
# Get all of the downstreams for the module in question, including internals
528625
log.debug("Downstream Mapping: %s", downstream_mapping)
529626

627+
# Set up the output dict depending on whether or not the stack info is being
628+
# tracked
629+
if args.track_import_stack:
630+
output_dict = {
631+
key: dict(sorted(val.items())) for key, val in downstream_mapping.items()
632+
}
633+
else:
634+
output_dict = {
635+
key: sorted(list(val)) for key, val in downstream_mapping.items()
636+
}
637+
530638
# Print out the json dump
531-
print(
532-
json.dumps(
533-
{key: sorted(list(val)) for key, val in downstream_mapping.items()},
534-
indent=args.indent,
535-
),
536-
)
639+
print(json.dumps(output_dict, indent=args.indent))
537640

538641

539642
if __name__ == "__main__":

import_tracker/import_tracker.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def track_module(
2828
num_jobs: int = 0,
2929
side_effect_modules: Optional[List[str]] = None,
3030
submodules: Optional[List[str]] = None,
31+
track_import_stack: bool = False,
3132
) -> Dict[str, List[str]]:
3233
"""This function executes the tracking of a single module by launching a
3334
subprocess to execute this module against the target module. The
@@ -54,6 +55,8 @@ def track_module(
5455
fall relative to the targeted module.
5556
submodules: Optional[List[str]]
5657
List of sub-modules to recurse on (only used when recursive set)
58+
track_import_stack: bool
59+
Store the stack trace of imports belonging to the tracked module
5760
5861
Returns:
5962
import_mapping: Dict[str, List[str]]
@@ -87,6 +90,8 @@ def track_module(
8790
cmd += " --side_effect_modules " + " ".join(side_effect_modules)
8891
if submodules:
8992
cmd += " --submodules " + " ".join(submodules)
93+
if track_import_stack:
94+
cmd += " --track_import_stack"
9095

9196
# Launch the process
9297
proc = subprocess.Popen(shlex.split(cmd), stdout=subprocess.PIPE, env=env)

test/test_main.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from contextlib import contextmanager
77
import json
88
import logging
9+
import os
910
import sys
1011

1112
# Third Party
@@ -198,3 +199,67 @@ def test_error_submodules_without_recursive():
198199
):
199200
with pytest.raises(ValueError):
200201
main()
202+
203+
204+
def test_import_stack_tracking(capsys):
205+
"""Make sure that tracking the import stack works as expected"""
206+
with cli_args(
207+
"--name",
208+
"inter_mod_deps",
209+
"--recursive",
210+
"--track_import_stack",
211+
):
212+
main()
213+
captured = capsys.readouterr()
214+
assert captured.out
215+
parsed_out = json.loads(captured.out)
216+
217+
assert set(parsed_out.keys()) == {
218+
"inter_mod_deps",
219+
"inter_mod_deps.submod1",
220+
"inter_mod_deps.submod2",
221+
"inter_mod_deps.submod2.foo",
222+
"inter_mod_deps.submod2.bar",
223+
"inter_mod_deps.submod3",
224+
"inter_mod_deps.submod4",
225+
"inter_mod_deps.submod5",
226+
}
227+
228+
# Check one of the stacks to make sure it's correct
229+
test_lib_dir = os.path.realpath(
230+
os.path.join(
231+
os.path.dirname(__file__),
232+
"sample_libs",
233+
"inter_mod_deps",
234+
)
235+
)
236+
assert parsed_out["inter_mod_deps.submod2"] == {
237+
"alog": [
238+
{
239+
"filename": f"{test_lib_dir}/submod1/__init__.py",
240+
"lineno": 6,
241+
"code_context": ["import alog"],
242+
},
243+
{
244+
"filename": f"{test_lib_dir}/__init__.py",
245+
"lineno": 17,
246+
"code_context": [
247+
"from . import submod1, submod2, submod3, submod4, submod5"
248+
],
249+
},
250+
],
251+
"yaml": [
252+
{
253+
"filename": f"{test_lib_dir}/submod2/__init__.py",
254+
"lineno": 6,
255+
"code_context": ["import yaml"],
256+
},
257+
{
258+
"filename": f"{test_lib_dir}/__init__.py",
259+
"lineno": 17,
260+
"code_context": [
261+
"from . import submod1, submod2, submod3, submod4, submod5"
262+
],
263+
},
264+
],
265+
}

0 commit comments

Comments
 (0)