Skip to content

Commit d258c3c

Browse files
committed
Support noarch build variants
This change adds support for noarch build variants. So far we have used the universal variant for kernels that do not have any AoT-compiled code. However, the universal variant has two important issues: 1. A kernel without AoT-compiled might still be backend-specific. E.g. NVIDIA CuTe-based kernels are not universal in the sense that they don't work on non-NVIDIA GPUs. 2. We cannot specify dependencies per backend. To solve these issues, we introduce the noarch variants to replace universal kernels. Noarch kernels have variants of the shape `torch-<backend>` (e.g. `torch-xpu`). This resolves the issues outlined. This change introduces support for loading noarch kernels. In the future, we will start emitting deprecation warnings for universal kernels (to eventually remove support).
1 parent 2faa279 commit d258c3c

File tree

3 files changed

+51
-31
lines changed

3 files changed

+51
-31
lines changed

src/kernels/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from .doc import generate_readme_for_kernel
1515
from .wheel import build_variant_to_wheel
1616

17-
BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-universal)")
17+
BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-)")
1818

1919

2020
def main():

src/kernels/utils.py

Lines changed: 43 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,33 @@ def build_variant() -> str:
8484
return f"torch{torch_version.major}{torch_version.minor}-{cxxabi}-{compute_framework}-{cpu}-{os}"
8585

8686

87-
def universal_build_variant() -> str:
87+
def build_variant_noarch() -> str:
88+
import torch
89+
90+
if torch.version.cuda is not None:
91+
return "torch-cuda"
92+
elif torch.version.hip is not None:
93+
return "torch-rocm"
94+
elif torch.backends.mps.is_available():
95+
return "torch-metal"
96+
elif hasattr(torch.version, "xpu") and torch.version.xpu is not None:
97+
return "torch-xpu"
98+
elif _get_privateuse_backend_name() == "npu":
99+
return "torch-npu"
100+
else:
101+
return "torch-cpu"
102+
103+
104+
def build_variant_universal() -> str:
88105
# Once we support other frameworks, detection goes here.
89106
return "torch-universal"
90107

91108

109+
def build_variants() -> List[str]:
110+
"""Return compatible build variants in preferred order."""
111+
return [build_variant(), build_variant_noarch(), build_variant_universal()]
112+
113+
92114
def _import_from_path(module_name: str, variant_path: Path) -> ModuleType:
93115
metadata_path = variant_path / "metadata.json"
94116
if metadata_path.exists():
@@ -146,13 +168,12 @@ def install_kernel(
146168
`Tuple[str, Path]`: A tuple containing the package name and the path to the variant directory.
147169
"""
148170
package_name = package_name_from_repo_id(repo_id)
149-
variant = build_variant()
150-
universal_variant = universal_build_variant()
171+
allow_patterns = [f"build/{variant}/*" for variant in build_variants()]
151172
user_agent = _get_user_agent(user_agent=user_agent)
152173
repo_path = Path(
153174
snapshot_download(
154175
repo_id,
155-
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
176+
allow_patterns=allow_patterns,
156177
cache_dir=CACHE_DIR,
157178
revision=revision,
158179
local_files_only=local_files_only,
@@ -173,23 +194,22 @@ def _find_kernel_in_repo_path(
173194
package_name: str,
174195
variant_locks: Optional[Dict[str, VariantLock]] = None,
175196
) -> Tuple[str, Path]:
176-
specific_variant = build_variant()
177-
universal_variant = universal_build_variant()
178-
179-
specific_variant_path = repo_path / "build" / specific_variant
180-
universal_variant_path = repo_path / "build" / universal_variant
181-
182-
if specific_variant_path.exists():
183-
variant = specific_variant
184-
variant_path = specific_variant_path
185-
elif universal_variant_path.exists():
186-
variant = universal_variant
187-
variant_path = universal_variant_path
188-
else:
197+
variants = build_variants()
198+
variant = None
199+
variant_path = None
200+
for candidate_variant in variants:
201+
variant_path = repo_path / "build" / candidate_variant
202+
if variant_path.exists():
203+
variant = candidate_variant
204+
break
205+
206+
if variant is None:
189207
raise FileNotFoundError(
190-
f"Kernel at path `{repo_path}` does not have one of build variants: {specific_variant}, {universal_variant}"
208+
f"Kernel at path `{repo_path}` does not have one of build variants: {', '.join(variants)}"
191209
)
192210

211+
assert variant_path is not None
212+
193213
if variant_locks is not None:
194214
variant_lock = variant_locks.get(variant)
195215
if variant_lock is None:
@@ -295,13 +315,9 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
295315
Returns:
296316
`ModuleType`: The imported kernel module.
297317
"""
298-
variant = build_variant()
299-
universal_variant = universal_build_variant()
300-
301318
# Presume we were given the top level path of the kernel repository.
302319
for base_path in [repo_path, repo_path / "build"]:
303-
# Prefer the universal variant if it exists.
304-
for v in [universal_variant, variant]:
320+
for v in build_variants():
305321
variant_path = base_path / v
306322
if variant_path.exists():
307323
return _import_from_path(package_name, variant_path)
@@ -337,9 +353,8 @@ def has_kernel(
337353

338354
package_name = package_name_from_repo_id(repo_id)
339355
variant = build_variant()
340-
universal_variant = universal_build_variant()
341356

342-
for variant in [universal_variant, variant]:
357+
for variant in build_variants():
343358
for init_file in ["__init__.py", f"{package_name}/__init__.py"]:
344359
if file_exists(
345360
repo_id,
@@ -379,13 +394,11 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
379394

380395
package_name = package_name_from_repo_id(repo_id)
381396

382-
variant = build_variant()
383-
universal_variant = universal_build_variant()
384-
397+
allow_patterns = [f"build/{variant}/*" for variant in build_variants()]
385398
repo_path = Path(
386399
snapshot_download(
387400
repo_id,
388-
allow_patterns=[f"build/{variant}/*", f"build/{universal_variant}/*"],
401+
allow_patterns=allow_patterns,
389402
cache_dir=CACHE_DIR,
390403
revision=locked_sha,
391404
local_files_only=True,
@@ -399,7 +412,7 @@ def load_kernel(repo_id: str, *, lockfile: Optional[Path] = None) -> ModuleType:
399412
return _import_from_path(package_name, variant_path)
400413
except FileNotFoundError:
401414
raise FileNotFoundError(
402-
f"Locked kernel `{repo_id}` does not have build `{variant}` or was not downloaded with `kernels download <project>`"
415+
f"Locked kernel `{repo_id}` does not have applicable variant or was not downloaded with `kernels download <project>`"
403416
)
404417

405418

tests/test_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ def test_universal_kernel(universal_kernel):
163163
torch.testing.assert_close(out, out_check, rtol=1e-1, atol=1e-1)
164164

165165

166+
def test_noarch_kernel(device):
167+
supported_devices = ["cpu", "cuda", "xpu"]
168+
if device not in supported_devices:
169+
pytest.skip(f"Device is not one of: {','.join(supported_devices)}")
170+
get_kernel("kernels-test/silu-and-mul-noarch")
171+
172+
166173
@pytest.mark.parametrize(
167174
"repo_revision",
168175
[

0 commit comments

Comments
 (0)