Skip to content

Commit 4acdde7

Browse files
committed
nix-required-monts: Refactor and test code, fix broken paths
1 parent cd17122 commit 4acdde7

File tree

3 files changed

+251
-43
lines changed

3 files changed

+251
-43
lines changed

pkgs/by-name/ni/nix-required-mounts/nix_required_mounts.py

Lines changed: 113 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,23 @@
22

33
import glob
44
import json
5+
import os
56
import subprocess
67
import textwrap
78
from argparse import ArgumentParser
89
from collections import deque
910
from itertools import chain
10-
from pathlib import Path
11-
from typing import Deque, Dict, List, Set, Tuple, TypeAlias, TypedDict
11+
from pathlib import Path, PurePath
12+
from typing import (
13+
Deque,
14+
Dict,
15+
List,
16+
Set,
17+
Tuple,
18+
TypeAlias,
19+
TypedDict,
20+
Iterable,
21+
)
1222
import logging
1323

1424
Glob: TypeAlias = str
@@ -51,13 +61,22 @@ class Pattern(TypedDict):
5161

5262
def symlink_parents(p: Path) -> List[Path]:
5363
out = []
54-
while p.is_symlink() and p not in out:
64+
while p.is_symlink():
5565
parent = p.readlink()
56-
if parent.is_relative_to("."):
57-
p = p / parent
58-
else:
66+
if parent.is_absolute():
5967
p = parent
68+
else:
69+
# we need to resolve paths before concatenation because of things like
70+
# $ ls -l /sys/dev/char/226:128/subsystem
71+
# ... /sys/dev/char/226:128/subsystem
72+
# -> ../../../../../../class/drm
73+
# Path(normpath(...)) to normalize `foo/../bar` to `bar`
74+
p = Path(os.path.normpath(p.parent.resolve() / parent))
75+
76+
if p in out:
77+
break
6078
out.append(p)
79+
6180
return out
6281

6382

@@ -70,20 +89,26 @@ def get_required_system_features(parsed_drv: dict) -> List[str]:
7089
# Older versions of Nix store structuredAttrs in the env as a JSON string.
7190
drv_env = parsed_drv.get("env", {})
7291
if "__json" in drv_env:
73-
return list(json.loads(drv_env["__json"]).get("requiredSystemFeatures", []))
92+
return list(
93+
json.loads(drv_env["__json"]).get("requiredSystemFeatures", [])
94+
)
7495

7596
# Without structuredAttrs, requiredSystemFeatures is a space-separated string in env.
7697
return drv_env.get("requiredSystemFeatures", "").split()
7798

7899

79-
def validate_mounts(pattern: Pattern) -> List[Tuple[PathString, PathString, bool]]:
100+
def validate_mounts(
101+
pattern: Pattern,
102+
) -> List[Tuple[PathString, PathString, bool]]:
80103
roots = []
81104
for mount in pattern["paths"]:
82105
if isinstance(mount, PathString):
83106
matches = glob.glob(mount)
84107
assert matches, f"Specified host paths do not exist: {mount}"
85108

86-
roots.extend((m, m, pattern["unsafeFollowSymlinks"]) for m in matches)
109+
roots.extend(
110+
(m, m, pattern["unsafeFollowSymlinks"]) for m in matches
111+
)
87112
else:
88113
assert isinstance(mount, dict) and "host" in mount, mount
89114
assert Path(
@@ -100,6 +125,73 @@ def validate_mounts(pattern: Pattern) -> List[Tuple[PathString, PathString, bool
100125
return roots
101126

102127

128+
def enumerate_patterns(
129+
allowed_patterns: any, required_features: List[str]
130+
) -> Iterable[Tuple[PathString, PathString, bool]]:
131+
patterns: List[Pattern] = list(
132+
pattern
133+
for pattern in allowed_patterns.values()
134+
if any(
135+
feature in required_features for feature in pattern["onFeatures"]
136+
)
137+
) # noqa: E501
138+
139+
return (mnt for pattern in patterns for mnt in validate_mounts(pattern))
140+
141+
142+
def discover_reachable_paths(
143+
inputs: Iterable[Tuple[PathString, PathString, bool]],
144+
) -> List[Tuple[PathString, PathString]]:
145+
queue: Deque[Tuple[PathString, PathString, bool]] = deque(inputs)
146+
unique_mounts: Set[Tuple[PathString, PathString]] = set()
147+
mounts: List[Tuple[PathString, PathString]] = []
148+
149+
while queue:
150+
guest_path_str, host_path_str, follow_symlinks = queue.popleft()
151+
if (guest_path_str, host_path_str) not in unique_mounts:
152+
mounts.append((guest_path_str, host_path_str))
153+
unique_mounts.add((guest_path_str, host_path_str))
154+
155+
if not follow_symlinks:
156+
continue
157+
158+
host_path = Path(host_path_str)
159+
if not (host_path.is_dir() or host_path.is_symlink()):
160+
continue
161+
162+
# assert host_path_str == guest_path_str, (host_path_str, guest_path_str)
163+
164+
paths = [host_path] + list(
165+
child for child in host_path.iterdir() if host_path.is_dir()
166+
)
167+
168+
for child in paths:
169+
for parent in symlink_parents(child):
170+
parent_str = parent.absolute().as_posix()
171+
if all(
172+
not parent.absolute().is_relative_to(existing_path)
173+
for existing_path, _ in unique_mounts
174+
):
175+
queue.append((parent_str, parent_str, follow_symlinks))
176+
return mounts
177+
178+
179+
def prune_paths(
180+
inputs: List[Tuple[PathString, PathString, bool]],
181+
) -> List[Tuple[PathString, PathString, bool]]:
182+
if len(inputs) == 0:
183+
return []
184+
185+
sorted_inputs = sorted(inputs)
186+
pruned = [sorted_inputs[0]]
187+
188+
for current in sorted_inputs[1:]:
189+
last_kept = Path(pruned[-1][0])
190+
if not Path(current[0]).is_relative_to(last_kept):
191+
pruned.append(current)
192+
return pruned
193+
194+
103195
def entrypoint():
104196
args = parser.parse_args()
105197

@@ -130,13 +222,17 @@ def entrypoint():
130222
)
131223
try:
132224
parsed_drv = json.loads(proc.stdout)
225+
if "derivations" in parsed_drv:
226+
parsed_drv = parsed_drv["derivations"]
133227
except json.JSONDecodeError:
134228
logging.error(
135229
"Couldn't parse the output of"
136230
"`nix show-derivation`"
137231
f". Expected JSON, observed: {proc.stdout}",
138232
)
139-
logging.error(textwrap.indent(proc.stdout.decode("utf8"), prefix=" " * 4))
233+
logging.error(
234+
textwrap.indent(proc.stdout.decode("utf8"), prefix=" " * 4)
235+
)
140236
logging.info("Exiting the nix-required-binds hook")
141237
return
142238
[canon_drv_path] = parsed_drv.keys()
@@ -149,41 +245,15 @@ def entrypoint():
149245

150246
parsed_drv = parsed_drv[canon_drv_path]
151247
required_features = get_required_system_features(parsed_drv)
152-
required_features = list(filter(known_features.__contains__, required_features))
153-
154-
patterns: List[Pattern] = list(
155-
pattern
156-
for pattern in allowed_patterns.values()
157-
for path in pattern["paths"]
158-
if any(feature in required_features for feature in pattern["onFeatures"])
159-
) # noqa: E501
160-
161-
queue: Deque[Tuple[PathString, PathString, bool]] = deque(
162-
(mnt for pattern in patterns for mnt in validate_mounts(pattern))
248+
required_features = list(
249+
filter(known_features.__contains__, required_features)
163250
)
164251

165-
unique_mounts: Set[Tuple[PathString, PathString]] = set()
166-
mounts: List[Tuple[PathString, PathString]] = []
167-
168-
while queue:
169-
guest_path_str, host_path_str, follow_symlinks = queue.popleft()
170-
if (guest_path_str, host_path_str) not in unique_mounts:
171-
mounts.append((guest_path_str, host_path_str))
172-
unique_mounts.add((guest_path_str, host_path_str))
173-
174-
if not follow_symlinks:
175-
continue
176-
177-
host_path = Path(host_path_str)
178-
if not (host_path.is_dir() or host_path.is_symlink()):
179-
continue
180-
181-
# assert host_path_str == guest_path_str, (host_path_str, guest_path_str)
182-
183-
for child in host_path.iterdir() if host_path.is_dir() else [host_path]:
184-
for parent in symlink_parents(child):
185-
parent_str = parent.absolute().as_posix()
186-
queue.append((parent_str, parent_str, follow_symlinks))
252+
mounts = prune_paths(
253+
discover_reachable_paths(
254+
enumerate_patterns(allowed_patterns, required_features)
255+
)
256+
)
187257

188258
# the pre-build-hook command
189259
if args.issue_command == "always" or (

pkgs/by-name/ni/nix-required-mounts/package.nix

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ python3Packages.buildPythonApplication {
4444
python3Packages.setuptools
4545
];
4646

47+
checkPhase = ''
48+
python3 test.py
49+
'';
50+
4751
postFixup = ''
4852
wrapProgram $out/bin/${pname} \
4953
--add-flags "--patterns ${allowedPatternsPath}" \
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import unittest
2+
import tempfile
3+
import shutil
4+
from pathlib import Path
5+
from nix_required_mounts import (
6+
symlink_parents,
7+
enumerate_patterns,
8+
prune_paths,
9+
discover_reachable_paths,
10+
)
11+
12+
13+
class TestStringMethods(unittest.TestCase):
14+
15+
def setUp(self):
16+
self.test_dir = tempfile.mkdtemp()
17+
self.root = Path(self.test_dir)
18+
19+
# 1. Multi-link setup
20+
self.ml_a = self.create_path("testfolder/multi-link/a")
21+
self.ml_b = self.create_symlink("testfolder/multi-link/b", "a")
22+
self.ml_c = self.create_symlink("testfolder/multi-link/c", "b")
23+
24+
# 2. Jump-out setup
25+
self.jo_target = self.create_path("testfolder/jump-out/c/d")
26+
self.jo_link = self.create_symlink("testfolder/jump-out/a/b", "../c/d")
27+
28+
# 3. Globbing setup
29+
self.glob_base = "testfolder/globbing"
30+
self.glob_a = self.create_path(f"{self.glob_base}/a")
31+
self.glob_c = self.create_path(f"{self.glob_base}/c")
32+
33+
def tearDown(self):
34+
shutil.rmtree(self.test_dir)
35+
36+
def create_path(self, path_str):
37+
"""Helper to create directories/files within the tmp folder."""
38+
path = self.root / path_str
39+
path.mkdir(parents=True, exist_ok=True)
40+
return path
41+
42+
def create_symlink(self, link_path_str, target_path_str):
43+
"""Helper to create symlinks within the tmp folder."""
44+
link_path = self.root / link_path_str
45+
target_path = self.root / target_path_str
46+
link_path.parent.mkdir(parents=True, exist_ok=True)
47+
link_path.symlink_to(target_path_str)
48+
return link_path
49+
50+
def test_symlink_parents_neighbors(self):
51+
self.assertEqual(symlink_parents(self.ml_a), [])
52+
self.assertEqual(symlink_parents(self.ml_b), [self.ml_a])
53+
self.assertEqual(symlink_parents(self.ml_c), [self.ml_b, self.ml_a])
54+
55+
def test_symlink_parents_jump_out(self):
56+
self.assertEqual(symlink_parents(self.jo_link), [self.jo_target])
57+
58+
def test_pattern_extraction(self):
59+
# Pull paths generated in setUp or resolve new ones
60+
a1 = str(self.ml_a)
61+
a2 = str(self.ml_b)
62+
b1 = str(self.root / "testfolder/jump-out/a")
63+
b2 = str(self.root / "testfolder/jump-out/c")
64+
65+
allowed_patterns = {
66+
"a": {
67+
"onFeatures": ["a", "a1"],
68+
"paths": [a1, a2],
69+
"unsafeFollowSymlinks": True,
70+
},
71+
"b": {
72+
"onFeatures": ["b", "b2"],
73+
"paths": [b1, b2],
74+
"unsafeFollowSymlinks": True,
75+
},
76+
}
77+
78+
self.assertEqual(list(enumerate_patterns(allowed_patterns, [])), [])
79+
self.assertEqual(
80+
list(enumerate_patterns(allowed_patterns, ["a"])),
81+
[(a1, a1, True), (a2, a2, True)],
82+
)
83+
self.assertEqual(
84+
list(enumerate_patterns(allowed_patterns, ["b"])),
85+
[(b1, b1, True), (b2, b2, True)],
86+
)
87+
88+
def test_pattern_globbing(self):
89+
full_base_path = str(self.root / self.glob_base)
90+
91+
allowed_patterns = {
92+
"a": {
93+
"onFeatures": ["a"],
94+
"paths": [f"{full_base_path}/*"],
95+
"unsafeFollowSymlinks": True,
96+
}
97+
}
98+
99+
results = list(enumerate_patterns(allowed_patterns, ["a"]))
100+
101+
# Note: glob order can vary, so we compare sets or sorted lists
102+
expected = [
103+
(str(self.glob_a), str(self.glob_a), True),
104+
(str(self.glob_c), str(self.glob_c), True),
105+
]
106+
self.assertEqual(sorted(results), sorted(expected))
107+
108+
def test_pruning(self):
109+
root = str(self.root / "testfolder/jump-out")
110+
a = str(self.root / "testfolder/jump-out/a")
111+
b = str(self.root / "testfolder/jump-out/b")
112+
113+
self.assertEqual(prune_paths([]), [])
114+
115+
self.assertEqual(
116+
prune_paths([(a, a, True), (b, b, True), (root, root, True)]),
117+
[(root, root, True)],
118+
)
119+
120+
def test_path_discovery(self):
121+
ma = str(self.root / "testfolder/multi-link/a")
122+
mb = str(self.root / "testfolder/multi-link/b")
123+
mc = str(self.root / "testfolder/multi-link/c")
124+
ja = str(self.root / "testfolder/jump-out/a")
125+
j_target = str(self.root / "testfolder/jump-out/c/d")
126+
self.maxDiff = None
127+
self.assertEqual(
128+
sorted(discover_reachable_paths([(mc, mc, True), (ja, ja, True)])),
129+
[(ja, ja), (j_target, j_target), (ma, ma), (mb, mb), (mc, mc)],
130+
)
131+
132+
133+
if __name__ == "__main__":
134+
unittest.main()

0 commit comments

Comments
 (0)