Skip to content

Commit 3c47763

Browse files
Mogballmeta-codesync[bot]
authored andcommitted
[Cherry-pick] [Dialect] Make warp specialization require at least 4 warps (#8005) (#548)
Summary: Cherry-picked from upstream OAI repository. Original Commit: cfc0a9d Original Author: Jeff Niu Original Date: 2025-08-29 09:45:18 -0700 Original commit message: ``` [Dialect] Make warp specialization require at least 4 warps (#8005) The warpgroup allocator makes fairly strong assumptions that the default number of warps is at least 4. Untangling this is non-trivial (see #7940 which is WIP to fix this). For now, just add an error message to prevent the compiler from crashing and confusing users. ``` This PR was automatically cherry-picked from the upstream triton-lang/triton repository. Pull Request resolved: #548 Reviewed By: agron911, htyu Differential Revision: D85908114 Pulled By: dshi7 fbshipit-source-id: af0a7353916a7ac70a151129d26357dee2f83518
1 parent 380d218 commit 3c47763

File tree

7 files changed

+74
-65
lines changed

7 files changed

+74
-65
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -908,6 +908,14 @@ LogicalResult WarpSpecializeOp::verify() {
908908
"cannot be nested inside another `ttg.warp_specialize` op");
909909
}
910910

911+
std::optional<int> numWarps = maybeLookupNumWarps(*this);
912+
if (numWarps && *numWarps % 4 != 0) {
913+
return mlir::emitError(getLoc())
914+
<< "warp-specialized kernels requires "
915+
"num_warps to be a multiple of 4 but num_warps="
916+
<< *numWarps;
917+
}
918+
911919
return success();
912920
}
913921

python/test/unit/language/test_tlx.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,7 @@ def tcgen5_fa_kernel(a_ptr, stride_am, stride_ak, b_ptr, stride_bk, stride_bn, c
16611661

16621662
kern_kwargs = {'BLOCK_M': M, 'BLOCK_K': K, 'BLOCK_N': N}
16631663
kernel = tcgen5_fa_kernel[(1, 1)](a, a.stride(0), a.stride(1), b, b.stride(0), b.stride(1), c, c.stride(0),
1664-
c.stride(1), d, d.stride(0), d.stride(1), **kern_kwargs, num_warps=1)
1664+
c.stride(1), d, d.stride(0), d.stride(1), **kern_kwargs, num_warps=4)
16651665

16661666
ttgir = kernel.asm["ttgir"]
16671667
assert ttgir.count("ttng.tmem_alloc") == 2

test/Conversion/triton_to_tritongpu.mlir

Lines changed: 0 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -168,64 +168,3 @@ tt.func @cf_br(%ptr: !tt.ptr<i32>) {
168168
tt.store %ptrs, %arg0 : tensor<128x!tt.ptr<i32>>
169169
tt.return
170170
}
171-
172-
// -----
173-
174-
// CHECK-LABEL: @legalize_warp_specialize
175-
tt.func @legalize_warp_specialize(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>) {
176-
ttg.warp_specialize(%arg0)
177-
default {
178-
ttg.warp_yield
179-
}
180-
partition0(%arg2: !tt.ptr<i32>) num_warps(2) {
181-
// CHECK: tt.splat {{.*}} : !tt.ptr<i32> -> tensor<256x!tt.ptr<i32>, #blocked>
182-
// CHECK: tt.load {{.*}} : tensor<256x!tt.ptr<i32>, #blocked>
183-
%splatted = tt.splat %arg2 : !tt.ptr<i32> -> tensor<256x!tt.ptr<i32>>
184-
%input = tt.load %splatted : tensor<256x!tt.ptr<i32>>
185-
ttg.warp_return
186-
} : (!tt.ptr<i32>) -> ()
187-
tt.return
188-
}
189-
190-
191-
// -----
192-
// CHECK-DAG: [[DEFAULT:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [2], order = [0]}>
193-
// CHECK-DAG: [[WS1:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
194-
// CHECK: @legalize_warp_partition
195-
module attributes {tlx.has_warp_spec_ops = true, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
196-
tt.func public @legalize_warp_partition(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
197-
%c1024_i32 = arith.constant 1024 : i32
198-
%0 = tt.get_program_id x : i32
199-
%1 = arith.muli %0, %c1024_i32 : i32
200-
ttg.warp_specialize(%arg3, %1, %arg5)
201-
// CHECK: default
202-
default {
203-
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
204-
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
205-
%4 = arith.addi %3, %2 : tensor<1024xi32>
206-
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
207-
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
208-
// CHECK: tt.load {{.*}} : tensor<1024x!tt.ptr<f32>, [[DEFAULT]]
209-
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>>
210-
%8 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
211-
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
212-
tt.store %9, %7 : tensor<1024x!tt.ptr<f32>>
213-
ttg.warp_yield
214-
}
215-
// CHECK: partition0
216-
partition0(%arg7: !tt.ptr<f32>, %arg8: i32, %arg9: !tt.ptr<f32>) num_warps(1) {
217-
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
218-
%3 = tt.splat %arg8 : i32 -> tensor<1024xi32>
219-
%4 = arith.addi %3, %2 : tensor<1024xi32>
220-
%5 = tt.splat %arg7 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
221-
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
222-
// CHECK: tt.load {{.*}} : tensor<1024x!tt.ptr<f32>, [[WS1]]
223-
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>>
224-
%8 = tt.splat %arg9 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
225-
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
226-
tt.store %9, %7 : tensor<1024x!tt.ptr<f32>>
227-
ttg.warp_return
228-
} : (!tt.ptr<f32>, i32, !tt.ptr<f32>) -> ()
229-
tt.return
230-
}
231-
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
// RUN: triton-opt %s -split-input-file -convert-triton-to-tritongpu='target=cuda:80 num-warps=4' | FileCheck %s
2+
3+
// CHECK-LABEL: @legalize_warp_specialize
4+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
5+
tt.func @legalize_warp_specialize(%arg0: !tt.ptr<i32>, %arg1: !tt.ptr<i32>) {
6+
ttg.warp_specialize(%arg0)
7+
default {
8+
ttg.warp_yield
9+
}
10+
partition0(%arg2: !tt.ptr<i32>) num_warps(2) {
11+
// CHECK: tt.splat {{.*}} : !tt.ptr<i32> -> tensor<256x!tt.ptr<i32>, #blocked>
12+
// CHECK: tt.load {{.*}} : tensor<256x!tt.ptr<i32>, #blocked>
13+
%splatted = tt.splat %arg2 : !tt.ptr<i32> -> tensor<256x!tt.ptr<i32>>
14+
%input = tt.load %splatted : tensor<256x!tt.ptr<i32>>
15+
ttg.warp_return
16+
} : (!tt.ptr<i32>) -> ()
17+
tt.return
18+
}
19+
}
20+
21+
22+
// -----
23+
// CHECK-DAG: [[DEFAULT:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
24+
// CHECK-DAG: [[WS1:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [1], order = [0]}>
25+
// CHECK: @legalize_warp_partition
26+
module attributes {tlx.has_warp_spec_ops = true, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
27+
tt.func public @legalize_warp_partition(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
28+
%c1024_i32 = arith.constant 1024 : i32
29+
%0 = tt.get_program_id x : i32
30+
%1 = arith.muli %0, %c1024_i32 : i32
31+
ttg.warp_specialize(%arg3, %1, %arg5)
32+
// CHECK: default
33+
default {
34+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
35+
%3 = tt.splat %1 : i32 -> tensor<1024xi32>
36+
%4 = arith.addi %3, %2 : tensor<1024xi32>
37+
%5 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
38+
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
39+
// CHECK: tt.load {{.*}} : tensor<1024x!tt.ptr<f32>, [[DEFAULT]]
40+
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>>
41+
%8 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
42+
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
43+
tt.store %9, %7 : tensor<1024x!tt.ptr<f32>>
44+
ttg.warp_yield
45+
}
46+
// CHECK: partition0
47+
partition0(%arg7: !tt.ptr<f32>, %arg8: i32, %arg9: !tt.ptr<f32>) num_warps(1) {
48+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32>
49+
%3 = tt.splat %arg8 : i32 -> tensor<1024xi32>
50+
%4 = arith.addi %3, %2 : tensor<1024xi32>
51+
%5 = tt.splat %arg7 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
52+
%6 = tt.addptr %5, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
53+
// CHECK: tt.load {{.*}} : tensor<1024x!tt.ptr<f32>, [[WS1]]
54+
%7 = tt.load %6 : tensor<1024x!tt.ptr<f32>>
55+
%8 = tt.splat %arg9 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>>
56+
%9 = tt.addptr %8, %4 : tensor<1024x!tt.ptr<f32>>, tensor<1024xi32>
57+
tt.store %9, %7 : tensor<1024x!tt.ptr<f32>>
58+
ttg.warp_return
59+
} : (!tt.ptr<f32>, i32, !tt.ptr<f32>) -> ()
60+
tt.return
61+
}
62+
}

test/TLX/propagate-layout.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,7 @@ module attributes {tlx.has_explicit_local_mem_access = true, "ttg.num-ctas" = 1
558558
// CHECK-DAG: #[[$TMEM:.*]] = #ttng.tensor_memory_encoding<blockM = 64, blockN = 32, unpacked = true>
559559
// CHECK-DAG: #[[$TMEM1:.*]] = #ttng.tensor_memory_encoding<blockM = 64, blockN = 32, unpacked = false>
560560

561-
module attributes {tlx.has_explicit_local_mem_access = true, tlx.has_tlx_ops = true, tlx.has_warp_spec_ops = true, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
561+
module attributes {tlx.has_explicit_local_mem_access = true, tlx.has_tlx_ops = true, tlx.has_warp_spec_ops = true, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
562562
// CHECK-LABEL: @tcgen5_fa_kernel
563563
tt.func public @tcgen5_fa_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
564564
%0 = ttg.local_alloc : () -> !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable>

test/TLX/rewrite-local-alias.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
// CHECK-DAG: #[[$TMEM:.*]] = #ttng.tensor_memory_encoding<blockM = 64, blockN = 32, unpacked = true>
1414
// CHECK-DAG: #[[$TMEM1:.*]] = #ttng.tensor_memory_encoding<blockM = 64, blockN = 32, unpacked = false>
1515

16-
module attributes {tlx.has_explicit_local_mem_access = true, tlx.has_tlx_ops = true, tlx.has_warp_spec_ops = true, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
16+
module attributes {tlx.has_explicit_local_mem_access = true, tlx.has_tlx_ops = true, tlx.has_warp_spec_ops = true, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
1717
// CHECK-LABEL: @tcgen5_fa_kernel
1818
tt.func public @tcgen5_fa_kernel(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: i32 {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
1919
// CHECK: %[[$LOCAL_ALLOC:.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x64x16xf16, #[[$SHARED]], #smem, mutable>

test/TLX/tlx-verifier.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11

22
// RUN: triton-opt -split-input-file -pass-pipeline='builtin.module(triton-tlx-fixup{num-warps=8 target=cuda:90 num-ctas=2 threads-per-warp=32})' --verify-diagnostics %s
33

4-
module attributes {tlx.has_warp_spec_ops = true, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
4+
module attributes {tlx.has_warp_spec_ops = true, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
55
tt.func public @legalize_warp_partition(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg5: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
66
%c1024_i32 = arith.constant 1024 : i32
77
%0 = tt.get_program_id x : i32

0 commit comments

Comments
 (0)