Skip to content

Commit 7d3dcd9

Browse files
committed
Cherry pick Neuron support for kernels
This change does not add support for `build2cmake`. For building kernels, use the `main` branch until the next major release.
1 parent 2e940cc commit 7d3dcd9

File tree

8 files changed

+133
-24
lines changed

8 files changed

+133
-24
lines changed

builder/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ See [dockerfiles/README.md](./dockerfiles/README.md) for more options, including
6363
| XPU |||| 2 |
6464
| Metal |||| 2 |
6565
| Huawei NPU |||| 3 |
66+
| Neuron || x | x | 3 |
67+
68+
**Warning:** Neuron support is experimental and currently requires pre-release packages.
6669

6770
# 📚 Documentation
6871

kernels/src/kernels/layer/kernelize.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
274274

275275
def _validate_device_type(device_type: str) -> None:
276276
"""Validate that the device type is supported."""
277-
supported_devices = {"cpu", "cuda", "mps", "npu", "rocm", "xpu"}
277+
supported_devices = {"cpu", "cuda", "mps", "neuron", "npu", "rocm", "xpu"}
278278
if device_type not in supported_devices:
279279
raise ValueError(
280280
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
@@ -310,3 +310,9 @@ def _is_rocm_platform():
310310
import torch
311311

312312
return torch.version.hip is not None
313+
314+
315+
def _has_neuron_ops():
316+
import torch
317+
318+
return hasattr(torch, "neuron")

kernels/src/kernels/layer/repos.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ def create_repo(device: Device) -> "DeviceRepos":
3636
return _XPURepos()
3737
elif device.type == "npu":
3838
return _NPURepos()
39+
elif device.type == "neuron":
40+
return _NeuronRepos()
3941
else:
4042
raise ValueError(f"Unknown device type: {device.type}")
4143

@@ -93,6 +95,26 @@ def insert(self, device: Device, repos: dict[Mode, RepositoryProtocol]):
9395
self._repos = repos
9496

9597

98+
class _NeuronRepos(DeviceRepos):
99+
_repos: dict[Mode, RepositoryProtocol]
100+
101+
def __init__(self):
102+
super().__init__()
103+
self._repos = {}
104+
105+
@property
106+
def repos(
107+
self,
108+
) -> dict[Mode, RepositoryProtocol] | None:
109+
return self._repos
110+
111+
def insert(self, device: Device, repos: dict[Mode, RepositoryProtocol]):
112+
if device.type != "neuron":
113+
raise ValueError(f"Device type must be 'neuron', got {device.type}")
114+
115+
self._repos = repos
116+
117+
96118
class _NPURepos(DeviceRepos):
97119
_repos: dict[Mode, RepositoryProtocol]
98120

kernels/src/kernels/python_depends.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
}
1515
},
1616
"metal": {},
17+
"neuron": {
18+
"nki": {
19+
"nix": [],
20+
"python": ["nki"]
21+
}
22+
},
1723
"rocm": {},
1824
"xpu": {
1925
"onednn": {

kernels/src/kernels/utils.py

Lines changed: 17 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ def _get_cache_dir() -> str | None:
2828
"""Returns the kernels cache directory."""
2929
cache_dir = os.environ.get("HF_KERNELS_CACHE", None)
3030
if cache_dir is not None:
31-
logging.warning(
32-
"HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead"
33-
)
31+
logging.warning("HF_KERNELS_CACHE will be removed in the future, use KERNELS_CACHE instead")
3432
return cache_dir
3533

3634
return os.environ.get("KERNELS_CACHE", None)
@@ -50,7 +48,11 @@ def _get_privateuse_backend_name() -> str | None:
5048
def backend() -> str:
5149
import torch
5250

53-
if torch.version.cuda is not None:
51+
if hasattr(torch, "neuron"):
52+
# Needs to be sorted before specific Torch builds, since Neuron
53+
# extension can be loaded into e.g. CUDA Torch builds.
54+
return "neuron"
55+
elif torch.version.cuda is not None:
5456
return "cuda"
5557
elif torch.version.hip is not None:
5658
return "hip"
@@ -104,7 +106,11 @@ def build_variant() -> str:
104106
def build_variant_noarch() -> str:
105107
import torch
106108

107-
if torch.version.cuda is not None:
109+
if hasattr(torch, "neuron"):
110+
# Needs to be sorted before specific Torch builds, since Neuron
111+
# extension can be loaded into e.g. CUDA Torch builds.
112+
return "torch-neuron"
113+
elif torch.version.cuda is not None:
108114
return "torch-cuda"
109115
elif torch.version.hip is not None:
110116
return "torch-rocm"
@@ -197,9 +203,7 @@ def install_kernel(
197203
try:
198204
return _find_kernel_in_repo_path(repo_path, package_name, variant_locks)
199205
except FileNotFoundError:
200-
raise FileNotFoundError(
201-
f"Cannot install kernel from repo {repo_id} (revision: {revision})"
202-
)
206+
raise FileNotFoundError(f"Cannot install kernel from repo {repo_id} (revision: {revision})")
203207

204208

205209
def _find_kernel_in_repo_path(
@@ -264,9 +268,7 @@ def install_kernel_all_variants(
264268
if variant_lock is None:
265269
raise ValueError(f"No lock found for build variant: {variant}")
266270

267-
validate_kernel(
268-
repo_path=repo_path, variant=variant, hash=variant_lock.hash
269-
)
271+
validate_kernel(repo_path=repo_path, variant=variant, hash=variant_lock.hash)
270272

271273
return repo_path / "build"
272274

@@ -309,9 +311,7 @@ def get_kernel(
309311
```
310312
"""
311313
revision = select_revision_or_version(repo_id, revision=revision, version=version)
312-
package_name, variant_path = install_kernel(
313-
repo_id, revision=revision, user_agent=user_agent
314-
)
314+
package_name, variant_path = install_kernel(repo_id, revision=revision, user_agent=user_agent)
315315
return _import_from_path(package_name, variant_path)
316316

317317

@@ -344,9 +344,7 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType:
344344
raise FileNotFoundError(f"Could not find package '{package_name}' in {repo_path}")
345345

346346

347-
def has_kernel(
348-
repo_id: str, revision: str | None = None, version: int | str | None = None
349-
) -> bool:
347+
def has_kernel(repo_id: str, revision: str | None = None, version: int | str | None = None) -> bool:
350348
"""
351349
Check whether a kernel build exists for the current environment (Torch version and compute framework).
352350
@@ -419,9 +417,7 @@ def load_kernel(repo_id: str, *, lockfile: Path | None) -> ModuleType:
419417
)
420418

421419
try:
422-
package_name, variant_path = _find_kernel_in_repo_path(
423-
repo_path, package_name, variant_locks=None
424-
)
420+
package_name, variant_path = _find_kernel_in_repo_path(repo_path, package_name, variant_locks=None)
425421
return _import_from_path(package_name, variant_path)
426422
except FileNotFoundError:
427423
raise FileNotFoundError(
@@ -447,9 +443,7 @@ def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleTyp
447443
if locked_sha is None:
448444
raise ValueError(f"Kernel `{repo_id}` is not locked")
449445

450-
package_name, variant_path = install_kernel(
451-
repo_id, locked_sha, local_files_only=local_files_only
452-
)
446+
package_name, variant_path = install_kernel(repo_id, locked_sha, local_files_only=local_files_only)
453447

454448
return _import_from_path(package_name, variant_path)
455449

kernels/tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010
and torch.version.cuda is not None
1111
and torch.cuda.device_count() > 0
1212
)
13+
14+
has_neuron = hasattr(torch, "neuron") and torch.neuron.device_count() > 0
15+
1316
has_rocm = (
1417
hasattr(torch.version, "hip")
1518
and torch.version.hip is not None
@@ -46,6 +49,8 @@ def device():
4649
def pytest_runtest_setup(item):
4750
if "cuda_only" in item.keywords and not has_cuda:
4851
pytest.skip("skipping CUDA-only test on host without CUDA")
52+
if "neuron_only" in item.keywords and not has_neuron:
53+
pytest.skip("skipping Neuron-only test on host without Neuron")
4954
if "rocm_only" in item.keywords and not has_rocm:
5055
pytest.skip("skipping ROCm-only test on host without ROCm")
5156
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):

kernels/tests/test_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,13 @@ def test_flattened_build(repo_revision, device):
199199
torch.testing.assert_close(kernel.silu_and_mul(x), silu_and_mul_torch(x))
200200

201201

202+
@pytest.mark.neuron_only
203+
def test_neuron():
204+
relu = get_kernel("kernels-test/relu-nki", version=1)
205+
x = torch.randn((16, 16), dtype=torch.float16).to(device="neuron")
206+
torch.testing.assert_close(relu.relu(x), x.relu())
207+
208+
202209
def silu_and_mul_torch(x: torch.Tensor):
203210
d = x.shape[-1] // 2
204211
return F.silu(x[..., :d]) * x[..., d:]

kernels/tests/test_layer.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,23 @@ class RMSNormWithKernel(RMSNorm):
8484
pass
8585

8686

87+
class ReLU(nn.Module):
88+
def __init__(self):
89+
super().__init__()
90+
# Used to check that we called hub kernel.
91+
self.n_calls = 0
92+
93+
def forward(self, input: torch.Tensor) -> torch.Tensor:
94+
self.n_calls += 1
95+
d = input.shape[-1] // 2
96+
return F.relu(input)
97+
98+
99+
@use_kernel_forward_from_hub("ReLU")
100+
class ReLUWithKernel(ReLU):
101+
pass
102+
103+
87104
class SiluAndMul(nn.Module):
88105
def __init__(self):
89106
super().__init__()
@@ -188,6 +205,55 @@ def test_hub_func(cls):
188205
assert silu_and_mul_with_kernel.n_calls == 0
189206

190207

208+
@pytest.mark.neuron_only
209+
def test_hub_forward_neuron():
210+
torch.manual_seed(0)
211+
212+
mapping = {
213+
"ReLU": {
214+
"neuron": LayerRepository(
215+
repo_id="kernels-test/relu-nki", version=1, layer_name="ReLU"
216+
)
217+
}
218+
}
219+
220+
relu = ReLU()
221+
X = torch.randn((16, 16), device="neuron")
222+
Y = relu(X)
223+
224+
with use_kernel_mapping(mapping):
225+
relu_with_kernel = kernelize(
226+
ReLUWithKernel(), device="neuron", mode=Mode.INFERENCE
227+
)
228+
Y_kernel = relu_with_kernel(X)
229+
230+
torch.testing.assert_close(Y_kernel, Y)
231+
232+
assert relu.n_calls == 1
233+
assert relu_with_kernel.n_calls == 0
234+
235+
# Check that the device type can be determined automatically.
236+
class SMOL(nn.Module):
237+
def __init__(self):
238+
super().__init__()
239+
self.linear = nn.Linear(16, 16)
240+
self.relu = ReLUWithKernel()
241+
242+
def forward(self, x):
243+
return self.relu(self.linear(x))
244+
245+
smol = SMOL().to("neuron")
246+
247+
Y = smol(X)
248+
249+
with use_kernel_mapping(mapping):
250+
smol = kernelize(smol, mode=Mode.INFERENCE)
251+
Y_kernel = smol(X)
252+
253+
torch.testing.assert_close(Y, Y_kernel)
254+
assert smol.relu.n_calls == 1
255+
256+
191257
@pytest.mark.rocm_only
192258
def test_hub_forward_rocm():
193259
torch.manual_seed(0)

0 commit comments

Comments
 (0)