diff --git a/API.md b/API.md index 07992d6..895ad4e 100644 --- a/API.md +++ b/API.md @@ -8,11 +8,17 @@ import torchruntime You can use the command line: `python -m torchruntime install ` +CLI flags: `--policy `, `--preview`, `--no-unsupported`, `--uv` + Or you can use the library: ```py torchruntime.install(["torch", "torchvision<0.20"]) ``` +Optional flags: +- `preview=True` to allow preview builds (e.g. ROCm 6.4, nightly builds, XPU test index) +- `unsupported=False` to forbid EOL/unsupported builds (e.g. Torch-DirectML, IPEX, Torch 1.x) + On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`). ## Test torch diff --git a/README.md b/README.md index 6558d5e..5bdd2bd 100644 --- a/README.md +++ b/README.md @@ -34,6 +34,10 @@ On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the approp **Tip:** You can also add the `--uv` flag to install packages using [uv](https://docs.astral.sh/uv/) (instead of `pip`). For e.g. `python -m torchruntime install --uv` +Build-selection options: +- `--policy `: `compat` (default), `stable`, `preview` (or `nightly`) +- Overrides: `--preview`, `--no-unsupported` + ### Step 2. Configure torch This should be run inside your program, to initialize the required environment variables (if any) for the variant of torch being used. diff --git a/tests/test_installer.py b/tests/test_installer.py index 789b155..fb49cae 100644 --- a/tests/test_installer.py +++ b/tests/test_installer.py @@ -35,7 +35,7 @@ def test_cuda_platform_windows_installs_triton(monkeypatch): def test_cuda_nightly_platform_linux(monkeypatch): monkeypatch.setattr("torchruntime.installer.os_name", "Linux") packages = ["torch", "torchvision"] - result = get_install_commands("nightly/cu112", packages) + result = get_install_commands("nightly/cu112", packages, preview=True) expected_url = "https://download.pytorch.org/whl/nightly/cu112" assert result == [packages + ["--index-url", expected_url]] @@ -43,11 +43,18 @@ def test_cuda_nightly_platform_linux(monkeypatch): def test_cuda_nightly_platform_windows_installs_triton(monkeypatch): monkeypatch.setattr("torchruntime.installer.os_name", "Windows") packages = ["torch", "torchvision"] - result = get_install_commands("nightly/cu112", packages) + result = get_install_commands("nightly/cu112", packages, preview=True) expected_url = "https://download.pytorch.org/whl/nightly/cu112" assert result == [packages + ["--index-url", expected_url], ["triton-windows"]] +def test_cuda_nightly_platform_requires_preview(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Linux") + packages = ["torch", "torchvision"] + with pytest.raises(ValueError, match="preview"): + get_install_commands("nightly/cu112", packages, preview=False) + + def test_rocm_4_platform_does_not_install_triton(monkeypatch): monkeypatch.setattr("torchruntime.installer.os_name", "Linux") packages = ["torch", "torchvision"] @@ -71,25 +78,43 @@ def test_rocm_6_platform_linux_installs_triton(monkeypatch): def test_xpu_platform_windows_with_torch_only(monkeypatch): monkeypatch.setattr("torchruntime.installer.os_name", "Windows") packages = ["torch"] - result = get_install_commands("xpu", packages) - expected_url = "https://download.pytorch.org/whl/test/xpu" + result = get_install_commands("xpu", packages, preview=False) + expected_url = "https://download.pytorch.org/whl/xpu" + assert result == [packages + ["--index-url", expected_url]] + + +def test_xpu_platform_windows_with_torchvision(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Windows") + packages = ["torch", "torchvision"] + result = get_install_commands("xpu", packages, preview=False) + expected_url = "https://download.pytorch.org/whl/xpu" assert result == [packages + ["--index-url", expected_url]] -def test_xpu_platform_windows_with_torchvision(monkeypatch, capsys): +def test_xpu_platform_windows_preview(monkeypatch): monkeypatch.setattr("torchruntime.installer.os_name", "Windows") packages = ["torch", "torchvision"] - result = get_install_commands("xpu", packages) - expected_url = "https://download.pytorch.org/whl/nightly/xpu" + result = get_install_commands("xpu", packages, preview=True) + expected_url = "https://download.pytorch.org/whl/test/xpu" assert result == [packages + ["--index-url", expected_url]] - captured = capsys.readouterr() - assert "[WARNING]" in captured.out def test_xpu_platform_linux(monkeypatch): monkeypatch.setattr("torchruntime.installer.os_name", "Linux") packages = ["torch", "torchvision"] - result = get_install_commands("xpu", packages) + result = get_install_commands("xpu", packages, preview=False) + expected_url = "https://download.pytorch.org/whl/xpu" + triton_index_url = "https://download.pytorch.org/whl" + assert result == [ + packages + ["--index-url", expected_url], + ["pytorch-triton-xpu", "--index-url", triton_index_url], + ] + + +def test_xpu_platform_linux_preview(monkeypatch): + monkeypatch.setattr("torchruntime.installer.os_name", "Linux") + packages = ["torch", "torchvision"] + result = get_install_commands("xpu", packages, preview=True) expected_url = "https://download.pytorch.org/whl/test/xpu" triton_index_url = "https://download.pytorch.org/whl" assert result == [ diff --git a/tests/test_platform_detection.py b/tests/test_platform_detection.py index 11de967..f0821b5 100644 --- a/tests/test_platform_detection.py +++ b/tests/test_platform_detection.py @@ -36,7 +36,8 @@ def test_amd_gpu_navi4_linux(monkeypatch): with pytest.raises(NotImplementedError): get_torch_platform(gpu_infos) else: - assert get_torch_platform(gpu_infos) == "rocm6.4" + assert get_torch_platform(gpu_infos) == "cpu" + assert get_torch_platform(gpu_infos, preview=True) == "rocm6.4" def test_amd_gpu_navi3_linux(monkeypatch, capsys): @@ -89,6 +90,14 @@ def test_amd_gpu_ellesmere_linux(monkeypatch): assert get_torch_platform(gpu_infos) == "rocm4.2" +def test_amd_gpu_ellesmere_linux_unsupported_false_raises(monkeypatch): + monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux") + monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64") + gpu_infos = [GPU(AMD, "AMD", 0x1234, "Ellesmere", True)] + with pytest.raises(ValueError, match="End-of-Life"): + get_torch_platform(gpu_infos, unsupported=False) + + def test_amd_gpu_unsupported_linux(monkeypatch, capsys): monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux") monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64") diff --git a/tests/test_policy_parsing.py b/tests/test_policy_parsing.py new file mode 100644 index 0000000..e26f69f --- /dev/null +++ b/tests/test_policy_parsing.py @@ -0,0 +1,81 @@ +import pytest +from torchruntime.utils.args import parse_policy_args + + +def test_default_policy(): + args = ["pkg1"] + preview, unsupported, cleaned = parse_policy_args(args) + assert preview is False + assert unsupported is True + assert cleaned == ["pkg1"] + + +def test_stable_policy(): + args = ["--policy", "stable", "pkg1"] + preview, unsupported, cleaned = parse_policy_args(args) + assert preview is False + assert unsupported is False + assert cleaned == ["pkg1"] + + +def test_nightly_policy(): + args = ["--policy", "nightly"] + preview, unsupported, cleaned = parse_policy_args(args) + assert preview is True + assert unsupported is True + assert cleaned == [] + + +def test_preview_policy_alias(): + args = ["--policy", "preview"] + preview, unsupported, cleaned = parse_policy_args(args) + assert preview is True + assert unsupported is True + assert cleaned == [] + + +def test_policy_equals_syntax(): + args = ["--policy=stable", "pkg1"] + preview, unsupported, cleaned = parse_policy_args(args) + assert preview is False + assert unsupported is False + assert cleaned == ["pkg1"] + + +def test_policy_override_preview(): + # stable is p=F, u=F. --preview should make p=T + args = ["--policy", "stable", "--preview"] + preview, unsupported, cleaned = parse_policy_args(args) + assert preview is True + assert unsupported is False + +def test_policy_override_unsupported(): + # nightly is p=T, u=T. --no-unsupported should make u=F + args = ["--policy", "nightly", "--no-unsupported"] + preview, unsupported, cleaned = parse_policy_args(args) + assert preview is True + assert unsupported is False + + +def test_unknown_policy(): + args = ["--policy", "nonexistent"] + with pytest.raises(ValueError, match="Unknown policy"): + parse_policy_args(args) + + +def test_missing_policy_arg(): + args = ["--policy"] + with pytest.raises(ValueError, match="--policy requires an argument"): + parse_policy_args(args) + + +def test_mixed_args(): + args = ["torch", "--preview", "--policy", "stable", "--uv"] + # stable: p=F, u=F + # --preview: p=T + # Result: p=T, u=F + # cleaned: ["torch", "--uv"] + preview, unsupported, cleaned = parse_policy_args(args) + assert preview is True + assert unsupported is False + assert cleaned == ["torch", "--uv"] diff --git a/tests/test_segmentation.py b/tests/test_segmentation.py new file mode 100644 index 0000000..dcc553a --- /dev/null +++ b/tests/test_segmentation.py @@ -0,0 +1,68 @@ +import pytest +from torchruntime.device_db import GPU +from torchruntime.platform_detection import AMD, INTEL, get_torch_platform, py_version + + +def test_preview_rocm_6_4_selection(monkeypatch): + monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux") + monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64") + gpu_infos = [GPU(AMD, "AMD", 0x1234, "Navi 41", True)] + + if py_version < (3, 9): + pytest.skip("Navi 4 requires Python 3.9+") + + # Default: preview=False -> cpu + assert get_torch_platform(gpu_infos) == "cpu" + assert get_torch_platform(gpu_infos, preview=False) == "cpu" + + # preview=True -> rocm6.4 + assert get_torch_platform(gpu_infos, preview=True) == "rocm6.4" + + +def test_eol_rocm_5_2_selection(monkeypatch): + monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux") + monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64") + gpu_infos = [GPU(AMD, "AMD", 0x1234, "Navi 10", True)] + + assert get_torch_platform(gpu_infos) == "rocm5.2" + + with pytest.raises(ValueError, match="considered End-of-Life"): + get_torch_platform(gpu_infos, unsupported=False) + + +def test_eol_rocm42_selection(monkeypatch): + monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux") + monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64") + # Ellesmere (e.g. RX 580) + gpu_infos = [GPU(AMD, "AMD", "67df", "Ellesmere [Radeon RX 470/480/570/570X/580/580X/590]", True)] + + # Default: unsupported=True -> rocm4.2 + assert get_torch_platform(gpu_infos) == "rocm4.2" + + # unsupported=False -> raises ValueError + with pytest.raises(ValueError, match="considered End-of-Life"): + get_torch_platform(gpu_infos, unsupported=False) + + +def test_eol_directml_selection(monkeypatch): + monkeypatch.setattr("torchruntime.platform_detection.os_name", "Windows") + monkeypatch.setattr("torchruntime.platform_detection.arch", "amd64") + gpu_infos = [GPU(AMD, "AMD", 0x1234, "Radeon", True)] + + assert get_torch_platform(gpu_infos) == "directml" + + with pytest.raises(ValueError, match="considered End-of-Life"): + get_torch_platform(gpu_infos, unsupported=False) + + +def test_eol_ipex_selection(monkeypatch): + monkeypatch.setattr("torchruntime.platform_detection.os_name", "Linux") + monkeypatch.setattr("torchruntime.platform_detection.arch", "x86_64") + monkeypatch.setattr("torchruntime.platform_detection.py_version", (3, 8)) + gpu_infos = [GPU(INTEL, "Intel", 0x1234, "Iris", True)] + + assert get_torch_platform(gpu_infos) == "ipex" + + # unsupported=False -> raises ValueError + with pytest.raises(ValueError, match="considered End-of-Life"): + get_torch_platform(gpu_infos, unsupported=False) diff --git a/torchruntime/__main__.py b/torchruntime/__main__.py index fefd567..3b78ad5 100644 --- a/torchruntime/__main__.py +++ b/torchruntime/__main__.py @@ -1,6 +1,7 @@ from .installer import install from .utils.torch_test import test from .utils import info +from .utils.args import parse_policy_args def print_usage(entry_command: str): @@ -16,6 +17,9 @@ def print_usage(entry_command: str): Examples: {entry_command} install {entry_command} install --uv + {entry_command} install --preview + {entry_command} install --no-unsupported + {entry_command} install --policy stable {entry_command} install torch==2.2.0 torchvision==0.17.0 {entry_command} install --uv torch>=2.0.0 torchaudio {entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0 @@ -35,6 +39,9 @@ def print_usage(entry_command: str): Options: --uv Use uv instead of pip for installation + --preview Allow preview builds (e.g. ROCm 6.4) + --no-unsupported Forbid EOL/unsupported builds (e.g. DirectML / IPEX / Torch 1.x) + --policy Set configuration policy (stable, compat, preview|nightly). Default: compat Version specification formats (follows pip format): package==2.1.0 Exact version @@ -62,15 +69,28 @@ def main(): if command == "install": args = sys.argv[2:] if len(sys.argv) > 2 else [] - use_uv = "--uv" in args - # Remove --uv from args to get package list - package_versions = [arg for arg in args if arg != "--uv"] if args else None - install(package_versions, use_uv=use_uv) + try: + preview, unsupported, cleaned_args = parse_policy_args(args) + except ValueError as e: + print(f"Error: {e}") + return + + use_uv = "--uv" in cleaned_args + # Remove --uv from package list + package_versions = [arg for arg in cleaned_args if arg != "--uv"] + install(package_versions, use_uv=use_uv, preview=preview, unsupported=unsupported) elif command == "test": subcommand = sys.argv[2] if len(sys.argv) > 2 else "all" test(subcommand) elif command == "info": - info() + args = sys.argv[2:] if len(sys.argv) > 2 else [] + try: + preview, unsupported, _ = parse_policy_args(args) + except ValueError as e: + print(f"Error: {e}") + return + from .utils import info + info(preview=preview, unsupported=unsupported) else: print(f"Unknown command: {command}") entry_path = sys.argv[0] diff --git a/torchruntime/consts.py b/torchruntime/consts.py index 0832c3a..67e334b 100644 --- a/torchruntime/consts.py +++ b/torchruntime/consts.py @@ -3,3 +3,10 @@ AMD = "1002" NVIDIA = "10de" INTEL = "8086" + +POLICIES = { + "stable": (False, False), # preview=False, unsupported=False + "compat": (False, True), # preview=False, unsupported=True (Default) + "preview": (True, True), # preview=True, unsupported=True + "nightly": (True, True), # alias for preview +} diff --git a/torchruntime/installer.py b/torchruntime/installer.py index 1153188..ecb68fd 100644 --- a/torchruntime/installer.py +++ b/torchruntime/installer.py @@ -2,20 +2,17 @@ import sys import platform import subprocess - -from .consts import CONTACT_LINK from .device_db import get_gpus from .platform_detection import get_torch_platform os_name = platform.system() -PIP_PREFIX = [sys.executable, "-m", "pip", "install"] CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$") ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$") ROCM_VERSION_REGEX = re.compile(r"^(?:nightly/)?rocm(?P\d+)\.(?P\d+)$") -def get_install_commands(torch_platform, packages): +def get_install_commands(torch_platform, packages, preview=False): """ Generates pip installation commands for PyTorch and related packages based on the specified platform. @@ -30,6 +27,7 @@ def get_install_commands(torch_platform, packages): packages (list of str): List of package names (and optionally versions in pip format). Examples: - ["torch", "torchvision"] - ["torch>=2.0", "torchaudio==0.16.0"] + preview (bool): If True, allow preview/nightly builds. Defaults to False. Returns: list of list of str: Each sublist contains a pip install command (excluding the `pip install` prefix). @@ -41,7 +39,7 @@ def get_install_commands(torch_platform, packages): ValueError: If an unsupported platform is provided. Notes: - - For "xpu" on Windows, if torchvision or torchaudio are included, the function switches to nightly builds. + - For "xpu", if preview is True, the function installs from the test index. - For "directml", the "torch-directml" package is returned as part of the installation commands. - For "ipex", the "intel-extension-for-pytorch" package is returned as part of the installation commands. - For Windows CUDA, the function also installs "triton-windows" (for torch.compile and Triton kernels). @@ -54,6 +52,9 @@ def get_install_commands(torch_platform, packages): if torch_platform == "cpu": return [packages] + if torch_platform.startswith("nightly/") and not preview: + raise ValueError("preview=True is required for nightly builds") + if CUDA_REGEX.match(torch_platform) or ROCM_REGEX.match(torch_platform): index_url = f"https://download.pytorch.org/whl/{torch_platform}" cmds = [packages + ["--index-url", index_url]] @@ -69,15 +70,11 @@ def get_install_commands(torch_platform, packages): return cmds if torch_platform == "xpu": - if os_name == "Windows" and ("torchvision" in packages or "torchaudio" in packages): - print( - f"[WARNING] The preview build of 'xpu' on Windows currently only supports torch, not torchvision/torchaudio. " - f"torchruntime will instead use the nightly build, to get the 'xpu' version of torchaudio and torchvision as well. " - f"Please contact torchruntime if this is no longer accurate: {CONTACT_LINK}" - ) - index_url = f"https://download.pytorch.org/whl/nightly/{torch_platform}" - else: - index_url = f"https://download.pytorch.org/whl/test/{torch_platform}" + index_url = ( + f"https://download.pytorch.org/whl/test/{torch_platform}" + if preview + else f"https://download.pytorch.org/whl/{torch_platform}" + ) cmds = [packages + ["--index-url", index_url]] if os_name == "Linux": @@ -108,14 +105,16 @@ def run_commands(cmds): subprocess.run(cmd) -def install(packages=[], use_uv=False): +def install(packages=[], use_uv=False, preview=False, unsupported=True): """ packages: a list of strings with package names (and optionally their versions in pip-format). e.g. ["torch", "torchvision"] or ["torch>=2.0", "torchaudio==0.16.0"]. Defaults to ["torch", "torchvision", "torchaudio"]. use_uv: bool, whether to use uv for installation. Defaults to False. + preview: bool, whether to allow preview/nightly builds. Defaults to False. + unsupported: bool, whether to allow EOL/unsupported builds. Defaults to True. """ gpu_infos = get_gpus() - torch_platform = get_torch_platform(gpu_infos, packages=packages) - cmds = get_install_commands(torch_platform, packages) + torch_platform = get_torch_platform(gpu_infos, packages=packages, preview=preview, unsupported=unsupported) + cmds = get_install_commands(torch_platform, packages, preview=preview) cmds = get_pip_commands(cmds, use_uv=use_uv) run_commands(cmds) diff --git a/torchruntime/platform_detection.py b/torchruntime/platform_detection.py index f5b785f..70bffa3 100644 --- a/torchruntime/platform_detection.py +++ b/torchruntime/platform_detection.py @@ -51,13 +51,15 @@ def _packages_require_cuda_12_4(packages): return False -def get_torch_platform(gpu_infos, packages=[]): +def get_torch_platform(gpu_infos, packages=[], preview=False, unsupported=True): """ Determine the appropriate PyTorch platform to use based on the system architecture, OS, and GPU information. Args: gpu_infos (list of `torchruntime.device_db.GPU` instances) packages (list of str): Optional list of torch/torchvision/torchaudio requirement strings. + preview (bool): If True, allow preview/nightly builds (e.g. rocm6.4). Defaults to False. + unsupported (bool): If False, forbid EOL/unsupported builds (e.g. cu118). Defaults to True. Returns: str: A string representing the platform to use. Possible values: @@ -70,6 +72,7 @@ def get_torch_platform(gpu_infos, packages=[]): Raises: NotImplementedError: For unsupported architectures, OS-GPU combinations, or multiple GPU vendors. + ValueError: If unsupported=False and only an EOL build is available. Warning: Outputs warnings for deprecated Python versions or fallback configurations. """ @@ -95,12 +98,23 @@ def get_torch_platform(gpu_infos, packages=[]): integrated_devices.append(device) if discrete_devices: - return _get_platform_for_discrete(discrete_devices, packages=packages) + platform = _get_platform_for_discrete(discrete_devices, packages=packages, preview=preview) + else: + platform = _get_platform_for_integrated(integrated_devices, preview=preview) - return _get_platform_for_integrated(integrated_devices) + # Segmentation Logic + EOL_PLATFORMS = {"directml", "ipex", "rocm5.2", "rocm4.2"} + + if not unsupported and platform in EOL_PLATFORMS: + raise ValueError( + f"The recommended platform '{platform}' is considered End-of-Life (EOL) and is forbidden because 'unsupported' is set to False. " + f"Please use a more recent GPU or set unsupported=True to allow this installation." + ) + + return platform -def _get_platform_for_discrete(gpu_infos, packages=None): +def _get_platform_for_discrete(gpu_infos, packages=None, preview=False): vendor_ids = set(gpu.vendor_id for gpu in gpu_infos) if len(vendor_ids) > 1: @@ -124,7 +138,13 @@ def _get_platform_for_discrete(gpu_infos, packages=None): raise NotImplementedError( f"Torch does not support Navi 4x series of GPUs on Python 3.8. Please switch to a newer Python version to use the latest version of torch!" ) - return "rocm6.4" + if preview: + return "rocm6.4" + print( + "[WARNING] Navi 4x series GPUs require preview ROCm builds (rocm6.4). " + "torchruntime will fall back to CPU unless preview=True is enabled." + ) + return "cpu" if any(device_name.startswith("Navi") for device_name in device_names) and any( device_name.startswith("Vega 2") for device_name in device_names ): # lowest-common denominator is rocm5.7, which works with both Navi and Vega 20 @@ -201,7 +221,7 @@ def _get_platform_for_discrete(gpu_infos, packages=None): return "cpu" -def _get_platform_for_integrated(gpu_infos): +def _get_platform_for_integrated(gpu_infos, preview=False): gpu = gpu_infos[0] if os_name == "Windows": @@ -235,3 +255,4 @@ def _get_platform_for_integrated(gpu_infos): ) return "cpu" + diff --git a/torchruntime/utils/__init__.py b/torchruntime/utils/__init__.py index 3865d7c..546cfdc 100644 --- a/torchruntime/utils/__init__.py +++ b/torchruntime/utils/__init__.py @@ -7,7 +7,7 @@ ) -def info(): +def info(preview=False, unsupported=True): from torchruntime.device_db import get_gpus from torchruntime.platform_detection import get_torch_platform from torchruntime.configuration import configure @@ -23,7 +23,7 @@ def info(): print("") print("--- RECOMMENDED TORCH PLATFORM ---") - torch_platform = get_torch_platform(gpu_infos) + torch_platform = get_torch_platform(gpu_infos, preview=preview, unsupported=unsupported) print(torch_platform) print("") diff --git a/torchruntime/utils/args.py b/torchruntime/utils/args.py new file mode 100644 index 0000000..8aed327 --- /dev/null +++ b/torchruntime/utils/args.py @@ -0,0 +1,68 @@ +from ..consts import POLICIES + + +def parse_policy_args(args): + """ + Parses arguments for policy and flags. + Returns (preview, unsupported, cleaned_args) + + Supports both `--policy NAME` and `--policy=NAME`. + + Logic: + 1. Determine base configuration from the LAST provided --policy argument (or default 'compat'). + 2. Apply explicit flags (--preview, --no-unsupported) which ALWAYS override the policy. + 3. Remove policy and flags from args to produce cleaned_args. + """ + # Default: compat + preview, unsupported = POLICIES["compat"] + + # 1. Scan for the last policy to set the baseline + last_policy_name = None + i = 0 + while i < len(args): + arg = args[i] + if arg == "--policy": + if i + 1 < len(args): + last_policy_name = args[i+1] + i += 2 + else: + # We will catch this error in the second pass or we can raise it now. + # Raising now is safer. + raise ValueError("--policy requires an argument") + elif arg.startswith("--policy="): + last_policy_name = arg.split("=", 1)[1] + if not last_policy_name: + raise ValueError("--policy requires an argument") + i += 1 + else: + i += 1 + + if last_policy_name: + if last_policy_name in POLICIES: + preview, unsupported = POLICIES[last_policy_name] + else: + raise ValueError(f"Unknown policy: {last_policy_name}") + + # 2. Apply flags and build cleaned_args + cleaned_args = [] + i = 0 + while i < len(args): + arg = args[i] + if arg == "--policy": + # Skip policy and its value (already processed) + i += 2 + continue + elif arg.startswith("--policy="): + i += 1 + continue + elif arg == "--preview": + preview = True + i += 1 + elif arg == "--no-unsupported": + unsupported = False + i += 1 + else: + cleaned_args.append(arg) + i += 1 + + return preview, unsupported, cleaned_args