22
33import glob
44import json
5+ import os
56import subprocess
67import textwrap
78from argparse import ArgumentParser
89from collections import deque
910from itertools import chain
10- from pathlib import Path
11- from typing import Deque , Dict , List , Set , Tuple , TypeAlias , TypedDict
11+ from pathlib import Path , PurePath
12+ from typing import (
13+ Deque ,
14+ Dict ,
15+ List ,
16+ Set ,
17+ Tuple ,
18+ TypeAlias ,
19+ TypedDict ,
20+ Iterable ,
21+ )
1222import logging
1323
1424Glob : TypeAlias = str
@@ -51,13 +61,22 @@ class Pattern(TypedDict):
5161
5262def symlink_parents (p : Path ) -> List [Path ]:
5363 out = []
54- while p .is_symlink () and p not in out :
64+ while p .is_symlink ():
5565 parent = p .readlink ()
56- if parent .is_relative_to ("." ):
57- p = p / parent
58- else :
66+ if parent .is_absolute ():
5967 p = parent
68+ else :
69+ # we need to resolve paths before concatenation because of things like
70+ # $ ls -l /sys/dev/char/226:128/subsystem
71+ # ... /sys/dev/char/226:128/subsystem
72+ # -> ../../../../../../class/drm
73+ # Path(normpath(...)) to normalize `foo/../bar` to `bar`
74+ p = Path (os .path .normpath (p .parent .resolve () / parent ))
75+
76+ if p in out :
77+ break
6078 out .append (p )
79+
6180 return out
6281
6382
@@ -70,20 +89,26 @@ def get_required_system_features(parsed_drv: dict) -> List[str]:
7089 # Older versions of Nix store structuredAttrs in the env as a JSON string.
7190 drv_env = parsed_drv .get ("env" , {})
7291 if "__json" in drv_env :
73- return list (json .loads (drv_env ["__json" ]).get ("requiredSystemFeatures" , []))
92+ return list (
93+ json .loads (drv_env ["__json" ]).get ("requiredSystemFeatures" , [])
94+ )
7495
7596 # Without structuredAttrs, requiredSystemFeatures is a space-separated string in env.
7697 return drv_env .get ("requiredSystemFeatures" , "" ).split ()
7798
7899
79- def validate_mounts (pattern : Pattern ) -> List [Tuple [PathString , PathString , bool ]]:
100+ def validate_mounts (
101+ pattern : Pattern ,
102+ ) -> List [Tuple [PathString , PathString , bool ]]:
80103 roots = []
81104 for mount in pattern ["paths" ]:
82105 if isinstance (mount , PathString ):
83106 matches = glob .glob (mount )
84107 assert matches , f"Specified host paths do not exist: { mount } "
85108
86- roots .extend ((m , m , pattern ["unsafeFollowSymlinks" ]) for m in matches )
109+ roots .extend (
110+ (m , m , pattern ["unsafeFollowSymlinks" ]) for m in matches
111+ )
87112 else :
88113 assert isinstance (mount , dict ) and "host" in mount , mount
89114 assert Path (
@@ -100,6 +125,73 @@ def validate_mounts(pattern: Pattern) -> List[Tuple[PathString, PathString, bool
100125 return roots
101126
102127
128+ def enumerate_patterns (
129+ allowed_patterns : any , required_features : List [str ]
130+ ) -> Iterable [Tuple [PathString , PathString , bool ]]:
131+ patterns : List [Pattern ] = list (
132+ pattern
133+ for pattern in allowed_patterns .values ()
134+ if any (
135+ feature in required_features for feature in pattern ["onFeatures" ]
136+ )
137+ ) # noqa: E501
138+
139+ return (mnt for pattern in patterns for mnt in validate_mounts (pattern ))
140+
141+
142+ def discover_reachable_paths (
143+ inputs : Iterable [Tuple [PathString , PathString , bool ]],
144+ ) -> List [Tuple [PathString , PathString ]]:
145+ queue : Deque [Tuple [PathString , PathString , bool ]] = deque (inputs )
146+ unique_mounts : Set [Tuple [PathString , PathString ]] = set ()
147+ mounts : List [Tuple [PathString , PathString ]] = []
148+
149+ while queue :
150+ guest_path_str , host_path_str , follow_symlinks = queue .popleft ()
151+ if (guest_path_str , host_path_str ) not in unique_mounts :
152+ mounts .append ((guest_path_str , host_path_str ))
153+ unique_mounts .add ((guest_path_str , host_path_str ))
154+
155+ if not follow_symlinks :
156+ continue
157+
158+ host_path = Path (host_path_str )
159+ if not (host_path .is_dir () or host_path .is_symlink ()):
160+ continue
161+
162+ # assert host_path_str == guest_path_str, (host_path_str, guest_path_str)
163+
164+ paths = [host_path ] + list (
165+ child for child in host_path .iterdir () if host_path .is_dir ()
166+ )
167+
168+ for child in paths :
169+ for parent in symlink_parents (child ):
170+ parent_str = parent .absolute ().as_posix ()
171+ if all (
172+ not parent .absolute ().is_relative_to (existing_path )
173+ for existing_path , _ in unique_mounts
174+ ):
175+ queue .append ((parent_str , parent_str , follow_symlinks ))
176+ return mounts
177+
178+
179+ def prune_paths (
180+ inputs : List [Tuple [PathString , PathString , bool ]],
181+ ) -> List [Tuple [PathString , PathString , bool ]]:
182+ if len (inputs ) == 0 :
183+ return []
184+
185+ sorted_inputs = sorted (inputs )
186+ pruned = [sorted_inputs [0 ]]
187+
188+ for current in sorted_inputs [1 :]:
189+ last_kept = Path (pruned [- 1 ][0 ])
190+ if not Path (current [0 ]).is_relative_to (last_kept ):
191+ pruned .append (current )
192+ return pruned
193+
194+
103195def entrypoint ():
104196 args = parser .parse_args ()
105197
@@ -130,13 +222,17 @@ def entrypoint():
130222 )
131223 try :
132224 parsed_drv = json .loads (proc .stdout )
225+ if "derivations" in parsed_drv :
226+ parsed_drv = parsed_drv ["derivations" ]
133227 except json .JSONDecodeError :
134228 logging .error (
135229 "Couldn't parse the output of"
136230 "`nix show-derivation`"
137231 f". Expected JSON, observed: { proc .stdout } " ,
138232 )
139- logging .error (textwrap .indent (proc .stdout .decode ("utf8" ), prefix = " " * 4 ))
233+ logging .error (
234+ textwrap .indent (proc .stdout .decode ("utf8" ), prefix = " " * 4 )
235+ )
140236 logging .info ("Exiting the nix-required-binds hook" )
141237 return
142238 [canon_drv_path ] = parsed_drv .keys ()
@@ -149,41 +245,15 @@ def entrypoint():
149245
150246 parsed_drv = parsed_drv [canon_drv_path ]
151247 required_features = get_required_system_features (parsed_drv )
152- required_features = list (filter (known_features .__contains__ , required_features ))
153-
154- patterns : List [Pattern ] = list (
155- pattern
156- for pattern in allowed_patterns .values ()
157- for path in pattern ["paths" ]
158- if any (feature in required_features for feature in pattern ["onFeatures" ])
159- ) # noqa: E501
160-
161- queue : Deque [Tuple [PathString , PathString , bool ]] = deque (
162- (mnt for pattern in patterns for mnt in validate_mounts (pattern ))
248+ required_features = list (
249+ filter (known_features .__contains__ , required_features )
163250 )
164251
165- unique_mounts : Set [Tuple [PathString , PathString ]] = set ()
166- mounts : List [Tuple [PathString , PathString ]] = []
167-
168- while queue :
169- guest_path_str , host_path_str , follow_symlinks = queue .popleft ()
170- if (guest_path_str , host_path_str ) not in unique_mounts :
171- mounts .append ((guest_path_str , host_path_str ))
172- unique_mounts .add ((guest_path_str , host_path_str ))
173-
174- if not follow_symlinks :
175- continue
176-
177- host_path = Path (host_path_str )
178- if not (host_path .is_dir () or host_path .is_symlink ()):
179- continue
180-
181- # assert host_path_str == guest_path_str, (host_path_str, guest_path_str)
182-
183- for child in host_path .iterdir () if host_path .is_dir () else [host_path ]:
184- for parent in symlink_parents (child ):
185- parent_str = parent .absolute ().as_posix ()
186- queue .append ((parent_str , parent_str , follow_symlinks ))
252+ mounts = prune_paths (
253+ discover_reachable_paths (
254+ enumerate_patterns (allowed_patterns , required_features )
255+ )
256+ )
187257
188258 # the pre-build-hook command
189259 if args .issue_command == "always" or (
0 commit comments