Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.

Commit 9a3e358

Browse files
authored
Add support for Torch 2.10 RC3 (#334)
Add support for Torch 2.10. Related changes in this PR: - `xpuPackages_2025_3`: init at 2025.3.1 - `rocmPackages_7_0`: 7.0.1 -> 7.0.2 - `rocmPackages_7_1`: init at 7.1.1
1 parent 2d8e2b1 commit 9a3e358

29 files changed

+8877
-3804
lines changed

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
</div>
1111
<hr/>
1212

13+
**Note:** Torch 2.10 builds are still based on PyTorch release candidates.
14+
Typically the ABI does not break during release candidates. If it does,
15+
you have to recompile your kernels with the final 2.10.0 release.
16+
1317
[Join us on Discord](https://discord.gg/H6Tkmd88N3) for questions and discussions!
1418

1519
This repo contains a Nix package that can be used to build custom machine learning kernels for PyTorch. The kernels are built using the [PyTorch C++ Frontend](https://pytorch.org/cppdocs/frontend.html) and can be loaded from the Hub with the [kernels](https://github.com/huggingface/kernels)

build-variants.json

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
{
22
"aarch64-darwin": {
33
"cpu": [
4+
"torch210-cpu-aarch64-darwin",
45
"torch28-cpu-aarch64-darwin",
56
"torch29-cpu-aarch64-darwin"
67
],
78
"metal": [
9+
"torch210-metal-aarch64-darwin",
810
"torch28-metal-aarch64-darwin",
911
"torch29-metal-aarch64-darwin"
1012
]
1113
},
1214
"aarch64-linux": {
1315
"cpu": [
16+
"torch210-cxx11-cpu-aarch64-linux",
1417
"torch28-cxx11-cpu-aarch64-linux",
1518
"torch29-cxx11-cpu-aarch64-linux"
1619
],
1720
"cuda": [
21+
"torch210-cxx11-cu126-aarch64-linux",
22+
"torch210-cxx11-cu128-aarch64-linux",
23+
"torch210-cxx11-cu130-aarch64-linux",
1824
"torch28-cxx11-cu129-aarch64-linux",
1925
"torch29-cxx11-cu126-aarch64-linux",
2026
"torch29-cxx11-cu128-aarch64-linux",
@@ -23,10 +29,14 @@
2329
},
2430
"x86_64-linux": {
2531
"cpu": [
32+
"torch210-cxx11-cpu-x86_64-linux",
2633
"torch28-cxx11-cpu-x86_64-linux",
2734
"torch29-cxx11-cpu-x86_64-linux"
2835
],
2936
"cuda": [
37+
"torch210-cxx11-cu126-x86_64-linux",
38+
"torch210-cxx11-cu128-x86_64-linux",
39+
"torch210-cxx11-cu130-x86_64-linux",
3040
"torch28-cxx11-cu126-x86_64-linux",
3141
"torch28-cxx11-cu128-x86_64-linux",
3242
"torch28-cxx11-cu129-x86_64-linux",
@@ -35,12 +45,15 @@
3545
"torch29-cxx11-cu130-x86_64-linux"
3646
],
3747
"rocm": [
48+
"torch210-cxx11-rocm70-x86_64-linux",
49+
"torch210-cxx11-rocm71-x86_64-linux",
3850
"torch28-cxx11-rocm63-x86_64-linux",
3951
"torch28-cxx11-rocm64-x86_64-linux",
4052
"torch29-cxx11-rocm63-x86_64-linux",
4153
"torch29-cxx11-rocm64-x86_64-linux"
4254
],
4355
"xpu": [
56+
"torch210-cxx11-xpu20253-x86_64-linux",
4457
"torch28-cxx11-xpu20251-x86_64-linux",
4558
"torch29-cxx11-xpu20252-x86_64-linux"
4659
]

build2cmake/src/templates/xpu/dep-cutlass-sycl.cmake

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
find_package(CutlassSycl)
22

3-
if(DPCPP_VERSION STREQUAL "2025.2")
3+
if(DPCPP_VERSION STREQUAL "2025.3")
4+
set(CUTLASS_SYCL_REVISION "14055e78510b8776ba739755eb57e592fdceefdb" CACHE STRING "CUTLASS revision to use")
5+
elseif(DPCPP_VERSION STREQUAL "2025.2")
46
set(CUTLASS_SYCL_REVISION "14055e78510b8776ba739755eb57e592fdceefdb" CACHE STRING "CUTLASS revision to use")
57
elseif(DPCPP_VERSION STREQUAL "2025.1")
68
set(CUTLASS_SYCL_REVISION "v3.9-0.3" CACHE STRING "CUTLASS revision to use")
@@ -67,7 +69,7 @@ endif()
6769
string(REPLACE "-fsycl-targets=spir64_gen,spir64" "-fsycl-targets=spir64" sycl_link_flags "${sycl_link_flags}")
6870
string(REPLACE "-device pvc,xe-lpg,ats-m150" "-device bmg_g21,pvc" sycl_link_flags "${sycl_link_flags}")
6971
string(APPEND sycl_link_flags "-Xspirv-translator;-spirv-ext=+SPV_INTEL_split_barrier")
70-
if(DPCPP_VERSION STREQUAL "2025.2" OR CUTLASS_SYCL_REVISION STREQUAL "v0.5")
72+
if(DPCPP_VERSION STREQUAL "2025.2" OR DPCPP_VERSION STREQUAL "2025.3" OR CUTLASS_SYCL_REVISION STREQUAL "v0.5")
7173
string(APPEND sycl_link_flags ",+SPV_INTEL_2d_block_io,+SPV_INTEL_subgroup_matrix_multiply_accumulate")
7274
endif()
7375
string(REPLACE "-fsycl-targets=spir64_gen,spir64" "-fsycl-targets=spir64" sycl_flags "${sycl_flags}")

docs/build-variants.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,33 +7,43 @@ available. This list will be updated as new PyTorch versions are released.
77

88
## CPU aarch64-darwin
99

10+
- `torch210-cpu-aarch64-darwin`
1011
- `torch28-cpu-aarch64-darwin`
1112
- `torch29-cpu-aarch64-darwin`
1213

1314
## Metal aarch64-darwin
1415

16+
- `torch210-metal-aarch64-darwin`
1517
- `torch28-metal-aarch64-darwin`
1618
- `torch29-metal-aarch64-darwin`
1719

1820
## CPU aarch64-linux
1921

22+
- `torch210-cxx11-cpu-aarch64-linux`
2023
- `torch28-cxx11-cpu-aarch64-linux`
2124
- `torch29-cxx11-cpu-aarch64-linux`
2225

2326
## CUDA aarch64-linux
2427

28+
- `torch210-cxx11-cu126-aarch64-linux`
29+
- `torch210-cxx11-cu128-aarch64-linux`
30+
- `torch210-cxx11-cu130-aarch64-linux`
2531
- `torch28-cxx11-cu129-aarch64-linux`
2632
- `torch29-cxx11-cu126-aarch64-linux`
2733
- `torch29-cxx11-cu128-aarch64-linux`
2834
- `torch29-cxx11-cu130-aarch64-linux`
2935

3036
## CPU x86_64-linux
3137

38+
- `torch210-cxx11-cpu-x86_64-linux`
3239
- `torch28-cxx11-cpu-x86_64-linux`
3340
- `torch29-cxx11-cpu-x86_64-linux`
3441

3542
## CUDA x86_64-linux
3643

44+
- `torch210-cxx11-cu126-x86_64-linux`
45+
- `torch210-cxx11-cu128-x86_64-linux`
46+
- `torch210-cxx11-cu130-x86_64-linux`
3747
- `torch28-cxx11-cu126-x86_64-linux`
3848
- `torch28-cxx11-cu128-x86_64-linux`
3949
- `torch28-cxx11-cu129-x86_64-linux`
@@ -43,13 +53,16 @@ available. This list will be updated as new PyTorch versions are released.
4353

4454
## ROCm x86_64-linux
4555

56+
- `torch210-cxx11-rocm70-x86_64-linux`
57+
- `torch210-cxx11-rocm71-x86_64-linux`
4658
- `torch28-cxx11-rocm63-x86_64-linux`
4759
- `torch28-cxx11-rocm64-x86_64-linux`
4860
- `torch29-cxx11-rocm63-x86_64-linux`
4961
- `torch29-cxx11-rocm64-x86_64-linux`
5062

5163
## XPU x86_64-linux
5264

65+
- `torch210-cxx11-xpu20253-x86_64-linux`
5366
- `torch28-cxx11-xpu20251-x86_64-linux`
5467
- `torch29-cxx11-xpu20252-x86_64-linux`
5568

flake.lock

Lines changed: 7 additions & 7 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

flake.nix

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33

44
inputs = {
55
flake-utils.url = "github:numtide/flake-utils";
6-
# Put back to nixos-unstable-small the next bump. Exact revision is
7-
# to avoid a rebuild during the hf-nix -> kernel-builder transition.
8-
nixpkgs.url = "github:NixOS/nixpkgs/c543a59edf25ada193719764f3bc0c6ba835f94d";
6+
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable-small";
97
flake-compat.url = "github:edolstra/flake-compat";
108
};
119

lib/torch-extension/default.nix

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
rocmSupport ? torch.rocmSupport,
33
xpuSupport ? torch.xpuSupport,
44

5+
pkgs,
56
lib,
67
callPackage,
78
stdenv,
@@ -30,7 +31,15 @@ let
3031
);
3132

3233
cuda_nvcc = cudaPackages.cuda_nvcc.override {
33-
backendStdenv = cudaPackages.backendStdenv.override {
34+
backendStdenv = import ../../pkgs/cuda/backendStdenv {
35+
inherit (pkgs)
36+
_cuda
37+
config
38+
lib
39+
pkgs
40+
stdenvAdapters
41+
;
42+
inherit (cudaPackages) cudaMajorMinorVersion;
3443
stdenv = effectiveStdenv;
3544
};
3645
};

overlay.nix

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,22 @@ in
4545

4646
remove-bytecode-hook = prev.callPackage ./pkgs/remove-bytecode-hook { };
4747

48-
stdenvGlibc_2_27 = prev.callPackage ./pkgs/stdenv-glibc-2_27 { };
48+
stdenvGlibc_2_27 = import ./pkgs/stdenv-glibc-2_27 {
49+
# Do not use callPackage, because we want overrides to apply to
50+
# the stdenv itself and not this file.
51+
inherit (final)
52+
config
53+
fetchFromGitHub
54+
overrideCC
55+
wrapBintoolsWith
56+
wrapCCWith
57+
gcc13Stdenv
58+
stdenv
59+
bintools-unwrapped
60+
cudaPackages
61+
libgcc
62+
;
63+
};
4964

5065
ucx = prev.ucx.overrideAttrs (
5166
_: prevAttrs: {
@@ -107,6 +122,11 @@ in
107122
xpuPackages = final.xpuPackages_2025_2;
108123
};
109124

125+
torch-bin_2_10 = mkTorch {
126+
version = "2.10";
127+
xpuPackages = final.xpuPackages_2025_3;
128+
};
129+
110130
torch_2_8 = callPackage ./pkgs/python-modules/torch/source/2_8 {
111131
xpuPackages = final.xpuPackages_2025_1;
112132
};
@@ -139,7 +159,8 @@ in
139159
versions = [
140160
"6.3.4"
141161
"6.4.2"
142-
"7.0.1"
162+
"7.0.2"
163+
"7.1.1"
143164
];
144165
newRocmPackages = final.callPackage ./pkgs/rocm-packages { };
145166
in
@@ -159,6 +180,7 @@ in
159180
xpuVersions = [
160181
"2025.1.3"
161182
"2025.2.1"
183+
"2025.3.1"
162184
];
163185
newXpuPackages = final.callPackage ./pkgs/xpu-packages { };
164186
in

pkgs/aotriton/default.nix

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ let
1717
find "$out" -name .git -print0 | xargs -0 rm -rf
1818
'';
1919
mkImages =
20-
srcs:
20+
version: srcs:
2121
stdenvNoCC.mkDerivation {
22-
name = "images";
22+
name = "images-${version}";
2323

2424
inherit srcs;
2525

@@ -69,7 +69,7 @@ in
6969
"gfx1201"
7070
];
7171

72-
images = mkImages [
72+
images = mkImages version [
7373
(fetchurl {
7474
url = "https://github.com/ROCm/aotriton/releases/download/0.10b/aotriton-0.10b-manylinux_2_28_x86_64-rocm6.3-shared.tar.gz";
7575
hash = "sha256-hhzZ90ee7JQ5M8J8uGkgJH5bXdE5vHwTdsgYCKu31/4=";
@@ -107,7 +107,7 @@ in
107107
"gfx1201"
108108
];
109109

110-
images = mkImages [
110+
images = mkImages version [
111111
(fetchurl {
112112
url = "https://github.com/ROCm/aotriton/releases/download/0.11b/aotriton-0.11b-images-amd-gfx90a.tar.gz";
113113
hash = "sha256-wZpByUgFEKsy5vsF5u0KODLWsHY08FC4NrdgIAvvpzU=";
@@ -132,4 +132,59 @@ in
132132

133133
extraPythonDepends = ps: [ ps.pandas ];
134134
};
135+
136+
aotriton_0_11_1 = generic rec {
137+
version = "0.11.1b";
138+
139+
src = fetchFromGitHub {
140+
owner = "ROCm";
141+
repo = "aotriton";
142+
rev = version;
143+
hash = "sha256-F7JjyS+6gMdCpOFLldTsNJdVzzVwd6lwW7+V8ZOZfig=";
144+
leaveDotGit = true;
145+
inherit postFetch;
146+
};
147+
148+
patches = [
149+
# Fails with: ld.lld: error: unable to insert .comment after .comment
150+
./v0.11.1b-no-ld-script.diff
151+
];
152+
153+
gpuTargets = [
154+
# aotriton GPU support list:
155+
# https://github.com/ROCm/aotriton/blob/main/v2python/gpu_targets.py
156+
"gfx90a"
157+
"gfx942"
158+
"gfx950"
159+
"gfx1100"
160+
"gfx1151"
161+
"gfx1201"
162+
];
163+
164+
images = mkImages version [
165+
(fetchurl {
166+
url = "https://github.com/ROCm/aotriton/releases/download/0.11.1b/aotriton-0.11.1b-images-amd-gfx90a.tar.gz";
167+
hash = "sha256-/p8Etmv1KsJ80CXh2Jz9BJdN0/s64HYZL3g2QaTYD98=";
168+
})
169+
(fetchurl {
170+
url = "https://github.com/ROCm/aotriton/releases/download/0.11.1b/aotriton-0.11.1b-images-amd-gfx942.tar.gz";
171+
hash = "sha256-CnvO4Z07ttVIcyJIwyNPe5JzbCq3p6rmUpS4en/WTAY=";
172+
})
173+
(fetchurl {
174+
url = "https://github.com/ROCm/aotriton/releases/download/0.11.1b/aotriton-0.11.1b-images-amd-gfx950.tar.gz";
175+
hash = "sha256-wbo7/oQhf9Z9890fi2fICn97M9CtTXS0HWVnA24DKs4=";
176+
})
177+
(fetchurl {
178+
url = "https://github.com/ROCm/aotriton/releases/download/0.11.1b/aotriton-0.11.1b-images-amd-gfx11xx.tar.gz";
179+
hash = "sha256-ZjIEDEBdgzvm/3ICkknHdoOLr18Do8E7pOjTeoe3p0A=";
180+
})
181+
(fetchurl {
182+
url = "https://github.com/ROCm/aotriton/releases/download/0.11.1b/aotriton-0.11.1b-images-amd-gfx120x.tar.gz";
183+
hash = "sha256-Ck/zJL/9rAwv3oeop/cFY9PISoCtTo8xNF8rQKE4TpU=";
184+
})
185+
];
186+
187+
extraPythonDepends = ps: [ ps.pandas ];
188+
};
189+
135190
}

0 commit comments

Comments
 (0)