Skip to content

Commit 6f71a8a

Browse files
committed
WIP: Re-work chaining parametrize arguments for ease of use
1 parent 702cb41 commit 6f71a8a

File tree

1 file changed

+80
-30
lines changed

1 file changed

+80
-30
lines changed

tests/test_env_var_chaining.py

Lines changed: 80 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,56 +1,106 @@
1+
from __future__ import annotations
2+
3+
import copy
14
import json
25
import re
36
import subprocess
7+
import sys
8+
from dataclasses import dataclass, field
49
from pathlib import Path
510

611
import pytest
712

813

9-
def generate_reference_tests():
10-
# For each test .json file, create a test for each supported shell type.
11-
reference_dir = Path(__file__).parent / "env_chaining"
12-
references = []
13-
ids = []
14-
for ref in reference_dir.glob("*.json"):
15-
plat = ref.name.split("-")[0]
16-
if plat == "windows":
17-
shells = ["bash_win", "bat", "ps1"]
18-
else:
19-
shells = ["bash_linux"]
20-
for shell in shells:
21-
references.append((shell, ref))
14+
@dataclass
15+
class FileRef:
16+
path: Path
17+
name: str = field(init=False)
18+
shell: str | None = field(default=None)
19+
platform: str = field(init=False)
20+
distros: list = field(init=False)
21+
alias: str = field(init=False)
22+
_parser = re.compile(
23+
r"(?P<platform>[^-]+)-(?P<distros>[^-]+)-(?P<alias>[^.]+).json"
24+
)
2225

23-
# Ensure consistent test sorting across all platforms
24-
references = sorted(references)
25-
for a, b in references:
26-
ids.append(f"{a},{b.name}")
27-
return references, ids
26+
def __post_init__(self) -> None:
27+
if not hasattr(self, "name"):
28+
self.name = self.path.stem
29+
match = self._parser.match(self.path.name)
30+
if not match:
31+
raise ValueError(f"Invalid filename: {self.path}")
32+
kwargs = match.groupdict()
33+
self.platform = kwargs["platform"]
34+
self.distros = kwargs["distros"].split(",")
35+
self.alias = kwargs["alias"]
2836

37+
@classmethod
38+
def from_glob(cls, path, glob_str="*.json"):
39+
ret = []
40+
for filename in path.glob(glob_str):
41+
ret.append(cls(filename))
42+
return ret
2943

30-
references, ids = generate_reference_tests()
44+
@classmethod
45+
def shell_matrix(cls, file_refs, shells: dict[str, list]) -> list:
46+
ret = []
47+
for file_ref in file_refs:
48+
plat_shells = shells.get(file_ref.platform, [])
49+
for shell in plat_shells:
50+
ref = copy.copy(file_ref)
51+
ref.shell = shell
52+
ret.append(ref)
53+
ret.sort(key=lambda i: repr(i))
54+
return ret
3155

56+
def __repr__(self) -> str:
57+
if self.shell:
58+
return f"{self.shell},{self.name}"
59+
return self.name
3260

33-
@pytest.mark.parametrize("shell,reference", references, ids=ids)
34-
def test_chaining(shell, reference, config_root, tmp_path, run_hab):
3561

36-
match = re.match(
37-
r"(?P<platform>[^-]+)-(?P<distros>[^-]+)-(?P<alias>[^.]+).json", reference.name
38-
)
39-
kwargs = match.groupdict()
62+
references = FileRef.from_glob(Path(__file__).parent / "env_chaining")
63+
shell_references = FileRef.shell_matrix(
64+
references, {"windows": ["bash_win", "bat", "ps1"], "linux": ["bash_linux"]}
65+
)
4066

67+
68+
@pytest.mark.parametrize("reference", shell_references, ids=lambda f: repr(f))
69+
def test_chaining(reference, config_root, tmp_path, run_hab):
4170
# Skip tests that will not run on the current platform
42-
run_hab.skip_wrong_platform(shell)
71+
run_hab.skip_wrong_platform(reference.shell)
4372

4473
runner = run_hab(config_root, tmp_path, stderr=subprocess.PIPE)
4574
sub_cmd = []
46-
for d in kwargs["distros"].split(","):
75+
for d in reference.distros:
4776
sub_cmd.extend(["-r", f"var-chain-{d}"])
48-
sub_cmd += ["launch", ",", kwargs["alias"]]
49-
proc = runner.run_in_shell(shell, sub_cmd)
77+
sub_cmd += ["launch", ",", reference.alias]
78+
proc = runner.run_in_shell(reference.shell, sub_cmd)
5079

5180
# Check that the env vars were set as expected
5281
assert proc.returncode == 0
5382

54-
check = json.load(reference.open())
83+
check = json.load(reference.path.open())
5584
result = json.loads(proc.stdout)
5685
assert result == check
86+
87+
88+
@pytest.mark.parametrize("reference", references, ids=lambda f: repr(f))
89+
def test_launch(habcached_resolver, reference):
90+
if sys.platform.startswith("linux"):
91+
if reference.platform != "linux":
92+
raise pytest.skip("test doesn't apply on linux")
93+
elif sys.platform == "win32":
94+
if reference.platform != "windows":
95+
raise pytest.skip("test doesn't apply on windows")
96+
97+
forced_requirements = []
98+
for d in reference.distros:
99+
forced_requirements.extend([f"var-chain-{d}"])
100+
cfg = habcached_resolver.resolve(",", forced_requirements=forced_requirements)
101+
102+
proc = cfg.launch(reference.alias, args=None, blocking=True, stderr=subprocess.PIPE)
103+
assert proc.returncode == 0
104+
check = json.load(reference.path.open())
105+
result = json.loads(proc.output_stdout)
106+
assert result == check

0 commit comments

Comments
 (0)