Skip to content

Commit d8bb02b

Browse files
authored
kernels: add support for Neuron and NKI (huggingface#285)
* Add basic Neuron + NKI support to `kernels` * Add Neuron layer support * Fix accidental import removal * Exclude Neuron in init test * Sync with latest neuronx * build2cmake: add Neuron support Also add an example kernel using NKI. * Add Neuron to the builder README * Fix typo
1 parent 0c9cf0b commit d8bb02b

File tree

18 files changed

+243
-14
lines changed

18 files changed

+243
-14
lines changed

build2cmake/src/config/mod.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ pub struct General {
4444
pub python_depends: Option<Vec<String>>,
4545

4646
pub cuda: Option<CudaGeneral>,
47+
pub neuron: Option<NeuronGeneral>,
4748
pub xpu: Option<XpuGeneral>,
4849
}
4950

@@ -106,6 +107,10 @@ pub struct XpuGeneral {
106107
pub python_depends: Option<Vec<String>>,
107108
}
108109

110+
pub struct NeuronGeneral {
111+
pub python_depends: Option<Vec<String>>,
112+
}
113+
109114
pub struct Hub {
110115
pub repo_id: Option<String>,
111116
pub branch: Option<String>,
@@ -237,16 +242,18 @@ pub enum Backend {
237242
Cpu,
238243
Cuda,
239244
Metal,
245+
Neuron,
240246
Rocm,
241247
Xpu,
242248
}
243249

244250
impl Backend {
245-
pub const fn all() -> [Backend; 5] {
251+
pub const fn all() -> [Backend; 6] {
246252
[
247253
Backend::Cpu,
248254
Backend::Cuda,
249255
Backend::Metal,
256+
Backend::Neuron,
250257
Backend::Rocm,
251258
Backend::Xpu,
252259
]
@@ -259,6 +266,7 @@ impl Display for Backend {
259266
Backend::Cpu => write!(f, "cpu"),
260267
Backend::Cuda => write!(f, "cuda"),
261268
Backend::Metal => write!(f, "metal"),
269+
Backend::Neuron => write!(f, "neuron"),
262270
Backend::Rocm => write!(f, "rocm"),
263271
Backend::Xpu => write!(f, "xpu"),
264272
}
@@ -273,6 +281,7 @@ impl FromStr for Backend {
273281
"cpu" => Ok(Backend::Cpu),
274282
"cuda" => Ok(Backend::Cuda),
275283
"metal" => Ok(Backend::Metal),
284+
"neuron" => Ok(Backend::Neuron),
276285
"rocm" => Ok(Backend::Rocm),
277286
"xpu" => Ok(Backend::Xpu),
278287
_ => Err(format!("Unknown backend: {s}")),

build2cmake/src/config/v1.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ impl TryFrom<Build> for super::Build {
8686
Backend::Cpu,
8787
Backend::Cuda,
8888
Backend::Metal,
89+
Backend::Neuron,
8990
Backend::Rocm,
9091
Backend::Xpu,
9192
]
@@ -102,6 +103,7 @@ impl TryFrom<Build> for super::Build {
102103
license: None,
103104
backends,
104105
hub: None,
106+
neuron: None,
105107
python_depends: None,
106108
cuda: None,
107109
xpu: None,

build2cmake/src/config/v2.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ impl TryFrom<Build> for super::Build {
132132
Backend::Cpu,
133133
Backend::Cuda,
134134
Backend::Metal,
135+
Backend::Neuron,
135136
Backend::Rocm,
136137
Backend::Xpu,
137138
]
@@ -168,6 +169,7 @@ impl General {
168169
backends,
169170
cuda,
170171
hub: general.hub.map(Into::into),
172+
neuron: None,
171173
python_depends: None,
172174
xpu: None,
173175
}

build2cmake/src/config/v3.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ pub struct General {
3131

3232
pub hub: Option<Hub>,
3333

34+
pub neuron: Option<NeuronGeneral>,
35+
3436
pub python_depends: Option<Vec<String>>,
3537

3638
pub xpu: Option<XpuGeneral>,
@@ -44,6 +46,12 @@ pub struct CudaGeneral {
4446
pub python_depends: Option<Vec<String>>,
4547
}
4648

49+
#[derive(Debug, Deserialize, Serialize)]
50+
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
51+
pub struct NeuronGeneral {
52+
pub python_depends: Option<Vec<String>>,
53+
}
54+
4755
#[derive(Debug, Deserialize, Serialize)]
4856
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
4957
pub struct XpuGeneral {
@@ -121,6 +129,7 @@ pub enum Backend {
121129
Cpu,
122130
Cuda,
123131
Metal,
132+
Neuron,
124133
Rocm,
125134
Xpu,
126135
}
@@ -150,6 +159,7 @@ impl From<General> for super::General {
150159
backends: general.backends.into_iter().map(Into::into).collect(),
151160
cuda: general.cuda.map(Into::into),
152161
hub: general.hub.map(Into::into),
162+
neuron: general.neuron.map(Into::into),
153163
python_depends: general.python_depends,
154164
xpu: general.xpu.map(Into::into),
155165
}
@@ -166,6 +176,14 @@ impl From<CudaGeneral> for super::CudaGeneral {
166176
}
167177
}
168178

179+
impl From<NeuronGeneral> for super::NeuronGeneral {
180+
fn from(neuron: NeuronGeneral) -> Self {
181+
Self {
182+
python_depends: neuron.python_depends,
183+
}
184+
}
185+
}
186+
169187
impl From<XpuGeneral> for super::XpuGeneral {
170188
fn from(xpu: XpuGeneral) -> Self {
171189
Self {
@@ -201,6 +219,7 @@ impl From<Backend> for super::Backend {
201219
Backend::Cpu => super::Backend::Cpu,
202220
Backend::Cuda => super::Backend::Cuda,
203221
Backend::Metal => super::Backend::Metal,
222+
Backend::Neuron => super::Backend::Neuron,
204223
Backend::Rocm => super::Backend::Rocm,
205224
Backend::Xpu => super::Backend::Xpu,
206225
}
@@ -304,6 +323,7 @@ impl From<super::General> for General {
304323
backends: general.backends.into_iter().map(Into::into).collect(),
305324
cuda: general.cuda.map(Into::into),
306325
hub: general.hub.map(Into::into),
326+
neuron: general.neuron.map(Into::into),
307327
python_depends: general.python_depends,
308328
xpu: general.xpu.map(Into::into),
309329
}
@@ -320,6 +340,14 @@ impl From<super::CudaGeneral> for CudaGeneral {
320340
}
321341
}
322342

343+
impl From<super::NeuronGeneral> for NeuronGeneral {
344+
fn from(neuron: super::NeuronGeneral) -> Self {
345+
Self {
346+
python_depends: neuron.python_depends,
347+
}
348+
}
349+
}
350+
323351
impl From<super::XpuGeneral> for XpuGeneral {
324352
fn from(xpu: super::XpuGeneral) -> Self {
325353
Self {
@@ -355,6 +383,7 @@ impl From<super::Backend> for Backend {
355383
super::Backend::Cpu => Backend::Cpu,
356384
super::Backend::Cuda => Backend::Cuda,
357385
super::Backend::Metal => Backend::Metal,
386+
super::Backend::Neuron => Backend::Neuron,
358387
super::Backend::Rocm => Backend::Rocm,
359388
super::Backend::Xpu => Backend::Xpu,
360389
}

build2cmake/src/python_dependencies.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
}
1515
},
1616
"metal": {},
17+
"neuron": {
18+
"nki": {
19+
"nix": [],
20+
"python": ["nki"]
21+
}
22+
},
1723
"rocm": {},
1824
"xpu": {
1925
"onednn": {

build2cmake/src/templates/noarch/setup.py

100755100644
Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#!/usr/bin/env python
22

3-
import shutil
4-
from pathlib import Path
53
from typing import Any
4+
from pathlib import Path
5+
import shutil
6+
import sys
67

78
from setuptools import setup
89
from setuptools.command.build import build
@@ -30,7 +31,10 @@ def run(self) -> None:
3031
"""Execute the build command."""
3132
project_root = Path(__file__).parent
3233

33-
import tomllib
34+
if sys.version_info >= (3, 11):
35+
import tomllib
36+
else:
37+
import tomli as tomllib
3438

3539
with open(project_root / "build.toml", "rb") as f:
3640
build_toml: dict[str, Any] = tomllib.load(f)

builder/README.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ See [dockerfiles/README.md](./dockerfiles/README.md) for more options, including
6363
| XPU |||| 2 |
6464
| Metal |||| 2 |
6565
| Huawei NPU |||| 3 |
66+
| Neuron || x | x | 3 |
67+
68+
**Warning:** Neuron support is experimental and currently requires pre-release packages.
6669

6770
# 📚 Documentation
6871

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[general]
2+
name = "relu-nki"
3+
version = 1
4+
backends = [
5+
"neuron",
6+
]
7+
8+
[general.neuron]
9+
python-depends = ["nki"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import nki
2+
import nki.language as nl
3+
import nki.isa as nisa
4+
5+
from ._ops import ops
6+
7+
8+
@nki.jit(platform_target="trn2")
9+
def relu(x):
10+
# Check the first dimension's size to ensure it does not exceed on-chip
11+
# memory tile size, since this simple kernel does not tile inputs.
12+
assert x.shape[0] <= nl.tile_size.pmax
13+
x_tile = sbuf.view(dtype=x.dtype, shape=x.shape)
14+
nisa.dma_copy(dst=x_tile, src=x)
15+
out_tile = sbuf.view(dtype=x.dtype, shape=x.shape)
16+
nisa.tensor_scalar(dst=out_tile, data=x_tile, operand0=0, op0=nl.maximum)
17+
c_output = hbm.view(dtype=x.dtype, shape=x.shape)
18+
nisa.dma_copy(dst=c_output, src=out_tile)
19+
return c_output
20+
21+
22+
from . import layers
23+
24+
__all__ = [
25+
"layers",
26+
"relu",
27+
]
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from .. import relu
5+
6+
7+
class ReLU(nn.Module):
8+
def forward(self, x: torch.Tensor) -> torch.Tensor:
9+
return relu(x)

0 commit comments

Comments
 (0)