Skip to content

Commit 5c9e545

Browse files
authored
[RELAND][BACKEND] Create llvm.store when we do not need predication (#7173)
Relanding triton-lang/triton#7067 Depends on triton-lang/triton#7170 This allows to create better PTX (and sometimes SASS) and helps hiding a PTX bug we were hitting. This PR follows triton-lang/triton#4776
1 parent 16961b7 commit 5c9e545

File tree

4 files changed

+53
-57
lines changed

4 files changed

+53
-57
lines changed

test/Conversion/cvt_to_llvm.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ tt.func private @convert_layout_blocked_blocked(%arg0: tensor<16x16xi32, #blocke
127127
// to this, we choose to fall back to the shared memory implementation.
128128

129129
// CHECK-NOT: shfl.sync.idx
130-
// CHECK: st.shared
130+
// CHECK: store
131131

132132
%0 = ttg.convert_layout %arg0 : tensor<16x16xi32, #blocked0> -> tensor<16x16xi32, #blocked1>
133133
tt.return %0 : tensor<16x16xi32, #blocked1>

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 32 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
804804
// CHECK-LABEL: convert_layout_blocked_blocked
805805
tt.func @convert_layout_blocked_blocked(%arg0: tensor<32x32xf32, #blocked0>) {
806806
// CHECK: llvm.mlir.addressof @global_smem
807-
// CHECK-COUNT-8: llvm.inline_asm {{.*}} st.shared
807+
// CHECK-COUNT-8: llvm.store
808808
// CHECK-: nvvm.barrier0
809809
// CHECK-COUNT-8: llvm.load
810810
%0 = ttg.convert_layout %arg0 : tensor<32x32xf32, #blocked0> -> tensor<32x32xf32, #blocked1>
@@ -821,10 +821,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
821821
// CHECK-LABEL: convert_layout_blocked_blocked_vec
822822
tt.func @convert_layout_blocked_blocked_vec(%arg0: tensor<32x32xf32, #blocked0>) {
823823
// CHECK: llvm.mlir.addressof @global_smem
824-
// CHECK: llvm.inline_asm
825-
// CHECK: st.shared
826-
// CHECK: llvm.inline_asm
827-
// CHECK: st.shared
824+
// CHECK: llvm.store
825+
// CHECK: llvm.store
828826
// CHECK: nvvm.barrier0
829827
// CHECK: llvm.load
830828
// CHECK: llvm.load
@@ -859,14 +857,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
859857
// CHECK-LABEL: convert_layout_blocked_blocked_multi_rep
860858
tt.func @convert_layout_blocked_blocked_multi_rep(%arg0: tensor<16x16xf32, #blocked0>) {
861859
// CHECK: llvm.mlir.addressof @global_smem
862-
// CHECK: llvm.inline_asm
863-
// CHECK: st.shared
860+
// CHECK: llvm.store
864861
// CHECK: nvvm.barrier0
865862
// CHECK: llvm.load
866863
// CHECK: llvm.load
867864
// CHECK: nvvm.barrier0
868-
// CHECK: llvm.inline_asm
869-
// CHECK: st.shared
865+
// CHECK: llvm.store
870866
// CHECK: nvvm.barrier0
871867
// CHECK: llvm.load
872868
// CHECK: llvm.load
@@ -1024,10 +1020,8 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10241020
// CHECK: llvm.mlir.global external @global_smem
10251021
// CHECK-LABEL: convert_layout_mmav2_block
10261022
tt.func @convert_layout_mmav2_blocked(%arg0: tensor<32x16xf32, #mma>) {
1027-
// CHECK: llvm.inline_asm
1028-
// CHECK-SAME: st.shared
1029-
// CHECK: llvm.inline_asm
1030-
// CHECK-SAME: st.shared
1023+
// CHECK: llvm.store
1024+
// CHECK: llvm.store
10311025
// CHECK: nvvm.barrier0
10321026
// CHECK: llvm.load
10331027
%0 = ttg.convert_layout %arg0 : tensor<32x16xf32, #mma> -> tensor<32x16xf32, #blocked0>
@@ -1042,7 +1036,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10421036
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
10431037
// CHECK-LABEL: convert_layout_mmav2_dot_reg
10441038
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) {
1045-
// CHECK-NOT: st.shared
1039+
// CHECK-NOT: llvm.store
10461040
// CHECK-NOT: llvm.load
10471041
%0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1>
10481042
tt.return
@@ -1056,7 +1050,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
10561050
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
10571051
// CHECK-LABEL: convert_layout_mmav2_dot_reg
10581052
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<1x16xf16, #mma>) {
1059-
// CHECK-NOT: st.shared
1053+
// CHECK-NOT: llvm.store
10601054
// CHECK-NOT: llvm.load
10611055
%0 = ttg.convert_layout %arg0 : tensor<1x16xf16, #mma> -> tensor<1x16xf16, #dot1>
10621056
tt.return
@@ -1072,7 +1066,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
10721066
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10731067
// CHECK-LABEL: convert_layout_slice_mmav2_blocked_reg
10741068
tt.func @convert_layout_slice_mmav2_blocked_reg(%arg0: tensor<1xf16, #slice>) {
1075-
// CHECK-NOT: st.shared
1069+
// CHECK-NOT: llvm.store
10761070
// CHECK-NOT: llvm.load
10771071
%0 = ttg.convert_layout %arg0 : tensor<1xf16, #slice> -> tensor<1xf16, #blocked>
10781072
tt.return
@@ -1087,7 +1081,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10871081
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
10881082
// CHECK-LABEL: convert_layout_mmav3_mmav3_0
10891083
tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) {
1090-
// CHECK-NOT: st.shared
1084+
// CHECK-NOT: llvm.store
10911085
// CHECK-NOT: llvm.load
10921086
%0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1>
10931087
tt.return
@@ -1102,7 +1096,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11021096
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11031097
// CHECK-LABEL: convert_layout_mmav3_mmav3_1
11041098
tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) {
1105-
// CHECK-NOT: st.shared
1099+
// CHECK-NOT: llvm.store
11061100
// CHECK-NOT: llvm.load
11071101
%0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0>
11081102
tt.return
@@ -1117,7 +1111,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11171111
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11181112
// CHECK-LABEL: convert_layout_mmav3_mmav3_2
11191113
tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) {
1120-
// CHECK-NOT: st.shared
1114+
// CHECK-NOT: llvm.store
11211115
// CHECK-NOT: llvm.load
11221116
%0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0>
11231117
tt.return
@@ -1132,7 +1126,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11321126
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11331127
// CHECK-LABEL: convert_layout_mmav3_mmav3_3
11341128
tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) {
1135-
// CHECK-NOT: st.shared
1129+
// CHECK-NOT: llvm.store
11361130
// CHECK-NOT: llvm.load
11371131
%0 = ttg.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0>
11381132
tt.return
@@ -1146,7 +1140,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11461140
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
11471141
// CHECK-LABEL: convert_layout_mmav2_dot_reg
11481142
tt.func @convert_layout_mmav2_dot_reg(%arg0: tensor<16x16xf16, #mma>) {
1149-
// CHECK-NOT: st.shared
1143+
// CHECK-NOT: llvm.store
11501144
// CHECK-NOT: llvm.load
11511145
%0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma> -> tensor<16x16xf16, #dot1>
11521146
tt.return
@@ -1161,7 +1155,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
11611155
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11621156
// CHECK-LABEL: convert_layout_mmav3_mmav3_0
11631157
tt.func @convert_layout_mmav3_mmav3_0(%arg0: tensor<64x64xf16, #mma0>) {
1164-
// CHECK-NOT: st.shared
1158+
// CHECK-NOT: llvm.store
11651159
// CHECK-NOT: llvm.load
11661160
%0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma0> -> tensor<64x64xf16, #mma1>
11671161
tt.return
@@ -1176,7 +1170,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11761170
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11771171
// CHECK-LABEL: convert_layout_mmav3_mmav3_1
11781172
tt.func @convert_layout_mmav3_mmav3_1(%arg0: tensor<64x64xf16, #mma1>) {
1179-
// CHECK-NOT: st.shared
1173+
// CHECK-NOT: llvm.store
11801174
// CHECK-NOT: llvm.load
11811175
%0 = ttg.convert_layout %arg0 : tensor<64x64xf16, #mma1> -> tensor<64x64xf16, #mma0>
11821176
tt.return
@@ -1191,7 +1185,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11911185
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
11921186
// CHECK-LABEL: convert_layout_mmav3_mmav3_2
11931187
tt.func @convert_layout_mmav3_mmav3_2(%arg0: tensor<16x16xf16, #mma1>) {
1194-
// CHECK-NOT: st.shared
1188+
// CHECK-NOT: llvm.store
11951189
// CHECK-NOT: llvm.load
11961190
%0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mma1> -> tensor<16x16xf16, #mma0>
11971191
tt.return
@@ -1206,7 +1200,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12061200
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
12071201
// CHECK-LABEL: convert_layout_mmav3_mmav3_3
12081202
tt.func @convert_layout_mmav3_mmav3_3(%arg0: tensor<1x64xf16, #mma1>) {
1209-
// CHECK-NOT: st.shared
1203+
// CHECK-NOT: llvm.store
12101204
// CHECK-NOT: llvm.load
12111205
%0 = ttg.convert_layout %arg0 : tensor<1x64xf16, #mma1> -> tensor<1x64xf16, #mma0>
12121206
tt.return
@@ -1221,28 +1215,28 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
12211215
// CHECK: llvm.mlir.global external @global_smem
12221216
// CHECK-LABEL: convert_layout_mmav3_transpose
12231217
tt.func @convert_layout_mmav3_transpose(%arg0: tensor<128x256xf8E5M2, #mma>) {
1224-
// CHECK-COUNT-16: st.shared.b8
1218+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12251219
// CHECK: nvvm.barrier0
12261220
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1227-
// CHECK-COUNT-16: st.shared.b8
1221+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12281222
// CHECK: nvvm.barrier0
12291223
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1230-
// CHECK-COUNT-16: st.shared.b8
1224+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12311225
// CHECK: nvvm.barrier0
12321226
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1233-
// CHECK-COUNT-16: st.shared.b8
1227+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12341228
// CHECK: nvvm.barrier0
12351229
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1236-
// CHECK-COUNT-16: st.shared.b8
1230+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12371231
// CHECK: nvvm.barrier0
12381232
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1239-
// CHECK-COUNT-16: st.shared.b8
1233+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12401234
// CHECK: nvvm.barrier0
12411235
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1242-
// CHECK-COUNT-16: st.shared.b8
1236+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12431237
// CHECK: nvvm.barrier0
12441238
// CHECK: llvm.load {{.*}} -> vector<4xi32>
1245-
// CHECK-COUNT-16: st.shared.b8
1239+
// CHECK-COUNT-16: llvm.store {{.*}} : vector<1xi8>
12461240
// CHECK: nvvm.barrier0
12471241
// CHECK: llvm.load {{.*}} -> vector<4xi32>
12481242
%0 = ttg.convert_layout %arg0 : tensor<128x256xf8E5M2, #mma> -> tensor<128x256xf8E5M2, #blocked>
@@ -1301,7 +1295,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
13011295
// CHECK-LABEL: convert_blocked_to_blocked_ptr
13021296
tt.func @convert_blocked_to_blocked_ptr(%src:tensor<32x!tt.ptr<f32>, #blocked0>) {
13031297
// CHECK: llvm.ptrtoint
1304-
// CHECK: inline_asm{{.*}}st.shared
1298+
// CHECK: llvm.store
13051299
// CHECK: nvvm.barrier0
13061300
// CHECK: llvm.inttoptr
13071301
// CHECK-COUNT-4: llvm.insertvalue
@@ -1319,13 +1313,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
13191313
// CHECK-LABEL: linear_layout_with_multiple_iterations
13201314
tt.func @linear_layout_with_multiple_iterations(%src: tensor<8x4xbf16, #linear>) {
13211315
%cvt = ttg.convert_layout %src : tensor<8x4xbf16, #linear> -> tensor<8x4xbf16, #linear1>
1322-
// CHECK: inline_asm{{.*}}st.shared.v2
1316+
// CHECK: llvm.store {{.*}} : vector<2xi16>
13231317
// CHECK: nvvm.barrier0
1324-
// CHECK: llvm.load
1318+
// CHECK: llvm.load {{.*}} -> i16
13251319
// CHECK: nvvm.barrier0
1326-
// CHECK: inline_asm{{.*}}st.shared.v2
1320+
// CHECK: llvm.store {{.*}} : vector<2xi16>
13271321
// CHECK: nvvm.barrier0
1328-
// CHECK: llvm.load
1322+
// CHECK: llvm.load {{.*}} -> i16
13291323
tt.return
13301324
}
13311325
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ static std::string getConstraintForBitwidth(unsigned bitwidth) {
153153

154154
static bool isConstantTruePred(Value pred) {
155155
if (auto constOp = pred.getDefiningOp<LLVM::ConstantOp>()) {
156-
return cast<IntegerAttr>(constOp.getValue()).getInt() != 0;
156+
return cast<IntegerAttr>(constOp.getValue()).getInt() == -1;
157157
}
158158
return false;
159159
}
@@ -258,19 +258,23 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr,
258258
.b(elemBitwidth);
259259
auto *ptrOpr = builder.newAddrOperand(ptr, "r");
260260

261-
PTXBuilder::Operand *valOpr;
262-
std::string constraint = getConstraintForBitwidth(elemBitwidth);
263-
if (vec > 1) {
264-
SmallVector<std::pair<Value, std::string>> vecVals;
265-
for (int i = 0; i < vec; i++) {
266-
vecVals.push_back({b.extract_element(val, b.i32_val(i)), constraint});
267-
}
268-
valOpr = builder.newListOperand(vecVals);
261+
if (isConstantTruePred(pred)) {
262+
b.store(val, ptr, /*align=*/vec * elemBitwidth / 8);
269263
} else {
270-
valOpr = builder.newOperand(val, constraint);
264+
PTXBuilder::Operand *valOpr;
265+
std::string constraint = getConstraintForBitwidth(elemBitwidth);
266+
if (vec > 1) {
267+
SmallVector<std::pair<Value, std::string>> vecVals;
268+
for (int i = 0; i < vec; i++) {
269+
vecVals.push_back({b.extract_element(val, b.i32_val(i)), constraint});
270+
}
271+
valOpr = builder.newListOperand(vecVals);
272+
} else {
273+
valOpr = builder.newOperand(val, constraint);
274+
}
275+
st(ptrOpr, valOpr).predicate(pred, "b");
276+
builder.launch(rewriter, loc, void_ty(ctx));
271277
}
272-
st(ptrOpr, valOpr).predicate(pred, "b");
273-
builder.launch(rewriter, loc, void_ty(ctx));
274278
}
275279

276280
Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
@@ -375,7 +379,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr,
375379
if (isConstantTruePred(pred)) {
376380
Type resultTy = vec == 1 ? Type(int_ty(elemBitwidth))
377381
: Type(vec_ty(int_ty(elemBitwidth), vec));
378-
load = b.load(resultTy, ptr);
382+
load = b.load(resultTy, ptr, /*align=*/vec * elemBitwidth / 8);
379383
if (vec > 1) {
380384
Type structTy = struct_ty(SmallVector<Type>(vec, int_ty(elemBitwidth)));
381385
Value structValue = b.undef(structTy);

third_party/proton/test/test_cmd.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,6 @@ def test_instrument_exec():
5858
assert result[5] == ['5', 'matmul_kernel', 'instrument.py:33:20', 'SHARED', 'LOAD']
5959
assert result[6] == ['6', 'matmul_kernel', 'instrument.py:42:21', 'GLOBAL', 'STORE']
6060
else:
61-
assert [row[0] for row in result] == ['0']
62-
assert [row[1] for row in result] == ['matmul_kernel']
63-
assert [row[2] for row in result] == ['instrument.py:42:21']
64-
assert [row[3] for row in result] == ['SHARED']
65-
assert [row[4] for row in result] == ['LOAD']
61+
assert len(result) == 2
62+
assert result[0] == ['0', 'matmul_kernel', 'instrument.py:42:21', 'SHARED', 'STORE']
63+
assert result[1] == ['1', 'matmul_kernel', 'instrument.py:42:21', 'SHARED', 'LOAD']

0 commit comments

Comments
 (0)