Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.

Commit 2ffd75b

Browse files
authored
Add support for building noarch kernels (#319)
This change adds support for building noarch kernels. So far we have used the universal variant for kernels that do not have any AoT-compiled code. However, the universal variant has two important issues: 1. A kernel without AoT-compiled might still be backend-specific. E.g. NVIDIA CuTe-based kernels are not universal in the sense that they don't work on non-NVIDIA GPUs. 2. We cannot specify dependencies per backend. To solve these issues, we introduce the noarch variants to replace universal kernels. Noarch kernels have variants of the shape `torch-<backend>` (e.g. `torch-xpu`). This resolves the issues outlined. To support no-arch kernels, we update the `build.toml` format to v3, making the following changes: * `general.universal` is removed. * `general.backends` is introduced. This required option is used to list what backends the kernel supports. * `general.cuda-{minver,maxver}` has been moved to the `general.cuda` section. If a kernel supports backend X and has one or more `kernels.*` sections with `backend = "X"`, then the kernel is an AoT-compiled kernel for that backend. Otherwise, it is a noarch kernel for that backend. Suppose that we have: ```toml [general] # ... backends = ["cuda", "xpu"] #... [kernel.mykernel] backend = "xpu" # ... ``` then the XPU kernel will be AoT-compiled (e.g. `build/torch29-cxx11-xpu20252-x86_64-linux`), whereas the CUDA kernel will be noarch (`torch-cuda`). An older `build.toml` can be updated automatically with `build2cmake update-build build.toml`.
1 parent 269aa54 commit 2ffd75b

File tree

44 files changed

+791
-397
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+791
-397
lines changed

.github/workflows/build_kernel.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,10 @@ jobs:
5858
- name: Test that we can build a test shell (e.g. that gcc corresponds to CUDA-required)
5959
run: ( cd examples/relu && nix build .#devShells.x86_64-linux.test )
6060

61-
- name: Build silu-and-mul-universal kernel
62-
run: ( cd examples/silu-and-mul-universal && nix build .\#redistributable.torch29-cxx11-cu126-x86_64-linux )
63-
- name: Copy silu-and-mul-universal kernel
64-
run: cp -rL examples/silu-and-mul-universal/result silu-and-mul-universal-kernel
61+
- name: Build silu-and-mul kernel
62+
run: ( cd examples/silu-and-mul && nix build .\#redistributable.torch-cuda )
63+
- name: Copy silu-and-mul kernel
64+
run: cp -rL examples/silu-and-mul/result silu-and-mul-kernel
6565

6666
- name: Upload kernel artifacts
6767
uses: actions/upload-artifact@v4
@@ -73,7 +73,7 @@ jobs:
7373
relu-kernel
7474
relu-kernel-cpu
7575
relu-backprop-compile-kernel
76-
silu-and-mul-universal-kernel
76+
silu-and-mul-kernel
7777
7878
test:
7979
name: Test kernels

.github/workflows/build_kernel_windows.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,5 +79,5 @@ jobs:
7979
# - name: Build relu kernel (specific Torch version)
8080
# run: ( cd examples/relu-specific-torch && nix build . )
8181

82-
- name: Build silu-and-mul-universal kernel
83-
run: ( scripts\windows\builder.ps1 -SourceFolder examples/silu-and-mul-universal -BuildConfig Release -Build -Force)
82+
- name: Build silu-and-mul kernel
83+
run: ( scripts\windows\builder.ps1 -SourceFolder examples/silu-and-mul -BuildConfig Release -Build -Force)

build2cmake/flake.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

build2cmake/src/config/common.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
use serde::{Deserialize, Serialize};
2+
3+
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
4+
#[non_exhaustive]
5+
#[serde(rename_all = "lowercase")]
6+
pub enum Dependency {
7+
#[serde(rename = "cutlass_2_10")]
8+
Cutlass2_10,
9+
#[serde(rename = "cutlass_3_5")]
10+
Cutlass3_5,
11+
#[serde(rename = "cutlass_3_6")]
12+
Cutlass3_6,
13+
#[serde(rename = "cutlass_3_8")]
14+
Cutlass3_8,
15+
#[serde(rename = "cutlass_3_9")]
16+
Cutlass3_9,
17+
#[serde(rename = "cutlass_4_0")]
18+
Cutlass4_0,
19+
#[serde(rename = "cutlass_sycl")]
20+
CutlassSycl,
21+
#[serde(rename = "metal-cpp")]
22+
MetalCpp,
23+
Torch,
24+
}

build2cmake/src/config/mod.rs

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
use eyre::Result;
22
use serde::Deserialize;
3+
use serde_value::Value;
34

45
pub mod v1;
56

7+
mod common;
8+
69
mod v2;
7-
use serde_value::Value;
8-
pub use v2::{Backend, Build, Dependency, General, Kernel, Torch};
10+
11+
mod v3;
12+
pub use common::Dependency;
13+
pub use v3::{Backend, Build, General, Kernel, Torch};
914

1015
#[derive(Debug)]
1116
pub enum BuildCompat {
1217
V1(v1::Build),
13-
V2(Build),
18+
V2(v2::Build),
19+
V3(Build),
1420
}
1521

1622
impl<'de> Deserialize<'de> for BuildCompat {
@@ -20,14 +26,11 @@ impl<'de> Deserialize<'de> for BuildCompat {
2026
{
2127
let value = Value::deserialize(deserializer)?;
2228

23-
match v1::Build::deserialize(value.clone()) {
24-
Ok(v1_build) => Ok(BuildCompat::V1(v1_build)),
25-
Err(_) => {
26-
let v2_build: Build =
27-
Build::deserialize(value).map_err(serde::de::Error::custom)?;
28-
Ok(BuildCompat::V2(v2_build))
29-
}
30-
}
29+
v1::Build::deserialize(value.clone())
30+
.map(BuildCompat::V1)
31+
.or_else(|_| v2::Build::deserialize(value.clone()).map(BuildCompat::V2))
32+
.or_else(|_| Build::deserialize(value.clone()).map(BuildCompat::V3))
33+
.map_err(serde::de::Error::custom)
3134
}
3235
}
3336

@@ -36,8 +39,12 @@ impl TryFrom<BuildCompat> for Build {
3639

3740
fn try_from(compat: BuildCompat) -> Result<Self> {
3841
match compat {
39-
BuildCompat::V1(v1_build) => v1_build.try_into(),
40-
BuildCompat::V2(v2_build) => Ok(v2_build),
42+
BuildCompat::V1(v1_build) => {
43+
let v2_build: v2::Build = v1_build.try_into()?;
44+
v2_build.try_into()
45+
}
46+
BuildCompat::V2(v2_build) => v2_build.try_into(),
47+
BuildCompat::V3(v3_build) => Ok(v3_build),
4148
}
4249
}
4350
}

build2cmake/src/config/v1.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::{collections::HashMap, fmt::Display, path::PathBuf};
22

33
use serde::Deserialize;
44

5-
use super::v2::Dependency;
5+
use super::common::Dependency;
66

77
#[derive(Debug, Deserialize)]
88
#[serde(deny_unknown_fields)]

build2cmake/src/config/v2.rs

Lines changed: 5 additions & 167 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
1-
use std::{
2-
collections::{BTreeSet, HashMap},
3-
fmt::Display,
4-
path::PathBuf,
5-
str::FromStr,
6-
};
1+
use std::{collections::HashMap, fmt::Display, path::PathBuf};
72

83
use eyre::{bail, Result};
9-
use itertools::Itertools;
104
use serde::{Deserialize, Serialize};
115

126
use crate::version::Version;
137

14-
use super::v1::{self, Language};
8+
use super::{
9+
common::Dependency,
10+
v1::{self, Language},
11+
};
1512

1613
#[derive(Debug, Deserialize, Serialize)]
1714
#[serde(deny_unknown_fields)]
@@ -23,25 +20,6 @@ pub struct Build {
2320
pub kernels: HashMap<String, Kernel>,
2421
}
2522

26-
impl Build {
27-
pub fn has_kernel_with_backend(&self, backend: &Backend) -> bool {
28-
self.backends().contains(backend)
29-
}
30-
31-
pub fn backends(&self) -> BTreeSet<Backend> {
32-
self.kernels
33-
.values()
34-
.map(|kernel| match kernel {
35-
Kernel::Cpu { .. } => Backend::Cpu,
36-
Kernel::Cuda { .. } => Backend::Cuda,
37-
Kernel::Metal { .. } => Backend::Metal,
38-
Kernel::Rocm { .. } => Backend::Rocm,
39-
Kernel::Xpu { .. } => Backend::Xpu,
40-
})
41-
.collect()
42-
}
43-
}
44-
4523
#[derive(Debug, Deserialize, Serialize)]
4624
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
4725
pub struct General {
@@ -58,13 +36,6 @@ pub struct General {
5836
pub python_depends: Option<Vec<PythonDependency>>,
5937
}
6038

61-
impl General {
62-
/// Name of the kernel as a Python extension.
63-
pub fn python_name(&self) -> String {
64-
self.name.replace("-", "_")
65-
}
66-
}
67-
6839
#[derive(Debug, Deserialize, Serialize)]
6940
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
7041
pub struct Hub {
@@ -100,27 +71,6 @@ pub struct Torch {
10071
pub src: Vec<PathBuf>,
10172
}
10273

103-
impl Torch {
104-
pub fn data_globs(&self) -> Option<Vec<String>> {
105-
match self.pyext.as_ref() {
106-
Some(exts) => {
107-
let globs = exts
108-
.iter()
109-
.filter(|&ext| ext != "py" && ext != "pyi")
110-
.map(|ext| format!("\"**/*.{ext}\""))
111-
.collect_vec();
112-
if globs.is_empty() {
113-
None
114-
} else {
115-
Some(globs)
116-
}
117-
}
118-
119-
None => None,
120-
}
121-
}
122-
}
123-
12474
#[derive(Debug, Deserialize, Serialize)]
12575
#[serde(deny_unknown_fields, rename_all = "kebab-case", tag = "backend")]
12676
pub enum Kernel {
@@ -167,118 +117,6 @@ pub enum Kernel {
167117
},
168118
}
169119

170-
impl Kernel {
171-
pub fn cxx_flags(&self) -> Option<&[String]> {
172-
match self {
173-
Kernel::Cpu { cxx_flags, .. }
174-
| Kernel::Cuda { cxx_flags, .. }
175-
| Kernel::Metal { cxx_flags, .. }
176-
| Kernel::Rocm { cxx_flags, .. }
177-
| Kernel::Xpu { cxx_flags, .. } => cxx_flags.as_deref(),
178-
}
179-
}
180-
181-
pub fn include(&self) -> Option<&[String]> {
182-
match self {
183-
Kernel::Cpu { include, .. }
184-
| Kernel::Cuda { include, .. }
185-
| Kernel::Metal { include, .. }
186-
| Kernel::Rocm { include, .. }
187-
| Kernel::Xpu { include, .. } => include.as_deref(),
188-
}
189-
}
190-
191-
pub fn backend(&self) -> Backend {
192-
match self {
193-
Kernel::Cpu { .. } => Backend::Cpu,
194-
Kernel::Cuda { .. } => Backend::Cuda,
195-
Kernel::Metal { .. } => Backend::Metal,
196-
Kernel::Rocm { .. } => Backend::Rocm,
197-
Kernel::Xpu { .. } => Backend::Xpu,
198-
}
199-
}
200-
201-
pub fn depends(&self) -> &[Dependency] {
202-
match self {
203-
Kernel::Cpu { depends, .. }
204-
| Kernel::Cuda { depends, .. }
205-
| Kernel::Metal { depends, .. }
206-
| Kernel::Rocm { depends, .. }
207-
| Kernel::Xpu { depends, .. } => depends,
208-
}
209-
}
210-
211-
pub fn src(&self) -> &[String] {
212-
match self {
213-
Kernel::Cpu { src, .. }
214-
| Kernel::Cuda { src, .. }
215-
| Kernel::Metal { src, .. }
216-
| Kernel::Rocm { src, .. }
217-
| Kernel::Xpu { src, .. } => src,
218-
}
219-
}
220-
}
221-
222-
#[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)]
223-
#[serde(deny_unknown_fields, rename_all = "kebab-case")]
224-
pub enum Backend {
225-
Cpu,
226-
Cuda,
227-
Metal,
228-
Rocm,
229-
Xpu,
230-
}
231-
232-
impl Display for Backend {
233-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
234-
match self {
235-
Backend::Cpu => write!(f, "cpu"),
236-
Backend::Cuda => write!(f, "cuda"),
237-
Backend::Metal => write!(f, "metal"),
238-
Backend::Rocm => write!(f, "rocm"),
239-
Backend::Xpu => write!(f, "xpu"),
240-
}
241-
}
242-
}
243-
244-
impl FromStr for Backend {
245-
type Err = String;
246-
247-
fn from_str(s: &str) -> Result<Self, Self::Err> {
248-
match s.to_lowercase().as_str() {
249-
"cpu" => Ok(Backend::Cpu),
250-
"cuda" => Ok(Backend::Cuda),
251-
"metal" => Ok(Backend::Metal),
252-
"rocm" => Ok(Backend::Rocm),
253-
"xpu" => Ok(Backend::Xpu),
254-
_ => Err(format!("Unknown backend: {s}")),
255-
}
256-
}
257-
}
258-
259-
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
260-
#[non_exhaustive]
261-
#[serde(rename_all = "lowercase")]
262-
pub enum Dependency {
263-
#[serde(rename = "cutlass_2_10")]
264-
Cutlass2_10,
265-
#[serde(rename = "cutlass_3_5")]
266-
Cutlass3_5,
267-
#[serde(rename = "cutlass_3_6")]
268-
Cutlass3_6,
269-
#[serde(rename = "cutlass_3_8")]
270-
Cutlass3_8,
271-
#[serde(rename = "cutlass_3_9")]
272-
Cutlass3_9,
273-
#[serde(rename = "cutlass_4_0")]
274-
Cutlass4_0,
275-
#[serde(rename = "cutlass_sycl")]
276-
CutlassSycl,
277-
#[serde(rename = "metal-cpp")]
278-
MetalCpp,
279-
Torch,
280-
}
281-
282120
impl TryFrom<v1::Build> for Build {
283121
type Error = eyre::Error;
284122

0 commit comments

Comments
 (0)