Skip to content

Commit c54e067

Browse files
[TritonGEN] Add predicated store (#5195)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent caa826f commit c54e067

File tree

7 files changed

+110
-28
lines changed

7 files changed

+110
-28
lines changed

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5353
"TRITON_INTEL_ENABLE_INSTR_SCHED",
5454
"TRITON_INTEL_FAST_MATH",
5555
"TRITON_INTEL_ONE_MATRIX_PER_LOAD_BT",
56-
"TRITON_INTEL_PREDICATED_LOAD",
56+
"TRITON_INTEL_PREDICATED",
5757
"TRITON_INTEL_REDUCE_TRANSPOSE",
5858
// clang-format on
5959
};

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --dump-input-context=20 --check-prefixes=CHECK,NO-PREDICATED
2-
// RUN: env TRITON_INTEL_PREDICATED_LOAD=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --dump-input-context=20 --check-prefixes=CHECK,PREDICATED
1+
// RUN: env TRITON_INTEL_PREDICATED=0 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --dump-input-context=20 --check-prefixes=CHECK,NO-PREDICATED
2+
// RUN: env TRITON_INTEL_PREDICATED=1 triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm --convert-tritongen-to-llvm | FileCheck %s --implicit-check-not=llvm.inline_asm --dump-input-context=20 --check-prefixes=CHECK,PREDICATED
33

44
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
55
// CHECK: llvm.func spir_kernelcc @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<1>, %arg2: !llvm.ptr<1>, %arg3: !llvm.ptr<1>)
@@ -694,21 +694,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
694694
// CHECK-NEXT: [[VEC2:%.*]] = llvm.mlir.undef : vector<1xi32>
695695
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
696696
// CHECK-NEXT: [[IE2:%.*]] = llvm.insertelement [[BCAST1]], [[VEC2]][[[ZERO]] : i32] : vector<1xi32>
697-
// CHECK-NEXT: llvm.cond_br [[ARG2_0]], ^bb1, ^bb2
698-
// CHECK-NEXT: ^bb1:
699-
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1>
700-
// CHECK-NEXT: llvm.store [[IE2]], [[BCAST2]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
701-
// CHECK-NEXT: llvm.br ^bb2
702-
// CHECK-NEXT: ^bb2:
697+
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[ARG0_0]] : !llvm.ptr<1> to !llvm.ptr<1>
698+
// PREDICATED-NEXT: [[BCAST3:%.*]] = llvm.bitcast [[IE2]] : vector<1xi32> to vector<1xf32>
699+
// PREDICATED: [[ALIGNMENT:%.*]] = llvm.mlir.constant(4 : i64) : i64
700+
// PREDICATED: llvm.call spir_funccc @llvm.genx.GenISA.PredicatedStore.p1f32.v1f32([[BCAST2]], [[BCAST3]], [[ALIGNMENT]], [[ARG2_0]]) {{.*}} : (!llvm.ptr<1>, vector<1xf32>, i64, i1) -> ()
701+
// NO-PREDICATED: llvm.cond_br [[ARG2_0]], ^bb1, ^bb2
702+
// NO-PREDICATED-NEXT: ^bb1:
703+
// NO-PREDICATED-NEXT: llvm.store [[IE2]], [[BCAST2]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
704+
// NO-PREDICATED-NEXT: llvm.br ^bb2
705+
// NO-PREDICATED-NEXT: ^bb2:
703706
// CHECK: [[VEC3:%.*]] = llvm.mlir.undef : vector<1xi32>
704707
// CHECK-NEXT: [[ZERO:%.*]] = llvm.mlir.constant(0 : i32) : i32
705708
// CHECK-NEXT: [[IE3:%.*]] = llvm.insertelement {{.*}}, [[VEC3]][[[ZERO]] : i32] : vector<1xi32>
706-
// CHECK: llvm.cond_br [[ARG2_1]], ^bb3, ^bb4
707-
// CHECK-NEXT: ^bb3:
708709
// CHECK-NEXT: [[BCAST2:%.*]] = llvm.bitcast [[ARG0_1]] : !llvm.ptr<1> to !llvm.ptr<1>
709-
// CHECK-NEXT: llvm.store [[IE3]], [[BCAST2]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
710-
// CHECK-NEXT: llvm.br ^bb4
711-
// CHECK-NEXT: ^bb4:
710+
// PREDICATED-NEXT: [[BCAST3:%.*]] = llvm.bitcast [[IE3]] : vector<1xi32> to vector<1xf32>
711+
// PREDICATED: [[ALIGNMENT:%.*]] = llvm.mlir.constant(4 : i64) : i64
712+
// PREDICATED: llvm.call spir_funccc @llvm.genx.GenISA.PredicatedStore.p1f32.v1f32([[BCAST2]], [[BCAST3]], [[ALIGNMENT]], [[ARG2_1]]) {{.*}} : (!llvm.ptr<1>, vector<1xf32>, i64, i1) -> ()
713+
// NO-PREDICATED: llvm.cond_br [[ARG2_1]], ^bb3, ^bb4
714+
// NO-PREDICATED-NEXT: ^bb3:
715+
// NO-PREDICATED-NEXT: llvm.store [[IE3]], [[BCAST2]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
716+
// NO-PREDICATED-NEXT: llvm.br ^bb4
717+
// NO-PREDICATED-NEXT: ^bb4:
712718
tt.store %ptrs, %vals, %mask : tensor<256x!tt.ptr<f32>, #blocked0>
713719
tt.return
714720
}
@@ -1345,10 +1351,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
13451351
// CHECK-LABEL: store_f32_scalar
13461352
tt.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
13471353
// CHECK: llvm.icmp "eq"
1348-
// CHECK: llvm.cond_br {{.*}}, ^bb1, ^bb2
1349-
// CHECK-NEXT: ^bb1:
1350-
// CHECK-NEXT: [[BCAST:%.*]] = llvm.bitcast %arg0 : !llvm.ptr<1> to !llvm.ptr<1>
1351-
// CHECK-NEXT: llvm.store {{.*}}, [[BCAST]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
1354+
// CHECK: [[BCAST:%.*]] = llvm.bitcast %arg0 : !llvm.ptr<1> to !llvm.ptr<1>
1355+
// PREDICATED: llvm.call spir_funccc @llvm.genx.GenISA.PredicatedStore.p1f32.v1f32([[BCAST]], {{.*}}) {{.*}} : (!llvm.ptr<1>, vector<1xf32>, i64, i1) -> ()
1356+
// NO-PREDICATED: llvm.cond_br {{.*}}, ^bb1, ^bb2
1357+
// NO-PREDICATED-NEXT: ^bb1:
1358+
// NO-PREDICATED-NEXT: llvm.store {{.*}}, [[BCAST]] {alignment = 4 : i64} : vector<1xi32>, !llvm.ptr<1>
13521359
tt.store %arg0, %arg1 : !tt.ptr<f32>
13531360
tt.return
13541361
}

test/TritonGEN/tritongen-to-llvm.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,3 +139,12 @@ llvm.func @triton_gen.predicated_load(%ptr : !llvm.ptr<1>, %alignment : i64, %pr
139139
%0 = triton_gen.predicated_load %ptr, %alignment, %predicate, %default_value : !llvm.ptr<1>, i64, i1, i32 -> i32
140140
llvm.return
141141
}
142+
143+
// -----
144+
145+
llvm.func @triton_gen.predicated_store(%ptr : !llvm.ptr<1>, %value : i32, %alignment : i64, %predicate : i1) {
146+
// CHECK: llvm.func @triton_gen.predicated_store(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i64, %arg3: i1) {
147+
// CHECK: llvm.call spir_funccc @llvm.genx.GenISA.PredicatedStore.p1i32.i32(%arg0, %arg1, %arg2, %arg3) {{.*}} : (!llvm.ptr<1>, i32, i64, i1) -> ()
148+
triton_gen.predicated_store %ptr, %value, %alignment, %predicate : !llvm.ptr<1>, i32, i64, i1
149+
llvm.return
150+
}

test/TritonGEN/tritongen.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,4 +95,11 @@ llvm.func @triton_gen.predicated_load(%ptr : !llvm.ptr<1>, %alignment : i64, %pr
9595
%0 = triton_gen.predicated_load %ptr, %alignment, %predicate, %default_value : !llvm.ptr<1>, i64, i1, i32 -> i32
9696
llvm.return
9797
}
98+
99+
llvm.func @triton_gen.predicated_store(%ptr : !llvm.ptr<1>, %value : i32, %alignment : i64, %predicate : i1) {
100+
// CHECK: llvm.func @triton_gen.predicated_store(%arg0: !llvm.ptr<1>, %arg1: i32, %arg2: i64, %arg3: i1) {
101+
// CHECK-NEXT: triton_gen.predicated_store %arg0, %arg1, %arg2, %arg3 : !llvm.ptr<1>, i32, i64, i1
102+
triton_gen.predicated_store %ptr, %value, %alignment, %predicate : !llvm.ptr<1>, i32, i64, i1
103+
llvm.return
104+
}
98105
}

third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,26 @@ def TritonGEN_PredicatedLoadOp
409409
}];
410410
}
411411

412+
def TritonGEN_PredicatedStoreOp
413+
: TritonGEN_Op<"predicated_store"> {
414+
let summary = "Predicated store operation";
415+
let description = [{
416+
The `triton_gen.predicated_store` operation stores a value to memory
417+
conditionally based on the predicate. If the predicate is true, the value
418+
is stored to the specified pointer; otherwise, no operation is performed.
419+
}];
420+
let arguments = (ins
421+
Arg<LLVM_AnyPointer, "", [MemWrite]>:$ptr,
422+
AnyType:$value,
423+
I64: $alignment,
424+
I1:$predicate);
425+
let results = (outs);
426+
let assemblyFormat = [{
427+
$ptr `,` $value `,` $alignment `,` $predicate attr-dict `:` qualified(type($ptr)) `,`
428+
type($value) `,` type($alignment) `,` type($predicate)
429+
}];
430+
}
431+
412432
def TritonGEN_FToTf32Op
413433
: TritonGEN_Op<"f_to_tf32", [SameOperandsAndResultType]> {
414434
let summary = "Rounding instruction from float to tensor float (TF32) data format";

third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -960,6 +960,37 @@ struct TritonPredicatedLoadOpLowering
960960
}
961961
};
962962

963+
struct TritonPredicatedStoreOpLowering
964+
: public ConvertOpToLLVMPattern<TritonGEN::PredicatedStoreOp> {
965+
using ConvertOpToLLVMPattern<
966+
TritonGEN::PredicatedStoreOp>::ConvertOpToLLVMPattern;
967+
968+
LogicalResult
969+
matchAndRewrite(TritonGEN::PredicatedStoreOp op, OpAdaptor adaptor,
970+
ConversionPatternRewriter &rewriter) const override {
971+
MLIRContext *ctx = rewriter.getContext();
972+
Location loc = op->getLoc();
973+
auto b = TritonLLVMOpBuilder(loc, rewriter);
974+
Type valType = op.getValue().getType();
975+
// Create a call to the SPIR-V builtin for predicated store.
976+
std::string typeMangling = getGenISATypeMangling(valType);
977+
std::string ptrTypeMangling = getGenISATypeMangling(valType);
978+
if (auto vecTy = dyn_cast<VectorType>(valType))
979+
ptrTypeMangling = getGenISATypeMangling(vecTy.getElementType());
980+
std::string funcName = "llvm.genx.GenISA.PredicatedStore.p1" +
981+
ptrTypeMangling + "." + typeMangling;
982+
SmallVector<Type> argTypes{ptr_ty(ctx, 1), valType, int_ty(64), int_ty(1)};
983+
SmallVector<Value> args{op.getPtr(), op.getValue(), op.getAlignment(),
984+
op.getPredicate()};
985+
986+
LLVM::CallOp callOp = intel::createDeviceFunctionCall(
987+
rewriter, funcName, void_ty(ctx), argTypes, args, {},
988+
intel::noUnwindWillReturnAttrs);
989+
rewriter.replaceOp(op, callOp);
990+
return success();
991+
}
992+
};
993+
963994
struct TritonFToTf32OpLowering
964995
: public ConvertOpToLLVMPattern<TritonGEN::FToTf32Op> {
965996
using ConvertOpToLLVMPattern<TritonGEN::FToTf32Op>::ConvertOpToLLVMPattern;
@@ -1049,11 +1080,12 @@ struct TritonGENToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
10491080

10501081
void mlir::triton::populateTritonGENToLLVMConversionPatterns(
10511082
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
1052-
patterns.add<
1053-
TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
1054-
TritonMatrix2DBlockStoreLowering, TritonMatrix2DBlockPrefetchLowering,
1055-
TritonSubGroupBlockReadLowering, TritonSubGroupBlockWriteLowering,
1056-
TritonPredicatedLoadOpLowering, TritonFToTf32OpLowering>(converter);
1083+
patterns
1084+
.add<TritonMatrixDPASLowering, TritonMatrix2DBlockLoadLowering,
1085+
TritonMatrix2DBlockStoreLowering,
1086+
TritonMatrix2DBlockPrefetchLowering, TritonSubGroupBlockReadLowering,
1087+
TritonSubGroupBlockWriteLowering, TritonPredicatedLoadOpLowering,
1088+
TritonPredicatedStoreOpLowering, TritonFToTf32OpLowering>(converter);
10571089
}
10581090

10591091
void registerConvertTritonTritonGENToLLVMInterface(DialectRegistry &registry) {

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3320,7 +3320,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
33203320
Value ret;
33213321
// Create a predicated load operation.
33223322
if (pred) {
3323-
if (triton::tools::getBoolEnv("TRITON_INTEL_PREDICATED_LOAD"))
3323+
if (triton::tools::getBoolEnv("TRITON_INTEL_PREDICATED"))
33243324
ret = rewriter.create<TritonGEN::PredicatedLoadOp>(
33253325
loc, retTy, addrElem, b.i64_val(alignment), pred, other_);
33263326
else {
@@ -3756,17 +3756,24 @@ struct StoreOpConversion
37563756
vecWord = b.insert_element(vecTy, vecWord, llWord, b.i32_val(index));
37573757
}
37583758

3759+
Value addrElem = b.bitcast(ptrElems[vecStart], ptr_ty(ctx, 1 /*global*/));
3760+
uint32_t alignment = nWords * width / 8;
37593761
auto createStore = [&]() -> ArrayRef<Value> {
3760-
Value addrElem =
3761-
b.bitcast(ptrElems[vecStart], ptr_ty(ctx, 1 /*global*/));
3762-
uint32_t alignment = nWords * width / 8;
37633762
b.store(vecWord, addrElem, alignment);
37643763
return ArrayRef<Value>();
37653764
};
37663765

37673766
if (maskVal) {
37683767
// Create a predicated store operation.
3769-
LLVM::intel::createPredicatedBlock(rewriter, loc, maskVal, createStore);
3768+
if (triton::tools::getBoolEnv("TRITON_INTEL_PREDICATED")) {
3769+
unsigned numElems = valArgTy.getIntOrFloatBitWidth() * nWords /
3770+
valueElemTy.getIntOrFloatBitWidth();
3771+
vecWord = b.bitcast(vecWord, vec_ty(valueElemTy, numElems));
3772+
rewriter.create<TritonGEN::PredicatedStoreOp>(
3773+
loc, addrElem, vecWord, b.i64_val(alignment), maskVal);
3774+
} else
3775+
LLVM::intel::createPredicatedBlock(rewriter, loc, maskVal,
3776+
createStore);
37703777
} else {
37713778
auto _ = createStore();
37723779
}

0 commit comments

Comments
 (0)