Skip to content

Commit 755d416

Browse files
authored
[LAYOUTS] Enable ReduceOp lowering with LinearEncodingAttr (#5477)
ReduceOp's lowering does not support linear layouts yet, so propagation of linear layouts across `tt.reduce` ops will cause codegen to crash. The codegen routine is generic enough to support linear layouts, so just enable it and add a few tests.
1 parent 75fb922 commit 755d416

File tree

3 files changed

+96
-4
lines changed

3 files changed

+96
-4
lines changed

lib/Analysis/Utility.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -219,15 +219,14 @@ bool ReduceOpHelper::isSupportedLayout() {
219219
}
220220

221221
auto srcLayout = getSrcLayout();
222-
if (isa<BlockedEncodingAttr>(srcLayout)) {
222+
if (isa<BlockedEncodingAttr, LinearEncodingAttr, SliceEncodingAttr>(
223+
srcLayout)) {
223224
return true;
224225
}
226+
225227
if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(srcLayout)) {
226228
return mmaLayout.supportReduction();
227229
}
228-
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(srcLayout)) {
229-
return true;
230-
}
231230
return false;
232231
}
233232

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s
2+
3+
#linear = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>
4+
5+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
6+
7+
// CHECK-LABEL: @reduce_linear_layout
8+
tt.func private @reduce_linear_layout(%arg0: tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> {
9+
// CHECK-NEXT: [[SRC0:%.*]] = extractvalue {{.*}} %0, 0
10+
// CHECK-NEXT: [[SRC1:%.*]] = extractvalue {{.*}} %0, 1
11+
// CHECK-NEXT: [[SRC2:%.*]] = extractvalue {{.*}} %0, 2
12+
// CHECK-NEXT: [[SRC3:%.*]] = extractvalue {{.*}} %0, 3
13+
14+
// The layout looks lke
15+
// [[ T0:0, T32:0, T0:1, T32:1, ...
16+
// [ T4:0, T36:0, T4:1, T36:1, ...
17+
// [ T0:2, T32:2, T0:3, T32:3, ...
18+
// [ T4:2, T36:2, T4:3, T36:3,
19+
// ...
20+
//
21+
// A reduction along axis=0 consists of adding registers (0, 2) and (1, 3)
22+
// before shuffling.
23+
//
24+
// Columns along axis=0 are contained within a warp, so reduction arcoss warps
25+
// is not needed.
26+
27+
// Reduce within threads
28+
// CHECK-NEXT: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]]
29+
// CHECK-NEXT: [[SUM1:%.*]] = add i32 [[SRC1]], [[SRC3]]
30+
31+
// Reduce within warp.
32+
// CHECK-NEXT: [[W0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM0]], i32 16, i32 31)
33+
// CHECK-NEXT: [[WSUM0:%.*]] = add i32 [[W0]], [[SUM0]]
34+
// CHECK-NEXT: [[W1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM0]], i32 8, i32 31)
35+
// CHECK-NEXT: [[WSUM1:%.*]] = add i32 [[WSUM0]], [[W1]]
36+
// CHECK-NEXT: [[W2:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM1]], i32 4, i32 31)
37+
// CHECK-NEXT: [[WSUM2:%.*]] = add i32 [[WSUM1]], [[W2]]
38+
// CHECK-NEXT: [[W3:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM2]], i32 2, i32 31)
39+
// CHECK-NEXT: [[WSUM3:%.*]] = add i32 [[WSUM2]], [[W3]]
40+
41+
// CHECK-NEXT: [[W4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM1]], i32 16, i32 31)
42+
// CHECK-NEXT: [[WSUM4:%.*]] = add i32 [[W4]], [[SUM1]]
43+
// CHECK-NEXT: [[W5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM4]], i32 8, i32 31)
44+
// CHECK-NEXT: [[WSUM5:%.*]] = add i32 [[WSUM4]], [[W5]]
45+
// CHECK-NEXT: [[W6:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM5]], i32 4, i32 31)
46+
// CHECK-NEXT: [[WSUM6:%.*]] = add i32 [[WSUM5]], [[W6]]
47+
// CHECK-NEXT: [[W7:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM6]], i32 2, i32 31)
48+
// CHECK-NEXT: [[WSUM7:%.*]] = add i32 [[WSUM6]], [[W7]]
49+
50+
// CHECK-NEXT: [[DST0:%.*]] = insertvalue { i32, i32 } undef, i32 [[WSUM3]], 0
51+
// CHECK-NEXT: [[DST1:%.*]] = insertvalue { i32, i32 } [[DST0]], i32 [[WSUM7]], 1
52+
53+
%0 = "tt.reduce"(%arg0) ({
54+
^bb0(%arg1: i32, %arg2: i32):
55+
%1 = arith.addi %arg1, %arg2 : i32
56+
tt.reduce.return %1 : i32
57+
}) {axis = 0 : i32} : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>
58+
59+
// CHECK-NEXT: ret { i32, i32 } [[DST1]]
60+
tt.return %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>
61+
}
62+
63+
tt.func @anchor(%ptr: !llvm.ptr, %arg0: tensor<32x16xi32, #linear>) {
64+
%0 = tt.call @reduce_linear_layout(%arg0) : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>>
65+
%1 = builtin.unrealized_conversion_cast %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> to !llvm.struct<(i32, i32)>
66+
llvm.store volatile %1, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr
67+
tt.return
68+
}
69+
70+
}

test/TritonGPU/combine.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2829,3 +2829,26 @@ tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) {
28292829
}
28302830

28312831
}
2832+
2833+
// -----
2834+
2835+
#linear = #ttg.linear<{register = [[1, 0], [0, 8], [0, 16]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 2], [0, 4]], block = []}>
2836+
#blocked = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [1, 0]}>
2837+
2838+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
2839+
2840+
// CHECK-LABEL: reduce_linear_layouts
2841+
tt.func @reduce_linear_layouts(%arg0: tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> {
2842+
// CHECK-NOT: convert_layout
2843+
%0 = ttg.convert_layout %arg0 : tensor<32x32xi32, #linear> -> tensor<32x32xi32, #blocked>
2844+
// CHECK-NEXT: tt.reduce
2845+
%1 = "tt.reduce" (%0) ({
2846+
^bb0(%arg1: i32, %arg2: i32):
2847+
tt.reduce.return %arg1 : i32
2848+
// CHECK: (tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>
2849+
}) {axis = 1 : i32} : (tensor<32x32xi32, #blocked>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
2850+
%2 = ttg.convert_layout %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>>
2851+
tt.return %2 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>>
2852+
}
2853+
2854+
}

0 commit comments

Comments
 (0)