|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import copy |
1 | 4 | import json |
2 | 5 | import re |
3 | 6 | import subprocess |
| 7 | +import sys |
| 8 | +from dataclasses import dataclass, field |
4 | 9 | from pathlib import Path |
5 | 10 |
|
6 | 11 | import pytest |
7 | 12 |
|
8 | 13 |
|
9 | | -def generate_reference_tests(): |
10 | | - # For each test .json file, create a test for each supported shell type. |
11 | | - reference_dir = Path(__file__).parent / "env_chaining" |
12 | | - references = [] |
13 | | - ids = [] |
14 | | - for ref in reference_dir.glob("*.json"): |
15 | | - plat = ref.name.split("-")[0] |
16 | | - if plat == "windows": |
17 | | - shells = ["bash_win", "bat", "ps1"] |
18 | | - else: |
19 | | - shells = ["bash_linux"] |
20 | | - for shell in shells: |
21 | | - references.append((shell, ref)) |
| 14 | +@dataclass |
| 15 | +class FileRef: |
| 16 | + path: Path |
| 17 | + name: str = field(init=False) |
| 18 | + shell: str | None = field(default=None) |
| 19 | + platform: str = field(init=False) |
| 20 | + distros: list = field(init=False) |
| 21 | + alias: str = field(init=False) |
| 22 | + _parser = re.compile( |
| 23 | + r"(?P<platform>[^-]+)-(?P<distros>[^-]+)-(?P<alias>[^.]+).json" |
| 24 | + ) |
22 | 25 |
|
23 | | - # Ensure consistent test sorting across all platforms |
24 | | - references = sorted(references) |
25 | | - for a, b in references: |
26 | | - ids.append(f"{a},{b.name}") |
27 | | - return references, ids |
| 26 | + def __post_init__(self) -> None: |
| 27 | + if not hasattr(self, "name"): |
| 28 | + self.name = self.path.stem |
| 29 | + match = self._parser.match(self.path.name) |
| 30 | + if not match: |
| 31 | + raise ValueError(f"Invalid filename: {self.path}") |
| 32 | + kwargs = match.groupdict() |
| 33 | + self.platform = kwargs["platform"] |
| 34 | + self.distros = kwargs["distros"].split(",") |
| 35 | + self.alias = kwargs["alias"] |
28 | 36 |
|
| 37 | + @classmethod |
| 38 | + def from_glob(cls, path, glob_str="*.json"): |
| 39 | + ret = [] |
| 40 | + for filename in path.glob(glob_str): |
| 41 | + ret.append(cls(filename)) |
| 42 | + return ret |
29 | 43 |
|
30 | | -references, ids = generate_reference_tests() |
| 44 | + @classmethod |
| 45 | + def shell_matrix(cls, file_refs, shells: dict[str, list]) -> list: |
| 46 | + ret = [] |
| 47 | + for file_ref in file_refs: |
| 48 | + plat_shells = shells.get(file_ref.platform, []) |
| 49 | + for shell in plat_shells: |
| 50 | + ref = copy.copy(file_ref) |
| 51 | + ref.shell = shell |
| 52 | + ret.append(ref) |
| 53 | + ret.sort(key=lambda i: repr(i)) |
| 54 | + return ret |
31 | 55 |
|
| 56 | + def __repr__(self) -> str: |
| 57 | + if self.shell: |
| 58 | + return f"{self.shell},{self.name}" |
| 59 | + return self.name |
32 | 60 |
|
33 | | -@pytest.mark.parametrize("shell,reference", references, ids=ids) |
34 | | -def test_chaining(shell, reference, config_root, tmp_path, run_hab): |
35 | 61 |
|
36 | | - match = re.match( |
37 | | - r"(?P<platform>[^-]+)-(?P<distros>[^-]+)-(?P<alias>[^.]+).json", reference.name |
38 | | - ) |
39 | | - kwargs = match.groupdict() |
| 62 | +references = FileRef.from_glob(Path(__file__).parent / "env_chaining") |
| 63 | +shell_references = FileRef.shell_matrix( |
| 64 | + references, {"windows": ["bash_win", "bat", "ps1"], "linux": ["bash_linux"]} |
| 65 | +) |
40 | 66 |
|
| 67 | + |
| 68 | +@pytest.mark.parametrize("reference", shell_references, ids=lambda f: repr(f)) |
| 69 | +def test_chaining(reference, config_root, tmp_path, run_hab): |
41 | 70 | # Skip tests that will not run on the current platform |
42 | | - run_hab.skip_wrong_platform(shell) |
| 71 | + run_hab.skip_wrong_platform(reference.shell) |
43 | 72 |
|
44 | 73 | runner = run_hab(config_root, tmp_path, stderr=subprocess.PIPE) |
45 | 74 | sub_cmd = [] |
46 | | - for d in kwargs["distros"].split(","): |
| 75 | + for d in reference.distros: |
47 | 76 | sub_cmd.extend(["-r", f"var-chain-{d}"]) |
48 | | - sub_cmd += ["launch", ",", kwargs["alias"]] |
49 | | - proc = runner.run_in_shell(shell, sub_cmd) |
| 77 | + sub_cmd += ["launch", ",", reference.alias] |
| 78 | + proc = runner.run_in_shell(reference.shell, sub_cmd) |
50 | 79 |
|
51 | 80 | # Check that the env vars were set as expected |
52 | 81 | assert proc.returncode == 0 |
53 | 82 |
|
54 | | - check = json.load(reference.open()) |
| 83 | + check = json.load(reference.path.open()) |
55 | 84 | result = json.loads(proc.stdout) |
56 | 85 | assert result == check |
| 86 | + |
| 87 | + |
| 88 | +@pytest.mark.parametrize("reference", references, ids=lambda f: repr(f)) |
| 89 | +def test_launch(habcached_resolver, reference): |
| 90 | + if sys.platform.startswith("linux"): |
| 91 | + if reference.platform != "linux": |
| 92 | + raise pytest.skip("test doesn't apply on linux") |
| 93 | + elif sys.platform == "win32": |
| 94 | + if reference.platform != "windows": |
| 95 | + raise pytest.skip("test doesn't apply on windows") |
| 96 | + |
| 97 | + forced_requirements = [] |
| 98 | + for d in reference.distros: |
| 99 | + forced_requirements.extend([f"var-chain-{d}"]) |
| 100 | + cfg = habcached_resolver.resolve(",", forced_requirements=forced_requirements) |
| 101 | + |
| 102 | + proc = cfg.launch(reference.alias, args=None, blocking=True, stderr=subprocess.PIPE) |
| 103 | + assert proc.returncode == 0 |
| 104 | + check = json.load(reference.path.open()) |
| 105 | + result = json.loads(proc.output_stdout) |
| 106 | + assert result == check |
0 commit comments