diff --git a/src/main.py b/src/main.py index 80799913..34f5a3e1 100644 --- a/src/main.py +++ b/src/main.py @@ -159,6 +159,8 @@ def _create_new_version_conda_specs( out.append(f"conda-forge::{package_name}") else: channel = match_out.get("channel").channel_name + # Build string of the package, like "cuda120py311h51447cc_202". + build_string = match_out.spec.split(" ")[-1] min_version_inclusive = match_out.get("version") assert str(min_version_inclusive).startswith("==") @@ -167,8 +169,12 @@ def _create_new_version_conda_specs( max_version_str = _get_dependency_upper_bound_for_runtime_upgrade( package_name, min_version_inclusive, runtime_version_upgrade_type ) + version_constraint = f"version='>={min_version_inclusive}{max_version_str}'" + # Explicitly constraint build string for GPU packages, like PyTorch, tensorflow. + if "cuda" in build_string: + version_constraint += ",build='*cuda*'" - out.append(f"{channel}::{package_name}[version='>={min_version_inclusive}{max_version_str}']") + out.append(f"{channel}::{package_name}[{version_constraint}]") with open(f"{new_version_dir}/{env_in_filename}", "w") as f: f.write("# This file is auto-generated.\n")