Skip to content

Commit a37bbdd

Browse files
authored
[BACKEND] Fix codegen for ScanOp when there are redundant threads (triton-lang#5641)
This was a mildly tricky bug to track down. Groups of threads with redundant data weren't being masked out, causing them to shuffle data in from threads they weren't supposed to and accumulate them. E.g. if there are 32 threads where the first 16 have unique data and the second half are replicas, lane 16 will shuffle in data from lane 15, 14, 12, etc. and add them in. If the result of the scan is used in such a way that the redundant data is simply discarded, such as stored to global memory, then the invalid values simply aren't observed, but the case that exposed this was a broadcast of the result, causing the invalid results to be observed.
1 parent d907d46 commit a37bbdd

File tree

4 files changed

+115
-16
lines changed

4 files changed

+115
-16
lines changed

lib/Analysis/Utility.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
290290
bool ScanLoweringHelper::isSupported() {
291291
// TODO: Support the following cases:
292292
// 1. Scan on non-blocking encodings
293-
if (!isa<BlockedEncodingAttr>(getEncoding()))
293+
if (!isa<BlockedEncodingAttr>(srcEncoding))
294294
return false;
295295
return true;
296296
}
@@ -306,6 +306,10 @@ unsigned ScanLoweringHelper::getScratchSizeInElems() {
306306
}
307307

308308
unsigned ScanLoweringHelper::getScratchSizeInBytes() {
309+
// Lowering will fail later if the layout is not supported.
310+
if (!isSupported())
311+
return 0;
312+
309313
unsigned axisNumWarps = getAxisNumWarpsWithUniqueData();
310314
if (axisNumWarps == 1)
311315
return 0;

lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
461461
ScanLoweringHelper helper(op);
462462
auto loc = helper.getLoc();
463463
if (!helper.isSupported())
464-
return failure();
464+
return op.emitError("TODO: unsupported scan layout");
465465

466466
Value threadId = getThreadId(rewriter, loc);
467467
auto mod = op->getParentOfType<ModuleOp>();
@@ -470,6 +470,14 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
470470
Value warpId = udiv(threadId, warpSize);
471471
Value laneId = urem(threadId, warpSize);
472472

473+
// Clamp the lane ID to just threads with unique data within a warp.
474+
LinearLayout layout =
475+
triton::gpu::toLinearLayout(helper.getShape(), helper.getEncoding());
476+
StringAttr kLane = rewriter.getStringAttr("lane");
477+
int32_t laneMask = layout.getFreeVariableMasks()[kLane];
478+
laneMask = (layout.getInDimSize(kLane) - 1) & ~laneMask;
479+
laneId = and_(laneId, i32_val(laneMask));
480+
473481
auto [laneIdAxis, warpIdAxis, flatIdParallel] =
474482
getDelinearizedIds(rewriter, helper, laneId, warpId);
475483
auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData();

python/test/unit/language/test_core.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2586,20 +2586,6 @@ def kernel(X, Y, Z, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, AXIS: tl.const
25862586
np.testing.assert_equal(z_ref, z_tri)
25872587

25882588

2589-
scan_layouts = [
2590-
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
2591-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
2592-
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]),
2593-
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
2594-
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
2595-
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
2596-
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
2597-
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
2598-
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
2599-
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
2600-
BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
2601-
]
2602-
26032589
# ---------------
26042590
# test histogram
26052591
# ---------------
@@ -2631,6 +2617,24 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
26312617
assert (z_torch == z).all()
26322618

26332619

2620+
@pytest.mark.parametrize("M, N", [(1, 64), (2, 32), (4, 16), (8, 8), (16, 4), (32, 2), (64, 1)])
2621+
def test_scan_1d(M, N):
2622+
2623+
@triton.jit
2624+
def scan_kernel(out_ptr, in_ptr, M: tl.constexpr, N: tl.constexpr):
2625+
input = tl.load(in_ptr + tl.arange(0, M))
2626+
output = tl.cumsum(input).reshape([1, M]).broadcast_to([N, M])
2627+
tl.store(out_ptr + tl.arange(0, M * N), output.reshape([M * N]))
2628+
2629+
x = torch.randint(-100, 100, (M, ), dtype=torch.int32, device='cuda')
2630+
output = torch.empty(M * N, dtype=torch.int32, device='cuda')
2631+
2632+
scan_kernel[(1, )](output, x, M, N)
2633+
2634+
ref = torch.cumsum(x, dim=0).reshape([1, M]).broadcast_to([N, M]).reshape([M * N])
2635+
torch.testing.assert_close(ref.to(torch.int32), output, atol=0, rtol=0)
2636+
2637+
26342638
@pytest.mark.interpreter
26352639
@pytest.mark.parametrize("op", ['sum', 'max', 'min'])
26362640
@pytest.mark.parametrize("BLOCK_N", [32, 64, 128])
@@ -2681,6 +2685,21 @@ def kernel(X, Y, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):
26812685
np.testing.assert_allclose(y_tri, y_ref, rtol=0.01, atol=1e-3)
26822686

26832687

2688+
scan_layouts = [
2689+
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
2690+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]),
2691+
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [0, 1], [1, 1], [1, 1], [0, 1]),
2692+
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
2693+
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]),
2694+
BlockedLayout([1, 4], [4, THREADS_PER_WARP // 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
2695+
BlockedLayout([1, 4], [8, THREADS_PER_WARP // 8], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]),
2696+
BlockedLayout([4, 1], [4, THREADS_PER_WARP // 4], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
2697+
BlockedLayout([2, 2], [4, THREADS_PER_WARP // 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
2698+
BlockedLayout([2, 2], [8, THREADS_PER_WARP // 8], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]),
2699+
BlockedLayout([1, 2], [1, THREADS_PER_WARP // 1], [1, 4], [1, 0], [1, 1], [1, 1], [0, 1]),
2700+
]
2701+
2702+
26842703
@pytest.mark.parametrize("M, N", [[32, 16], [32, 32], [32, 64], [64, 32]])
26852704
@pytest.mark.parametrize("src_layout", scan_layouts)
26862705
@pytest.mark.parametrize("axis", [0, 1])

test/Conversion/scan_to_llvm.mlir

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --canonicalize | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s
2+
3+
#layout = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}>
4+
#layout_adj = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [16], warpsPerCTA = [2], order = [0]}>
5+
#layout_2d = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [8, 2], warpsPerCTA = [2, 1], order = [0,1]}>
6+
7+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 16 : i32} {
8+
9+
// CHECK-LABEL: @test_1d_simple
10+
tt.func private @test_1d_simple(%arg0: tensor<8xi32, #layout>) -> tensor<8xi32, #layout> {
11+
// CHECK: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
12+
// CHECK: [[LANEID_AXIS:%.*]] = and i32 [[TID]], 7
13+
// CHECK: icmp eq i32 [[LANEID_AXIS]], 0
14+
%0 = "tt.scan"(%arg0) <{axis = 0 : i32, reverse = false}> ({
15+
^bb0(%arg1: i32, %arg2: i32):
16+
%1 = arith.addi %arg1, %arg2 : i32
17+
tt.scan.return %1 : i32
18+
}) : (tensor<8xi32, #layout>) -> tensor<8xi32, #layout>
19+
tt.return %0 : tensor<8xi32, #layout>
20+
}
21+
22+
// CHECK-LABEL: @test_1d_grouped
23+
tt.func private @test_1d_grouped(%arg0: tensor<8xi32, #layout_adj>) -> tensor<8xi32, #layout_adj> {
24+
// CHECK: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
25+
// CHECK: [[LANEID_AXIS:%.*]] = and i32 [[TID]], 3
26+
// CHECK: icmp eq i32 [[LANEID_AXIS]], 0
27+
%0 = "tt.scan"(%arg0) <{axis = 0 : i32, reverse = false}> ({
28+
^bb0(%arg1: i32, %arg2: i32):
29+
%1 = arith.addi %arg1, %arg2 : i32
30+
tt.scan.return %1 : i32
31+
}) : (tensor<8xi32, #layout_adj>) -> tensor<8xi32, #layout_adj>
32+
tt.return %0 : tensor<8xi32, #layout_adj>
33+
}
34+
35+
// CHECK-LABEL: @test_2d_grouped
36+
tt.func private @test_2d_grouped(%arg0: tensor<16x1xi32, #layout_2d>) -> tensor<16x1xi32, #layout_2d> {
37+
// CHECK: [[TID:%.*]] = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
38+
// CHECK: [[LANEID_AXIS:%.*]] = and i32 [[TID]], 7
39+
// CHECK: icmp eq i32 [[LANEID_AXIS]], 0
40+
%0 = "tt.scan"(%arg0) <{axis = 0 : i32, reverse = false}> ({
41+
^bb0(%arg1: i32, %arg2: i32):
42+
%1 = arith.addi %arg1, %arg2 : i32
43+
tt.scan.return %1 : i32
44+
}) : (tensor<16x1xi32, #layout_2d>) -> tensor<16x1xi32, #layout_2d>
45+
tt.return %0 : tensor<16x1xi32, #layout_2d>
46+
}
47+
48+
// This just prevents the test functions from being DCE'd.
49+
tt.func public @anchor(%ptr: !llvm.ptr, %arg0: !llvm.struct<(i32)>, %arg1: !llvm.struct<(i32, i32)>, %arg2: !llvm.struct<(i32)>) {
50+
%0 = builtin.unrealized_conversion_cast %arg0 : !llvm.struct<(i32)> to tensor<8xi32, #layout>
51+
%1 = tt.call @test_1d_simple(%0) : (tensor<8xi32, #layout>) -> tensor<8xi32, #layout>
52+
%2 = builtin.unrealized_conversion_cast %1 : tensor<8xi32, #layout> to !llvm.struct<(i32)>
53+
llvm.store volatile %2, %ptr : !llvm.struct<(i32)>, !llvm.ptr
54+
55+
%3 = builtin.unrealized_conversion_cast %arg1 : !llvm.struct<(i32, i32)> to tensor<8xi32, #layout_adj>
56+
%4 = tt.call @test_1d_grouped(%3) : (tensor<8xi32, #layout_adj>) -> tensor<8xi32, #layout_adj>
57+
%5 = builtin.unrealized_conversion_cast %4 : tensor<8xi32, #layout_adj> to !llvm.struct<(i32, i32)>
58+
llvm.store volatile %5, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr
59+
60+
%6 = builtin.unrealized_conversion_cast %arg2 : !llvm.struct<(i32)> to tensor<16x1xi32, #layout_2d>
61+
%7 = tt.call @test_2d_grouped(%6) : (tensor<16x1xi32, #layout_2d>) -> tensor<16x1xi32, #layout_2d>
62+
%8 = builtin.unrealized_conversion_cast %7 : tensor<16x1xi32, #layout_2d> to !llvm.struct<(i32)>
63+
llvm.store volatile %8, %ptr : !llvm.struct<(i32)>, !llvm.ptr
64+
65+
tt.return
66+
}
67+
68+
}

0 commit comments

Comments
 (0)