@@ -61,23 +61,59 @@ def reroot_path(
6161 return rel_fn , project_root .joinpath (rel_fn ).resolve ()
6262
6363
64- def get_files (root : Path , extensions : Container [str ]) -> Iterator [Path ]:
64+ def is_relative_to (a : Path , b : Path ) -> bool :
65+ try :
66+ a .relative_to (b )
67+ return True
68+ except ValueError :
69+ return False
70+
71+
72+ def get_files (
73+ root : Path , extensions : Container [str ], must_be_relative_to : Optional [Path ] = None
74+ ) -> Iterator [Path ]:
6575 """Recursively iterate over files underneath the given root, yielding
6676 only filenames with the given extensions. Symlinks are followed, but
67- any given concrete directory is only scanned once."""
77+ any given concrete directory is only scanned once.
78+
79+ By default, directories above the given root in the filesystem are not
80+ scanned, but this can be overridden with the must_be_relative_to parameter."""
81+ root_resolved = root .resolve ()
6882 seen : Set [Path ] = set ()
6983
84+ if must_be_relative_to is None :
85+ must_be_relative_to = root_resolved
86+
7087 for base , dirs , files in os .walk (root , followlinks = True ):
71- dirs_set = set (Path (d ).resolve () for d in dirs )
72- dirs [:] = [d .name for d in (dirs_set - seen )]
88+ base_resolved = Path (base ).resolve ()
89+ if not is_relative_to (base_resolved , must_be_relative_to ):
90+ # Prevent a race between our checking if a symlink is valid, and our
91+ # actually entering it.
92+ continue
93+
94+ # Preserve both the actual resolved path and the directory name
95+ dirs_set = dict (((base_resolved .joinpath (d ).resolve (), d ) for d in dirs ))
96+
97+ # Only recurse into directories which are within our prefix
98+ dirs [:] = [
99+ d_name
100+ for d_path , d_name in ((k , v ) for k , v in dirs_set .items () if k not in seen )
101+ if is_relative_to (
102+ base_resolved .joinpath (d_path ).resolve (), must_be_relative_to
103+ )
104+ ]
105+
73106 seen .update (dirs_set )
74107
75108 for name in files :
76109 ext = os .path .splitext (name )[1 ]
77110 if ext not in extensions :
78111 continue
79112
80- yield Path (os .path .join (base , name ))
113+ path = Path (os .path .join (base , name ))
114+ # Detect and ignore symlinks outside of our jail
115+ if is_relative_to (path .resolve (), must_be_relative_to ):
116+ yield path
81117
82118
83119def get_line (node : docutils .nodes .Node ) -> int :
0 commit comments