Skip to content

Commit 4995d84

Browse files
committed
rebase to latest main
1 parent 41644c2 commit 4995d84

File tree

1 file changed

+153
-8
lines changed

1 file changed

+153
-8
lines changed

install_requirements.py

Lines changed: 153 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,16 @@ def python_is_compatible():
5959

6060

6161
# The pip repository that hosts nightly torch packages.
62-
TORCH_NIGHTLY_URL = "https://download.pytorch.org/whl/nightly/cpu"
62+
# This will be dynamically set based on CUDA availability and CUDA backend enabled/disabled.
63+
TORCH_NIGHTLY_URL_BASE = "https://download.pytorch.org/whl/nightly"
6364

65+
# Supported CUDA versions - modify this to add/remove supported versions
66+
# Format: tuple of (major, minor) version numbers
67+
SUPPORTED_CUDA_VERSIONS = [
68+
(12, 6),
69+
(12, 8),
70+
(12, 9),
71+
]
6472

6573
# Since ExecuTorch often uses main-branch features of pytorch, only the nightly
6674
# pip versions will have the required features.
@@ -71,7 +79,137 @@ def python_is_compatible():
7179
#
7280
# NOTE: If you're changing, make the corresponding change in .ci/docker/ci_commit_pins/pytorch.txt
7381
# by picking the hash from the same date in https://hud.pytorch.org/hud/pytorch/pytorch/nightly/
74-
NIGHTLY_VERSION = "dev20250906"
82+
#
83+
# NOTE: If you're changing, make the corresponding supported CUDA versions in
84+
# SUPPORTED_CUDA_VERSIONS above if needed.
85+
NIGHTLY_VERSION = "dev20250915"
86+
87+
88+
def _check_cuda_enabled():
89+
"""Check if CUDA delegate is enabled via CMAKE_ARGS environment variable."""
90+
cmake_args = os.environ.get("CMAKE_ARGS", "")
91+
return "-DEXECUTORCH_BUILD_CUDA=ON" in cmake_args
92+
93+
94+
def _cuda_version_to_pytorch_suffix(major, minor):
95+
"""
96+
Generate PyTorch CUDA wheel suffix from CUDA version numbers.
97+
98+
Args:
99+
major: CUDA major version (e.g., 12)
100+
minor: CUDA minor version (e.g., 6)
101+
102+
Returns:
103+
PyTorch wheel suffix string (e.g., "cu126")
104+
"""
105+
return f"cu{major}{minor}"
106+
107+
108+
def _get_cuda_version():
109+
"""
110+
Get the CUDA version installed on the system using nvcc command.
111+
Returns a tuple (major, minor).
112+
113+
Raises:
114+
RuntimeError: if nvcc is not found or version cannot be parsed
115+
"""
116+
try:
117+
# Get CUDA version from nvcc (CUDA compiler)
118+
nvcc_result = subprocess.run(
119+
["nvcc", "--version"], capture_output=True, text=True, check=True
120+
)
121+
# Parse nvcc output for CUDA version
122+
# Output contains line like "Cuda compilation tools, release 12.6, V12.6.68"
123+
match = re.search(r"release (\d+)\.(\d+)", nvcc_result.stdout)
124+
if match:
125+
major, minor = int(match.group(1)), int(match.group(2))
126+
127+
# Check if the detected version is supported
128+
if (major, minor) not in SUPPORTED_CUDA_VERSIONS:
129+
available_versions = ", ".join(
130+
[f"{maj}.{min}" for maj, min in SUPPORTED_CUDA_VERSIONS]
131+
)
132+
raise RuntimeError(
133+
f"Detected CUDA version {major}.{minor} is not supported. "
134+
f"Only the following CUDA versions are supported: {available_versions}. "
135+
f"Please install a supported CUDA version or try on CPU-only delegates."
136+
)
137+
138+
return (major, minor)
139+
else:
140+
raise RuntimeError(
141+
"CUDA delegate is enabled but could not parse CUDA version from nvcc output. "
142+
"Please ensure CUDA is properly installed or try on CPU-only delegates."
143+
)
144+
except FileNotFoundError:
145+
raise RuntimeError(
146+
"CUDA delegate is enabled but nvcc (CUDA compiler) is not found in PATH. "
147+
"Please install CUDA toolkit or try on CPU-only delegates."
148+
)
149+
except subprocess.CalledProcessError as e:
150+
raise RuntimeError(
151+
f"CUDA delegate is enabled but nvcc command failed with error: {e}. "
152+
"Please ensure CUDA is properly installed or try on CPU-only delegates."
153+
)
154+
155+
156+
def _get_pytorch_cuda_url(cuda_version):
157+
"""
158+
Get the appropriate PyTorch CUDA URL for the given CUDA version.
159+
160+
Args:
161+
cuda_version: tuple of (major, minor) version numbers
162+
163+
Returns:
164+
URL string for PyTorch CUDA packages
165+
"""
166+
major, minor = cuda_version
167+
# Generate CUDA suffix (version validation is already done in _get_cuda_version)
168+
cuda_suffix = _cuda_version_to_pytorch_suffix(major, minor)
169+
170+
return f"{TORCH_NIGHTLY_URL_BASE}/{cuda_suffix}"
171+
172+
173+
# url for the PyTorch ExecuTorch depending on, which will be set by _determine_torch_url().
174+
# please do not directly rely on it, but use _determine_torch_url() instead.
175+
_torch_url = None
176+
177+
178+
def _determine_torch_url():
179+
"""
180+
Determine the appropriate PyTorch installation URL based on CUDA availability and CMAKE_ARGS.
181+
Uses caching to avoid redundant CUDA detection and print statements.
182+
183+
Returns:
184+
URL string for PyTorch packages
185+
"""
186+
global _torch_url
187+
188+
# Return cached URL if already determined
189+
if _torch_url is not None:
190+
return _torch_url
191+
192+
# Check if CUDA delegate is enabled
193+
if not _check_cuda_enabled():
194+
print("CUDA delegate not enabled, using CPU-only PyTorch")
195+
_torch_url = f"{TORCH_NIGHTLY_URL_BASE}/cpu"
196+
return _torch_url
197+
198+
print("CUDA delegate enabled, detecting CUDA version...")
199+
200+
# Get CUDA version
201+
cuda_version = _get_cuda_version()
202+
203+
major, minor = cuda_version
204+
print(f"Detected CUDA version: {major}.{minor}")
205+
206+
# Get appropriate PyTorch CUDA URL
207+
torch_url = _get_pytorch_cuda_url(cuda_version)
208+
print(f"Using PyTorch URL: {torch_url}")
209+
210+
# Cache the result
211+
_torch_url = torch_url
212+
return torch_url
75213

76214

77215
def install_requirements(use_pytorch_nightly):
@@ -84,12 +222,16 @@ def install_requirements(use_pytorch_nightly):
84222
)
85223
sys.exit(1)
86224

225+
# Determine the appropriate PyTorch URL based on CUDA delegate status
226+
torch_url = _determine_torch_url()
227+
87228
# pip packages needed by exir.
88229
TORCH_PACKAGE = [
89230
# Setting use_pytorch_nightly to false to test the pinned PyTorch commit. Note
90231
# that we don't need to set any version number there because they have already
91232
# been installed on CI before this step, so pip won't reinstall them
92-
f"torch==2.9.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torch",
233+
f"torch==2.10.0.{NIGHTLY_VERSION}" if use_pytorch_nightly else "torch",
234+
f"torchao==0.14.0{NIGHTLY_VERSION}" if use_pytorch_nightly else "torchao",
93235
]
94236

95237
# Install the requirements for core ExecuTorch package.
@@ -105,13 +247,13 @@ def install_requirements(use_pytorch_nightly):
105247
"requirements-dev.txt",
106248
*TORCH_PACKAGE,
107249
"--extra-index-url",
108-
TORCH_NIGHTLY_URL,
250+
torch_url,
109251
],
110252
check=True,
111253
)
112254

113255
LOCAL_REQUIREMENTS = [
114-
"third-party/ao", # We need the latest kernels for fast iteration, so not relying on pypi.
256+
# "third-party/ao", # We need the latest kernels for fast iteration, so not relying on pypi.
115257
] + (
116258
[
117259
"extension/llm/tokenizers", # TODO(larryliu0820): Setup a pypi package for this.
@@ -147,10 +289,13 @@ def install_requirements(use_pytorch_nightly):
147289

148290

149291
def install_optional_example_requirements(use_pytorch_nightly):
292+
# Determine the appropriate PyTorch URL based on CUDA delegate status
293+
torch_url = _determine_torch_url()
294+
150295
print("Installing torch domain libraries")
151296
DOMAIN_LIBRARIES = [
152297
(
153-
f"torchvision==0.24.0.{NIGHTLY_VERSION}"
298+
f"torchvision==0.25.0.{NIGHTLY_VERSION}"
154299
if use_pytorch_nightly
155300
else "torchvision"
156301
),
@@ -165,7 +310,7 @@ def install_optional_example_requirements(use_pytorch_nightly):
165310
"install",
166311
*DOMAIN_LIBRARIES,
167312
"--extra-index-url",
168-
TORCH_NIGHTLY_URL,
313+
torch_url,
169314
],
170315
check=True,
171316
)
@@ -180,7 +325,7 @@ def install_optional_example_requirements(use_pytorch_nightly):
180325
"-r",
181326
"requirements-examples.txt",
182327
"--extra-index-url",
183-
TORCH_NIGHTLY_URL,
328+
torch_url,
184329
"--upgrade-strategy",
185330
"only-if-needed",
186331
],

0 commit comments

Comments
 (0)