Skip to content

Commit f2797ce

Browse files
committed
nix-required-monts: Refactor and test code, fix broken paths
1 parent cd17122 commit f2797ce

File tree

4 files changed

+290
-47
lines changed

4 files changed

+290
-47
lines changed

pkgs/by-name/ni/nix-required-mounts/pyproject.toml renamed to pkgs/by-name/ni/nix-required-mounts/nix-required-mounts/pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ nix-required-mounts = "nix_required_mounts:entrypoint"
1818

1919
[tool.black]
2020
line-length = 79
21+
22+
[tool.pytest.ini_options]
23+
addopts = ["--doctest-modules"]

pkgs/by-name/ni/nix-required-mounts/nix_required_mounts.py renamed to pkgs/by-name/ni/nix-required-mounts/nix-required-mounts/src/nix_required_mounts.py

Lines changed: 121 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,23 @@
22

33
import glob
44
import json
5+
import os
56
import subprocess
67
import textwrap
78
from argparse import ArgumentParser
89
from collections import deque
910
from itertools import chain
10-
from pathlib import Path
11-
from typing import Deque, Dict, List, Set, Tuple, TypeAlias, TypedDict
11+
from pathlib import Path, PurePath
12+
from typing import (
13+
Deque,
14+
Dict,
15+
List,
16+
Set,
17+
Tuple,
18+
TypeAlias,
19+
TypedDict,
20+
Iterable,
21+
)
1222
import logging
1323

1424
Glob: TypeAlias = str
@@ -49,15 +59,28 @@ class Pattern(TypedDict):
4959
parser.add_argument("-v", "--verbose", action="count", default=0)
5060

5161

52-
def symlink_parents(p: Path) -> List[Path]:
62+
def symlink_paths_closure(p: Path) -> List[Path]:
63+
"""Traverses a chain of symlinks to collect every intermediate path up to the final destination."""
64+
5365
out = []
54-
while p.is_symlink() and p not in out:
66+
while p.is_symlink():
5567
parent = p.readlink()
56-
if parent.is_relative_to("."):
57-
p = p / parent
58-
else:
68+
if parent.is_absolute():
5969
p = parent
70+
else:
71+
# we need to resolve paths before concatenation because of things like
72+
# $ ls -l /sys/dev/char/226:128/subsystem
73+
# ... /sys/dev/char/226:128/subsystem
74+
# -> ../../../../../../class/drm
75+
# see also test_path_dsicovery_resolve_rel_links
76+
#
77+
# Path(normpath(...)) needed to normalize `foo/../bar` to `bar`
78+
p = Path(os.path.normpath(p.parent.resolve() / parent))
79+
80+
if p in out:
81+
break
6082
out.append(p)
83+
6184
return out
6285

6386

@@ -70,20 +93,26 @@ def get_required_system_features(parsed_drv: dict) -> List[str]:
7093
# Older versions of Nix store structuredAttrs in the env as a JSON string.
7194
drv_env = parsed_drv.get("env", {})
7295
if "__json" in drv_env:
73-
return list(json.loads(drv_env["__json"]).get("requiredSystemFeatures", []))
96+
return list(
97+
json.loads(drv_env["__json"]).get("requiredSystemFeatures", [])
98+
)
7499

75100
# Without structuredAttrs, requiredSystemFeatures is a space-separated string in env.
76101
return drv_env.get("requiredSystemFeatures", "").split()
77102

78103

79-
def validate_mounts(pattern: Pattern) -> List[Tuple[PathString, PathString, bool]]:
80-
roots = []
104+
def validate_mounts(
105+
pattern: Pattern,
106+
) -> List[Tuple[PathString, PathString, bool]]:
107+
roots: List[Tuple[PathString, PathString, bool]] = []
81108
for mount in pattern["paths"]:
82109
if isinstance(mount, PathString):
83110
matches = glob.glob(mount)
84111
assert matches, f"Specified host paths do not exist: {mount}"
85112

86-
roots.extend((m, m, pattern["unsafeFollowSymlinks"]) for m in matches)
113+
roots.extend(
114+
(m, m, pattern["unsafeFollowSymlinks"]) for m in matches
115+
)
87116
else:
88117
assert isinstance(mount, dict) and "host" in mount, mount
89118
assert Path(
@@ -100,6 +129,73 @@ def validate_mounts(pattern: Pattern) -> List[Tuple[PathString, PathString, bool
100129
return roots
101130

102131

132+
def enumerate_patterns(
133+
allowed_patterns: AllowedPatterns, required_features: List[str]
134+
) -> Iterable[Tuple[PathString, PathString, bool]]:
135+
patterns: List[Pattern] = [
136+
pattern
137+
for pattern in allowed_patterns.values()
138+
if any(
139+
feature in required_features for feature in pattern["onFeatures"]
140+
)
141+
] # noqa: E501
142+
143+
return (mnt for pattern in patterns for mnt in validate_mounts(pattern))
144+
145+
146+
def discover_reachable_paths(
147+
inputs: Iterable[Tuple[PathString, PathString, bool]],
148+
) -> List[Tuple[PathString, PathString]]:
149+
queue: Deque[Tuple[PathString, PathString, bool]] = deque(inputs)
150+
unique_mounts: Set[Tuple[PathString, PathString]] = set()
151+
mounts: List[Tuple[PathString, PathString]] = []
152+
153+
while queue:
154+
guest_path_str, host_path_str, follow_symlinks = queue.popleft()
155+
if (guest_path_str, host_path_str) not in unique_mounts:
156+
mounts.append((guest_path_str, host_path_str))
157+
unique_mounts.add((guest_path_str, host_path_str))
158+
159+
if not follow_symlinks:
160+
continue
161+
162+
host_path = Path(host_path_str)
163+
if not (host_path.is_dir() or host_path.is_symlink()):
164+
continue
165+
166+
paths = [host_path] + [
167+
child for child in host_path.iterdir() if host_path.is_dir()
168+
]
169+
170+
for child in paths:
171+
for parent in symlink_paths_closure(child):
172+
parent_str = parent.absolute().as_posix()
173+
if all(
174+
not parent.absolute().is_relative_to(existing_path)
175+
for existing_path, _ in unique_mounts
176+
):
177+
queue.append((parent_str, parent_str, follow_symlinks))
178+
return mounts
179+
180+
181+
def prune_paths(
182+
inputs: List[Tuple[PathString, PathString]],
183+
) -> List[Tuple[PathString, PathString]]:
184+
if len(inputs) < 2:
185+
return inputs
186+
187+
sorted_inputs = sorted(inputs)
188+
pruned = [sorted_inputs[0]]
189+
190+
last_kept = Path(pruned[0][0])
191+
for current in sorted_inputs[1:]:
192+
if not Path(current[0]).is_relative_to(last_kept):
193+
pruned.append(current)
194+
last_kept = Path(current[0])
195+
196+
return pruned
197+
198+
103199
def entrypoint():
104200
args = parser.parse_args()
105201

@@ -130,13 +226,19 @@ def entrypoint():
130226
)
131227
try:
132228
parsed_drv = json.loads(proc.stdout)
229+
230+
# compabitility: https://github.com/NixOS/nix/pull/14770
231+
if "derivations" in parsed_drv:
232+
parsed_drv = parsed_drv["derivations"]
133233
except json.JSONDecodeError:
134234
logging.error(
135235
"Couldn't parse the output of"
136236
"`nix show-derivation`"
137237
f". Expected JSON, observed: {proc.stdout}",
138238
)
139-
logging.error(textwrap.indent(proc.stdout.decode("utf8"), prefix=" " * 4))
239+
logging.error(
240+
textwrap.indent(proc.stdout.decode("utf8"), prefix=" " * 4)
241+
)
140242
logging.info("Exiting the nix-required-binds hook")
141243
return
142244
[canon_drv_path] = parsed_drv.keys()
@@ -149,41 +251,15 @@ def entrypoint():
149251

150252
parsed_drv = parsed_drv[canon_drv_path]
151253
required_features = get_required_system_features(parsed_drv)
152-
required_features = list(filter(known_features.__contains__, required_features))
153-
154-
patterns: List[Pattern] = list(
155-
pattern
156-
for pattern in allowed_patterns.values()
157-
for path in pattern["paths"]
158-
if any(feature in required_features for feature in pattern["onFeatures"])
159-
) # noqa: E501
160-
161-
queue: Deque[Tuple[PathString, PathString, bool]] = deque(
162-
(mnt for pattern in patterns for mnt in validate_mounts(pattern))
254+
required_features = list(
255+
filter(known_features.__contains__, required_features)
163256
)
164257

165-
unique_mounts: Set[Tuple[PathString, PathString]] = set()
166-
mounts: List[Tuple[PathString, PathString]] = []
167-
168-
while queue:
169-
guest_path_str, host_path_str, follow_symlinks = queue.popleft()
170-
if (guest_path_str, host_path_str) not in unique_mounts:
171-
mounts.append((guest_path_str, host_path_str))
172-
unique_mounts.add((guest_path_str, host_path_str))
173-
174-
if not follow_symlinks:
175-
continue
176-
177-
host_path = Path(host_path_str)
178-
if not (host_path.is_dir() or host_path.is_symlink()):
179-
continue
180-
181-
# assert host_path_str == guest_path_str, (host_path_str, guest_path_str)
182-
183-
for child in host_path.iterdir() if host_path.is_dir() else [host_path]:
184-
for parent in symlink_parents(child):
185-
parent_str = parent.absolute().as_posix()
186-
queue.append((parent_str, parent_str, follow_symlinks))
258+
mounts = prune_paths(
259+
discover_reachable_paths(
260+
enumerate_patterns(allowed_patterns, required_features)
261+
)
262+
)
187263

188264
# the pre-build-hook command
189265
if args.issue_command == "always" or (
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import unittest
2+
import tempfile
3+
import shutil
4+
from pathlib import Path
5+
from nix_required_mounts import (
6+
symlink_paths_closure,
7+
enumerate_patterns,
8+
prune_paths,
9+
discover_reachable_paths,
10+
)
11+
12+
13+
class TestNixRequiredMountsMethods(unittest.TestCase):
14+
15+
def setUp(self):
16+
self.test_dir = tempfile.mkdtemp()
17+
self.root = Path(self.test_dir)
18+
19+
# 1. Multi-link setup
20+
self.ml_a = self.create_path("multi-link/a")
21+
self.ml_b = self.create_symlink("multi-link/b", "a")
22+
self.ml_c = self.create_symlink("multi-link/c", "b")
23+
24+
# 2. Jump-out setup
25+
self.jo_target = self.create_path("jump-out/c/d")
26+
self.jo_link = self.create_symlink("jump-out/a/b", "../c/d")
27+
28+
# 3. Globbing setup
29+
self.glob_base = "globbing"
30+
self.glob_a = self.create_path(f"{self.glob_base}/a")
31+
self.glob_c = self.create_path(f"{self.glob_base}/c")
32+
33+
# 4. far-up rel symlink
34+
self.create_path("far-up/a")
35+
self.create_path("far-up/c/d/e/f/g/h/i/j/k/l/m")
36+
self.far_up_target2 = self.create_path("far-up/o/p")
37+
self.far_up_root = self.create_symlink(
38+
"far-up/a/b", "../c/d/e/f/g/h/i/j/k/l/m"
39+
)
40+
self.far_up_target1 = self.create_symlink(
41+
"far-up/c/d/e/f/g/h/i/j/k/l/m/n",
42+
"../../../../../../../../../../../o/p",
43+
)
44+
45+
def tearDown(self):
46+
shutil.rmtree(self.test_dir)
47+
48+
def create_path(self, path_str):
49+
"""Helper to create directories within the tmp folder."""
50+
path = self.root / path_str
51+
path.mkdir(parents=True, exist_ok=True)
52+
return path
53+
54+
def create_symlink(self, link_path_str, target_path_str):
55+
"""Helper to create symlinks within the tmp folder."""
56+
link_path = self.root / link_path_str
57+
target_path = self.root / target_path_str
58+
link_path.parent.mkdir(parents=True, exist_ok=True)
59+
link_path.symlink_to(target_path_str)
60+
return link_path
61+
62+
def test_symlink_paths_closure_neighbors(self):
63+
self.assertEqual(symlink_paths_closure(self.ml_a), [])
64+
self.assertEqual(symlink_paths_closure(self.ml_b), [self.ml_a])
65+
self.assertEqual(
66+
symlink_paths_closure(self.ml_c), [self.ml_b, self.ml_a]
67+
)
68+
69+
def test_symlink_paths_closure_jump_out(self):
70+
self.assertEqual(symlink_paths_closure(self.jo_link), [self.jo_target])
71+
72+
def test_pattern_extraction(self):
73+
a1 = str(self.ml_a)
74+
a2 = str(self.ml_b)
75+
b1 = str(self.root / "jump-out/a")
76+
b2 = str(self.root / "jump-out/c")
77+
78+
allowed_patterns = {
79+
"a": {
80+
"onFeatures": ["a", "a1"],
81+
"paths": [a1, a2],
82+
"unsafeFollowSymlinks": True,
83+
},
84+
"b": {
85+
"onFeatures": ["b", "b2"],
86+
"paths": [b1, b2],
87+
"unsafeFollowSymlinks": True,
88+
},
89+
}
90+
91+
self.assertEqual(list(enumerate_patterns(allowed_patterns, [])), [])
92+
self.assertEqual(
93+
list(enumerate_patterns(allowed_patterns, ["a"])),
94+
[(a1, a1, True), (a2, a2, True)],
95+
)
96+
self.assertEqual(
97+
list(enumerate_patterns(allowed_patterns, ["b"])),
98+
[(b1, b1, True), (b2, b2, True)],
99+
)
100+
101+
def test_pattern_globbing(self):
102+
full_base_path = str(self.root / self.glob_base)
103+
104+
allowed_patterns = {
105+
"a": {
106+
"onFeatures": ["a"],
107+
"paths": [f"{full_base_path}/*"],
108+
"unsafeFollowSymlinks": True,
109+
}
110+
}
111+
112+
results = list(enumerate_patterns(allowed_patterns, ["a"]))
113+
114+
expected = [
115+
(str(self.glob_a), str(self.glob_a), True),
116+
(str(self.glob_c), str(self.glob_c), True),
117+
]
118+
self.assertEqual(sorted(results), sorted(expected))
119+
120+
def test_pruning(self):
121+
root = str(self.root / "jump-out")
122+
a = str(self.root / "jump-out/a")
123+
b = str(self.root / "jump-out/b")
124+
125+
self.assertEqual(prune_paths([]), [])
126+
127+
self.assertEqual(
128+
prune_paths([(a, a, True), (b, b, True), (root, root, True)]),
129+
[(root, root, True)],
130+
)
131+
132+
def test_path_discovery(self):
133+
ma = str(self.root / "multi-link/a")
134+
mb = str(self.root / "multi-link/b")
135+
mc = str(self.root / "multi-link/c")
136+
ja = str(self.root / "jump-out/a")
137+
j_target = str(self.root / "jump-out/c/d")
138+
self.maxDiff = None
139+
self.assertEqual(
140+
sorted(discover_reachable_paths([(mc, mc, True), (ja, ja, True)])),
141+
[(ja, ja), (j_target, j_target), (ma, ma), (mb, mb), (mc, mc)],
142+
)
143+
144+
def test_path_discovery_resolve_rel_links(self):
145+
self.maxDiff = None
146+
root = str(self.far_up_root)
147+
t1 = str(self.far_up_target1.parent)
148+
t2 = str(self.far_up_target2)
149+
150+
self.assertEqual(
151+
discover_reachable_paths([(root, root, True)]),
152+
[(root, root), (t1, t1), (t2, t2)],
153+
)
154+
155+
156+
if __name__ == "__main__":
157+
unittest.main()

0 commit comments

Comments
 (0)