Skip to content

Commit 1d3345f

Browse files
Merge commit 'bae5ff9923a75a37d4410275f1cd30bccd323e0b'
2 parents 23362af + bae5ff9 commit 1d3345f

File tree

4 files changed

+169
-150
lines changed

4 files changed

+169
-150
lines changed

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1921,8 +1921,8 @@ void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
19211921
{{str_attr("offset"), dstLayout.getTotalOutDimSize()}});
19221922
auto smemBase = LLVM::getSharedMemoryBase(loc, rewriter, targetInfo, op);
19231923

1924-
auto emitSt = [&](ConversionPatternRewriter &rewriter, Location loc,
1925-
ArrayRef<Value> vals, Value shmemAddr, int idx,
1924+
auto emitSt = [&](RewriterBase &rewriter, Location loc, ArrayRef<Value> vals,
1925+
Value shmemAddr, int idx,
19261926
VectorType vecTy) -> SmallVector<Value> {
19271927
auto length = vecTy.getNumElements();
19281928
Value valsVec =
@@ -1932,8 +1932,8 @@ void finalizeTensorAtomicResults(Operation *op, RankedTensorType tensorTy,
19321932
return {};
19331933
};
19341934

1935-
auto emitLd = [&](ConversionPatternRewriter &rewriter, Location loc,
1936-
ArrayRef<Value> vals, Value shmemAddr, int idx,
1935+
auto emitLd = [&](RewriterBase &rewriter, Location loc, ArrayRef<Value> vals,
1936+
Value shmemAddr, int idx,
19371937
VectorType vecTy) -> SmallVector<Value> {
19381938
Value loadedVec = targetInfo.loadDShared(rewriter, loc, shmemAddr,
19391939
std::nullopt, vecTy, b.true_val());

test/Conversion/cvt_to_llvm.mlir.unsupported renamed to test/Conversion/cvt_to_llvm.mlir

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,13 @@ tt.func private @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xi32, #bl
4747
// `(x//2*4 + x%2)%16 + (x>=16)*2`
4848

4949
// CHECK-DAG: [[X_MOD_2:%.*]] = and i32 [[TID]], 1
50-
// CHECK-DAG: [[X_2_4_LOWER:%.*]] = shl {{.*}} i32 [[IS_UPPER_HALF]], 1
51-
// CHECK-DAG: [[X_2_4_UPPER0:%.*]] = shl {{.*}} i32 [[TID]], 1
52-
// CHECK-DAG: [[X_2_4_UPPER1:%.*]] = and i32 [[X_2_4_UPPER0]], 24
50+
// CHECK-DAG: [[SHL:%.*]] = shl {{.*}}
51+
// CHECK-DAG: [[MASKED:%.*]] = and i32 [[SHL]], 28
52+
// CHECK-DAG: [[IDX0:%.*]] = or disjoint i32 [[MASKED]], [[X_MOD_2]]
5353
// CHECK-DAG: [[X_GE_16:%.*]] = and i32 [[TID]], 16
54+
// CHECK-DAG: [[SWAP_RESULTS:%.*]] = icmp eq i32 [[X_GE_16]], 0
5455
// CHECK-DAG: [[X_GE_16_2:%.*]] = lshr exact i32 [[X_GE_16]], 3
55-
56-
// CHECK-DAG: [[IDX0:%.*]] = or disjoint i32 [[X_2_4_LOWER]], [[X_MOD_2]]
57-
// CHECK-DAG: [[IDX1:%.*]] = or disjoint i32 [[IDX0]], [[X_2_4_UPPER1]]
58-
// CHECK-DAG: [[IDX2:%.*]] = or disjoint i32 [[IDX1]], [[X_GE_16_2]]
56+
// CHECK-DAG: [[IDX2:%.*]] = or disjoint i32 [[IDX0]], [[X_GE_16_2]]
5957

6058
// CHECK-DAG: [[SHFLSRC0:%.*]] = select i1 [[IS_LOWER_HALF]], i32 [[SRC0]], i32 [[SRC4]]
6159
// CHECK-DAG: [[SHFLSRC1:%.*]] = select i1 [[IS_LOWER_HALF]], i32 [[SRC1]], i32 [[SRC5]]
@@ -73,8 +71,7 @@ tt.func private @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xi32, #bl
7371

7472
// For register [4, 8), the upper and lower halves swap.
7573

76-
// CHECK-DAG: [[IDX3:%.*]] = or disjoint i32 [[IDX1]], 2
77-
// CHECK-DAG: [[IDX4:%.*]] = xor i32 [[IDX3]], [[X_GE_16_2]]
74+
// CHECK-DAG: [[IDX4:%.*]] = xor i32 [[IDX2]], 2
7875

7976
// CHECK-DAG: [[SHFLOUT4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[SHFLSRC4]], i32 [[IDX4]], i32 31)
8077
// CHECK-DAG: [[SHFLOUT5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[SHFLSRC5]], i32 [[IDX4]], i32 31)
@@ -83,15 +80,13 @@ tt.func private @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xi32, #bl
8380

8481
// For lanes [16, 32), swap the two results.
8582

86-
// CHECK-DAG: [[SWAP_RESULTS:%.*]] = icmp eq i32 [[X_GE_16]], 0
87-
8883
// CHECK: [[DST0:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT0]], i32 [[SHFLOUT4]]
89-
// CHECK: [[DST1:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT1]], i32 [[SHFLOUT5]]
90-
// CHECK: [[DST2:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT2]], i32 [[SHFLOUT6]]
91-
// CHECK: [[DST3:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT3]], i32 [[SHFLOUT7]]
9284
// CHECK: [[DST4:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT4]], i32 [[SHFLOUT0]]
85+
// CHECK: [[DST1:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT1]], i32 [[SHFLOUT5]]
9386
// CHECK: [[DST5:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT5]], i32 [[SHFLOUT1]]
87+
// CHECK: [[DST2:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT2]], i32 [[SHFLOUT6]]
9488
// CHECK: [[DST6:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT6]], i32 [[SHFLOUT2]]
89+
// CHECK: [[DST3:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT3]], i32 [[SHFLOUT7]]
9590
// CHECK: [[DST7:%.*]] = select i1 [[SWAP_RESULTS]], i32 [[SHFLOUT7]], i32 [[SHFLOUT3]]
9691

9792
// CHECK: insertvalue {{.*}}, i32 [[DST0]], 0

test/Conversion/gather_to_llvm.mlir

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s
2+
3+
// Check the optimized LLVMIR, since InstCombine makes the linear layout
4+
// logic understandable enough (in simple cases) to check correctness by eye.
5+
6+
#crazy_2d_src = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}>
7+
#crazy_2d_idx = #ttg.linear<{register = [[2, 0], [0, 2]], lane = [[0, 8], [16, 0], [1, 0], [8, 0], [4, 0]], warp = [[0, 1], [0, 4]], block = []}>
8+
#broadcasted_lane_1d = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
9+
#broadcasted_warp_2d = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 4], order = [1, 0]}>
10+
11+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} {
12+
13+
// CHECK-LABEL: @gather_2d_crazy
14+
tt.func private @gather_2d_crazy(%arg0: tensor<32x16xi32, #crazy_2d_idx>, %arg1: tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx> {
15+
// The specific logic becomes hard to grasp here. Just check the shuffles.
16+
17+
// CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float, float, float } %1, 0
18+
// CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float, float, float } %1, 1
19+
// CHECK-NEXT: [[SRC2:%.*]] = extractvalue { float, float, float, float } %1, 2
20+
// CHECK-NEXT: [[SRC3:%.*]] = extractvalue { float, float, float, float } %1, 3
21+
22+
// CHECK: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32
23+
// CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]],
24+
// CHECK-NEXT: {{%.*}} = bitcast i32 {{%.*}} to float
25+
// CHECK-NEXT: {{%.*}} = icmp eq i32
26+
// CHECK-NEXT: {{%.*}} = select i1
27+
// CHECK-NEXT: [[VALUE2:%.*]] = bitcast float [[SRC2]] to i32
28+
// CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]],
29+
30+
// CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]],
31+
// CHECK-NEXT: {{%.*}} = bitcast i32 {{%.*}} to float
32+
// CHECK-NEXT: {{%.*}} = icmp eq i32
33+
// CHECK-NEXT: {{%.*}} = select i1
34+
// CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE2]],
35+
36+
// CHECK: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32
37+
// CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]],
38+
// CHECK-NEXT: [[VALUE3:%.*]] = bitcast float [[SRC3]] to i32
39+
// CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]],
40+
// CHECK-NEXT: {{%.*}} = icmp eq i32
41+
// CHECK-NEXT: {{%.*}} = select i1
42+
// CHECK-NEXT: {{%.*}} = bitcast i32 {{%.*}} to float
43+
44+
// CHECK: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]],
45+
// CHECK-NEXT: tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE3]],
46+
// CHECK-NEXT: {{%.*}} = icmp eq i32
47+
// CHECK-NEXT: {{%.*}} = select i1
48+
// CHECK-NEXT: {{%.*}} = bitcast i32 {{%.*}} to float
49+
50+
%0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<32x16xf32, #crazy_2d_src>, tensor<32x16xi32, #crazy_2d_idx>) -> tensor<32x16xf32, #crazy_2d_idx>
51+
tt.return %0 : tensor<32x16xf32, #crazy_2d_idx>
52+
}
53+
54+
// There are 16 elements in the tensor. For each warp, each half-warp is mapped
55+
// to the 16 elements, so it doesn't matter if the second half [16, 32) indexes
56+
// into [0, 16), since they contain the same data.
57+
// CHECK-LABEL: @gather_broadcasted_lane_1d
58+
tt.func private @gather_broadcasted_lane_1d(%arg0: tensor<16xi32, #broadcasted_lane_1d>, %arg1: tensor<16xf32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d> {
59+
// CHECK-NEXT: [[SRC:%.*]] = extractvalue { float } %1, 0
60+
// CHECK-NEXT: [[IDX:%.*]] = extractvalue { i32 } %0, 0
61+
62+
// CHECK-NEXT: [[LANEID:%.*]] = and i32 [[IDX]], 15
63+
// CHECK-NEXT: [[VALUE:%.*]] = bitcast float [[SRC]] to i32
64+
// CHECK-NEXT: [[RES_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE]], i32 [[LANEID]], i32 31)
65+
%0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<16xf32, #broadcasted_lane_1d>, tensor<16xi32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d>
66+
67+
// CHECK-NEXT: [[RES:%.*]] = bitcast i32 [[RES_i32]] to float
68+
// CHECK-NEXT: ret float [[RES]]
69+
tt.return %0 : tensor<16xf32, #broadcasted_lane_1d>
70+
}
71+
72+
// Single gather column with 64 elements, all of which have to fit into a single
73+
// warp, so the whole column is broadcasted across the 4 warps. Each process the
74+
// same data so the warp doesn't matter.
75+
// CHECK-LABEL: @gather_broadcasted_warp_2d
76+
tt.func private @gather_broadcasted_warp_2d(%arg0: tensor<64x1xi32, #broadcasted_warp_2d>, %arg1: tensor<64x1xf32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d> {
77+
// CHECK-NEXT: [[SRC0:%.*]] = extractvalue { float, float } %1, 0
78+
// CHECK-NEXT: [[SRC1:%.*]] = extractvalue { float, float } %1, 1
79+
// CHECK-NEXT: [[IDX0:%.*]] = extractvalue { i32, i32 } %0, 0
80+
// CHECK-NEXT: [[IDX1:%.*]] = extractvalue { i32, i32 } %0, 1
81+
82+
// CHECK-NEXT: [[REGID0:%.*]] = and i32 [[IDX0]], 1
83+
// CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX0]], 1
84+
// CHECK-NEXT: [[LANEID0:%.*]] = and i32 [[TMP]], 31
85+
86+
// CHECK-NEXT: [[VALUE0:%.*]] = bitcast float [[SRC0]] to i32
87+
// CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID0]], i32 31)
88+
// CHECK-NEXT: [[VALUE1:%.*]] = bitcast float [[SRC1]] to i32
89+
// CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID0]], i32 31)
90+
91+
// CHECK-NEXT: [[PICK0:%.*]] = icmp eq i32 [[REGID0]], 0
92+
// CHECK-NEXT: select i1 [[PICK0]], i32 [[RES0_i32]], i32 [[RES1_i32]]
93+
94+
// CHECK: [[REGID1:%.*]] = and i32 [[IDX1]], 1
95+
// CHECK-NEXT: [[TMP:%.*]] = lshr i32 [[IDX1]], 1
96+
// CHECK-NEXT: [[LANEID1:%.*]] = and i32 [[TMP]], 31
97+
98+
// CHECK-NEXT: [[RES0_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE0]], i32 [[LANEID1]], i32 31)
99+
// CHECK-NEXT: [[RES1_i32:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.idx.i32(i32 -1, i32 [[VALUE1]], i32 [[LANEID1]], i32 31)
100+
101+
// CHECK-NEXT: [[PICK1:%.*]] = icmp eq i32 [[REGID1]], 0
102+
// CHECK-NEXT: select i1 [[PICK1]], i32 [[RES0_i32]], i32 [[RES1_i32]]
103+
%0 = tt.gather %arg1[%arg0] {axis = 0 : i32} : (tensor<64x1xf32, #broadcasted_warp_2d>, tensor<64x1xi32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d>
104+
tt.return %0 : tensor<64x1xf32, #broadcasted_warp_2d>
105+
}
106+
107+
// Keep LLVM from DCE'ing the above functions. Use volatile stores to stop LLVM
108+
// from removing unused function results.
109+
tt.func @anchor_warp4(%ptr: !llvm.ptr,
110+
%arg9: tensor<32x16xi32, #crazy_2d_idx>,
111+
%arg10: tensor<32x16xf32, #crazy_2d_src>,
112+
%arg11: tensor<16xi32, #broadcasted_lane_1d>,
113+
%arg12: tensor<16xf32, #broadcasted_lane_1d>,
114+
%arg13: tensor<64x1xi32, #broadcasted_warp_2d>,
115+
%arg14: tensor<64x1xf32, #broadcasted_warp_2d>) {
116+
117+
%12 = tt.call @gather_2d_crazy(%arg9, %arg10) : (tensor<32x16xi32, #crazy_2d_idx>, tensor<32x16xf32, #crazy_2d_src>) -> tensor<32x16xf32, #crazy_2d_idx>
118+
%13 = builtin.unrealized_conversion_cast %12 : tensor<32x16xf32, #crazy_2d_idx> to !llvm.struct<(f32, f32, f32, f32)>
119+
llvm.store volatile %13, %ptr : !llvm.struct<(f32, f32, f32, f32)>, !llvm.ptr
120+
121+
%14 = tt.call @gather_broadcasted_lane_1d(%arg11, %arg12) : (tensor<16xi32, #broadcasted_lane_1d>, tensor<16xf32, #broadcasted_lane_1d>) -> tensor<16xf32, #broadcasted_lane_1d>
122+
%15 = builtin.unrealized_conversion_cast %14 : tensor<16xf32, #broadcasted_lane_1d> to !llvm.struct<(f32)>
123+
llvm.store volatile %15, %ptr : !llvm.struct<(f32)>, !llvm.ptr
124+
125+
%16 = tt.call @gather_broadcasted_warp_2d(%arg13, %arg14) : (tensor<64x1xi32, #broadcasted_warp_2d>, tensor<64x1xf32, #broadcasted_warp_2d>) -> tensor<64x1xf32, #broadcasted_warp_2d>
126+
%17 = builtin.unrealized_conversion_cast %16 : tensor<64x1xf32, #broadcasted_warp_2d> to !llvm.struct<(f32, f32)>
127+
llvm.store volatile %17, %ptr : !llvm.struct<(f32, f32)>, !llvm.ptr
128+
129+
tt.return
130+
}
131+
132+
}

0 commit comments

Comments
 (0)