22
33import glob
44import json
5+ import os
56import subprocess
67import textwrap
78from argparse import ArgumentParser
89from collections import deque
910from itertools import chain
10- from pathlib import Path
11- from typing import Deque , Dict , List , Set , Tuple , TypeAlias , TypedDict
11+ from pathlib import Path , PurePath
12+ from typing import (
13+ Deque ,
14+ Dict ,
15+ List ,
16+ Set ,
17+ Tuple ,
18+ TypeAlias ,
19+ TypedDict ,
20+ Iterable ,
21+ )
1222import logging
1323
1424Glob : TypeAlias = str
@@ -49,15 +59,28 @@ class Pattern(TypedDict):
4959parser .add_argument ("-v" , "--verbose" , action = "count" , default = 0 )
5060
5161
52- def symlink_parents (p : Path ) -> List [Path ]:
62+ def symlink_paths_closure (p : Path ) -> List [Path ]:
63+ """Traverses a chain of symlinks to collect every intermediate path up to the final destination."""
64+
5365 out = []
54- while p .is_symlink () and p not in out :
66+ while p .is_symlink ():
5567 parent = p .readlink ()
56- if parent .is_relative_to ("." ):
57- p = p / parent
58- else :
68+ if parent .is_absolute ():
5969 p = parent
70+ else :
71+ # we need to resolve paths before concatenation because of things like
72+ # $ ls -l /sys/dev/char/226:128/subsystem
73+ # ... /sys/dev/char/226:128/subsystem
74+ # -> ../../../../../../class/drm
75+ # see also test_path_dsicovery_resolve_rel_links
76+ #
77+ # Path(normpath(...)) needed to normalize `foo/../bar` to `bar`
78+ p = Path (os .path .normpath (p .parent .resolve () / parent ))
79+
80+ if p in out :
81+ break
6082 out .append (p )
83+
6184 return out
6285
6386
@@ -70,20 +93,26 @@ def get_required_system_features(parsed_drv: dict) -> List[str]:
7093 # Older versions of Nix store structuredAttrs in the env as a JSON string.
7194 drv_env = parsed_drv .get ("env" , {})
7295 if "__json" in drv_env :
73- return list (json .loads (drv_env ["__json" ]).get ("requiredSystemFeatures" , []))
96+ return list (
97+ json .loads (drv_env ["__json" ]).get ("requiredSystemFeatures" , [])
98+ )
7499
75100 # Without structuredAttrs, requiredSystemFeatures is a space-separated string in env.
76101 return drv_env .get ("requiredSystemFeatures" , "" ).split ()
77102
78103
79- def validate_mounts (pattern : Pattern ) -> List [Tuple [PathString , PathString , bool ]]:
80- roots = []
104+ def validate_mounts (
105+ pattern : Pattern ,
106+ ) -> List [Tuple [PathString , PathString , bool ]]:
107+ roots : List [Tuple [PathString , PathString , bool ]] = []
81108 for mount in pattern ["paths" ]:
82109 if isinstance (mount , PathString ):
83110 matches = glob .glob (mount )
84111 assert matches , f"Specified host paths do not exist: { mount } "
85112
86- roots .extend ((m , m , pattern ["unsafeFollowSymlinks" ]) for m in matches )
113+ roots .extend (
114+ (m , m , pattern ["unsafeFollowSymlinks" ]) for m in matches
115+ )
87116 else :
88117 assert isinstance (mount , dict ) and "host" in mount , mount
89118 assert Path (
@@ -100,6 +129,73 @@ def validate_mounts(pattern: Pattern) -> List[Tuple[PathString, PathString, bool
100129 return roots
101130
102131
132+ def enumerate_patterns (
133+ allowed_patterns : AllowedPatterns , required_features : List [str ]
134+ ) -> Iterable [Tuple [PathString , PathString , bool ]]:
135+ patterns : List [Pattern ] = [
136+ pattern
137+ for pattern in allowed_patterns .values ()
138+ if any (
139+ feature in required_features for feature in pattern ["onFeatures" ]
140+ )
141+ ] # noqa: E501
142+
143+ return (mnt for pattern in patterns for mnt in validate_mounts (pattern ))
144+
145+
146+ def discover_reachable_paths (
147+ inputs : Iterable [Tuple [PathString , PathString , bool ]],
148+ ) -> List [Tuple [PathString , PathString ]]:
149+ queue : Deque [Tuple [PathString , PathString , bool ]] = deque (inputs )
150+ unique_mounts : Set [Tuple [PathString , PathString ]] = set ()
151+ mounts : List [Tuple [PathString , PathString ]] = []
152+
153+ while queue :
154+ guest_path_str , host_path_str , follow_symlinks = queue .popleft ()
155+ if (guest_path_str , host_path_str ) not in unique_mounts :
156+ mounts .append ((guest_path_str , host_path_str ))
157+ unique_mounts .add ((guest_path_str , host_path_str ))
158+
159+ if not follow_symlinks :
160+ continue
161+
162+ host_path = Path (host_path_str )
163+ if not (host_path .is_dir () or host_path .is_symlink ()):
164+ continue
165+
166+ paths = [host_path ] + [
167+ child for child in host_path .iterdir () if host_path .is_dir ()
168+ ]
169+
170+ for child in paths :
171+ for parent in symlink_paths_closure (child ):
172+ parent_str = parent .absolute ().as_posix ()
173+ if all (
174+ not parent .absolute ().is_relative_to (existing_path )
175+ for existing_path , _ in unique_mounts
176+ ):
177+ queue .append ((parent_str , parent_str , follow_symlinks ))
178+ return mounts
179+
180+
181+ def prune_paths (
182+ inputs : List [Tuple [PathString , PathString ]],
183+ ) -> List [Tuple [PathString , PathString ]]:
184+ if len (inputs ) < 2 :
185+ return inputs
186+
187+ sorted_inputs = sorted (inputs )
188+ pruned = [sorted_inputs [0 ]]
189+
190+ last_kept = Path (pruned [0 ][0 ])
191+ for current in sorted_inputs [1 :]:
192+ if not Path (current [0 ]).is_relative_to (last_kept ):
193+ pruned .append (current )
194+ last_kept = Path (current [0 ])
195+
196+ return pruned
197+
198+
103199def entrypoint ():
104200 args = parser .parse_args ()
105201
@@ -130,13 +226,19 @@ def entrypoint():
130226 )
131227 try :
132228 parsed_drv = json .loads (proc .stdout )
229+
230+ # compabitility: https://github.com/NixOS/nix/pull/14770
231+ if "derivations" in parsed_drv :
232+ parsed_drv = parsed_drv ["derivations" ]
133233 except json .JSONDecodeError :
134234 logging .error (
135235 "Couldn't parse the output of"
136236 "`nix show-derivation`"
137237 f". Expected JSON, observed: { proc .stdout } " ,
138238 )
139- logging .error (textwrap .indent (proc .stdout .decode ("utf8" ), prefix = " " * 4 ))
239+ logging .error (
240+ textwrap .indent (proc .stdout .decode ("utf8" ), prefix = " " * 4 )
241+ )
140242 logging .info ("Exiting the nix-required-binds hook" )
141243 return
142244 [canon_drv_path ] = parsed_drv .keys ()
@@ -149,41 +251,15 @@ def entrypoint():
149251
150252 parsed_drv = parsed_drv [canon_drv_path ]
151253 required_features = get_required_system_features (parsed_drv )
152- required_features = list (filter (known_features .__contains__ , required_features ))
153-
154- patterns : List [Pattern ] = list (
155- pattern
156- for pattern in allowed_patterns .values ()
157- for path in pattern ["paths" ]
158- if any (feature in required_features for feature in pattern ["onFeatures" ])
159- ) # noqa: E501
160-
161- queue : Deque [Tuple [PathString , PathString , bool ]] = deque (
162- (mnt for pattern in patterns for mnt in validate_mounts (pattern ))
254+ required_features = list (
255+ filter (known_features .__contains__ , required_features )
163256 )
164257
165- unique_mounts : Set [Tuple [PathString , PathString ]] = set ()
166- mounts : List [Tuple [PathString , PathString ]] = []
167-
168- while queue :
169- guest_path_str , host_path_str , follow_symlinks = queue .popleft ()
170- if (guest_path_str , host_path_str ) not in unique_mounts :
171- mounts .append ((guest_path_str , host_path_str ))
172- unique_mounts .add ((guest_path_str , host_path_str ))
173-
174- if not follow_symlinks :
175- continue
176-
177- host_path = Path (host_path_str )
178- if not (host_path .is_dir () or host_path .is_symlink ()):
179- continue
180-
181- # assert host_path_str == guest_path_str, (host_path_str, guest_path_str)
182-
183- for child in host_path .iterdir () if host_path .is_dir () else [host_path ]:
184- for parent in symlink_parents (child ):
185- parent_str = parent .absolute ().as_posix ()
186- queue .append ((parent_str , parent_str , follow_symlinks ))
258+ mounts = prune_paths (
259+ discover_reachable_paths (
260+ enumerate_patterns (allowed_patterns , required_features )
261+ )
262+ )
187263
188264 # the pre-build-hook command
189265 if args .issue_command == "always" or (
0 commit comments