From 53140e804fb4bb76b7f492b7f20c7c27a0605e3a Mon Sep 17 00:00:00 2001 From: Tian Wang <133085652+aws-tianquaw@users.noreply.github.com> Date: Sat, 12 Apr 2025 15:52:16 -0700 Subject: [PATCH 1/2] Explicitly constraint build string for GPU packages in generated env.in files --- src/main.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/main.py b/src/main.py index 80799913..e19618aa 100644 --- a/src/main.py +++ b/src/main.py @@ -159,6 +159,8 @@ def _create_new_version_conda_specs( out.append(f"conda-forge::{package_name}") else: channel = match_out.get("channel").channel_name + # E.g. "tensorflow ==2.17.0 cuda120py311h51447cc_202" + version_spec = match_out.spec min_version_inclusive = match_out.get("version") assert str(min_version_inclusive).startswith("==") @@ -167,8 +169,12 @@ def _create_new_version_conda_specs( max_version_str = _get_dependency_upper_bound_for_runtime_upgrade( package_name, min_version_inclusive, runtime_version_upgrade_type ) + version_constraint = f"version='>={min_version_inclusive}{max_version_str}'" + # Explicitly constraint build string for GPU packages, like PyTorch, tensorflow. + if "cuda" in version_spec: + version_constraint += ",build='*cuda*'" - out.append(f"{channel}::{package_name}[version='>={min_version_inclusive}{max_version_str}']") + out.append(f"{channel}::{package_name}[{version_constraint}]") with open(f"{new_version_dir}/{env_in_filename}", "w") as f: f.write("# This file is auto-generated.\n") From 45a0a1cf17151933a6462376ae56e29a47258a71 Mon Sep 17 00:00:00 2001 From: Tian Wang <133085652+aws-tianquaw@users.noreply.github.com> Date: Sat, 12 Apr 2025 15:57:07 -0700 Subject: [PATCH 2/2] Fix build string --- src/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/main.py b/src/main.py index e19618aa..34f5a3e1 100644 --- a/src/main.py +++ b/src/main.py @@ -159,8 +159,8 @@ def _create_new_version_conda_specs( out.append(f"conda-forge::{package_name}") else: channel = match_out.get("channel").channel_name - # E.g. "tensorflow ==2.17.0 cuda120py311h51447cc_202" - version_spec = match_out.spec + # Build string of the package, like "cuda120py311h51447cc_202". + build_string = match_out.spec.split(" ")[-1] min_version_inclusive = match_out.get("version") assert str(min_version_inclusive).startswith("==") @@ -171,7 +171,7 @@ def _create_new_version_conda_specs( ) version_constraint = f"version='>={min_version_inclusive}{max_version_str}'" # Explicitly constraint build string for GPU packages, like PyTorch, tensorflow. - if "cuda" in version_spec: + if "cuda" in build_string: version_constraint += ",build='*cuda*'" out.append(f"{channel}::{package_name}[{version_constraint}]")