Skip to content

Commit 66391a3

Browse files
authored
[XPU][TritonIntelGPUToLLVM] Handle arithmetic reductions of i1 values (#3113)
Use logical operations to represent arithmetic reductions of `i1` values. Only the SPIR-V builtins being used need to be changed as LLVM will handle the scalar arithmetic operations used. Closes #3109 --------- Signed-off-by: victor-eds <[email protected]>
1 parent e27d722 commit 66391a3

File tree

4 files changed

+94
-43
lines changed

4 files changed

+94
-43
lines changed

test/Conversion/intel/tritongpu_to_gen.mlir

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1508,7 +1508,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "triton_
15081508
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [32, 1], warpsPerCTA = [1, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
15091509
#slice = #ttg.slice<{dim = 0, parent = #blocked}>
15101510
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
1511-
tt.func public @reduce_all(%arg: tensor<256x1xi32, #blocked>, %arg_0: tensor<256x1xf32, #blocked>) {
1511+
tt.func public @reduce_all(%arg: tensor<256x1xi32, #blocked>, %arg_0: tensor<256x1xf32, #blocked>, %arg_1: tensor<256x1xi1, #blocked>) {
15121512

15131513
// CHECK: @_Z27__spirv_GroupNonUniformFAddiif
15141514
%0 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
@@ -1573,6 +1573,48 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
15731573
tt.reduce.return %48 : i32
15741574
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>
15751575

1576+
// CHECK: @_Z32__spirv_GroupNonUniformLogicalOriib
1577+
%10 = "tt.reduce"(%arg_1) <{axis = 0 : i32}> ({
1578+
^bb0(%arg4: i1, %arg5: i1):
1579+
%48 = arith.addi %arg4, %arg5 : i1
1580+
tt.reduce.return %48 : i1
1581+
}) : (tensor<256x1xi1, #blocked>) -> tensor<1xi1, #slice>
1582+
1583+
// CHECK: @_Z33__spirv_GroupNonUniformLogicalAndiib
1584+
%11 = "tt.reduce"(%arg_1) <{axis = 0 : i32}> ({
1585+
^bb0(%arg4: i1, %arg5: i1):
1586+
%48 = arith.muli %arg4, %arg5 : i1
1587+
tt.reduce.return %48 : i1
1588+
}) : (tensor<256x1xi1, #blocked>) -> tensor<1xi1, #slice>
1589+
1590+
// CHECK: @_Z32__spirv_GroupNonUniformLogicalOriib
1591+
%12 = "tt.reduce"(%arg_1) <{axis = 0 : i32}> ({
1592+
^bb0(%arg4: i1, %arg5: i1):
1593+
%48 = arith.maxsi %arg4, %arg5 : i1
1594+
tt.reduce.return %48 : i1
1595+
}) : (tensor<256x1xi1, #blocked>) -> tensor<1xi1, #slice>
1596+
1597+
// CHECK: @_Z32__spirv_GroupNonUniformLogicalOriib
1598+
%13 = "tt.reduce"(%arg_1) <{axis = 0 : i32}> ({
1599+
^bb0(%arg4: i1, %arg5: i1):
1600+
%48 = arith.maxui %arg4, %arg5 : i1
1601+
tt.reduce.return %48 : i1
1602+
}) : (tensor<256x1xi1, #blocked>) -> tensor<1xi1, #slice>
1603+
1604+
// CHECK: @_Z33__spirv_GroupNonUniformLogicalAndiib
1605+
%14 = "tt.reduce"(%arg_1) <{axis = 0 : i32}> ({
1606+
^bb0(%arg4: i1, %arg5: i1):
1607+
%48 = arith.minsi %arg4, %arg5 : i1
1608+
tt.reduce.return %48 : i1
1609+
}) : (tensor<256x1xi1, #blocked>) -> tensor<1xi1, #slice>
1610+
1611+
// CHECK: @_Z33__spirv_GroupNonUniformLogicalAndiib
1612+
%15 = "tt.reduce"(%arg_1) <{axis = 0 : i32}> ({
1613+
^bb0(%arg4: i1, %arg5: i1):
1614+
%48 = arith.minui %arg4, %arg5 : i1
1615+
tt.reduce.return %48 : i1
1616+
}) : (tensor<256x1xi1, #blocked>) -> tensor<1xi1, #slice>
1617+
15761618
tt.return
15771619
}
15781620
}

third_party/intel/lib/TritonIntelGPUToLLVM/SPIRVSubgroupOps.h

Lines changed: 33 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,56 +21,50 @@ using namespace mlir;
2121

2222
namespace mlir::triton::intel {
2323

24-
template <typename OpTy> struct SPIRVArithmeticGroupOp {};
24+
template <typename OpTy> struct SPIRVGroupOp {};
2525

26-
template <> struct SPIRVArithmeticGroupOp<arith::AddFOp> {
26+
template <> struct SPIRVGroupOp<arith::AddFOp> {
2727
using type = spirv::GroupNonUniformFAddOp;
2828
};
29-
template <> struct SPIRVArithmeticGroupOp<arith::AddIOp> {
29+
template <> struct SPIRVGroupOp<arith::AddIOp> {
3030
using type = spirv::GroupNonUniformIAddOp;
3131
};
32-
template <> struct SPIRVArithmeticGroupOp<arith::MulFOp> {
32+
template <> struct SPIRVGroupOp<arith::MulFOp> {
3333
using type = spirv::GroupNonUniformFMulOp;
3434
};
35-
template <> struct SPIRVArithmeticGroupOp<arith::MulIOp> {
35+
template <> struct SPIRVGroupOp<arith::MulIOp> {
3636
using type = spirv::GroupNonUniformIMulOp;
3737
};
38-
template <> struct SPIRVArithmeticGroupOp<arith::MaxSIOp> {
38+
template <> struct SPIRVGroupOp<arith::MaxSIOp> {
3939
using type = spirv::GroupNonUniformSMaxOp;
4040
};
41-
template <> struct SPIRVArithmeticGroupOp<arith::MaxUIOp> {
41+
template <> struct SPIRVGroupOp<arith::MaxUIOp> {
4242
using type = spirv::GroupNonUniformUMaxOp;
4343
};
44-
template <> struct SPIRVArithmeticGroupOp<arith::MinSIOp> {
44+
template <> struct SPIRVGroupOp<arith::MinSIOp> {
4545
using type = spirv::GroupNonUniformSMinOp;
4646
};
47-
template <> struct SPIRVArithmeticGroupOp<arith::MinUIOp> {
47+
template <> struct SPIRVGroupOp<arith::MinUIOp> {
4848
using type = spirv::GroupNonUniformUMinOp;
4949
};
50-
template <> struct SPIRVArithmeticGroupOp<arith::MaxNumFOp> {
50+
template <> struct SPIRVGroupOp<arith::MaxNumFOp> {
5151
using type = spirv::GroupNonUniformFMaxOp;
5252
};
53-
template <> struct SPIRVArithmeticGroupOp<arith::MinNumFOp> {
53+
template <> struct SPIRVGroupOp<arith::MinNumFOp> {
5454
using type = spirv::GroupNonUniformFMinOp;
5555
};
56-
57-
template <typename OpTy>
58-
using SPIRVArithmeticGroupOpTy = typename SPIRVArithmeticGroupOp<OpTy>::type;
59-
60-
template <typename OpTy> struct SPIRVBitwiseGroupOp {};
61-
62-
template <> struct SPIRVBitwiseGroupOp<arith::AndIOp> {
56+
template <> struct SPIRVGroupOp<arith::AndIOp> {
6357
using type = spirv::GroupNonUniformBitwiseAndOp;
6458
};
65-
template <> struct SPIRVBitwiseGroupOp<arith::OrIOp> {
59+
template <> struct SPIRVGroupOp<arith::OrIOp> {
6660
using type = spirv::GroupNonUniformBitwiseOrOp;
6761
};
68-
template <> struct SPIRVBitwiseGroupOp<arith::XOrIOp> {
62+
template <> struct SPIRVGroupOp<arith::XOrIOp> {
6963
using type = spirv::GroupNonUniformBitwiseXorOp;
7064
};
7165

7266
template <typename OpTy>
73-
using SPIRVBitwiseGroupOpTy = typename SPIRVBitwiseGroupOp<OpTy>::type;
67+
using SPIRVGroupOpTy = typename SPIRVGroupOp<OpTy>::type;
7468

7569
template <typename OpTy> struct SPIRVLogicalGroupOp {};
7670

@@ -83,6 +77,24 @@ template <> struct SPIRVLogicalGroupOp<arith::OrIOp> {
8377
template <> struct SPIRVLogicalGroupOp<arith::XOrIOp> {
8478
using type = spirv::GroupNonUniformLogicalXorOp;
8579
};
80+
template <> struct SPIRVLogicalGroupOp<arith::AddIOp> {
81+
using type = spirv::GroupNonUniformLogicalOrOp;
82+
};
83+
template <> struct SPIRVLogicalGroupOp<arith::MulIOp> {
84+
using type = spirv::GroupNonUniformLogicalAndOp;
85+
};
86+
template <> struct SPIRVLogicalGroupOp<arith::MaxUIOp> {
87+
using type = spirv::GroupNonUniformLogicalOrOp;
88+
};
89+
template <> struct SPIRVLogicalGroupOp<arith::MaxSIOp> {
90+
using type = spirv::GroupNonUniformLogicalOrOp;
91+
};
92+
template <> struct SPIRVLogicalGroupOp<arith::MinUIOp> {
93+
using type = spirv::GroupNonUniformLogicalAndOp;
94+
};
95+
template <> struct SPIRVLogicalGroupOp<arith::MinSIOp> {
96+
using type = spirv::GroupNonUniformLogicalAndOp;
97+
};
8698

8799
template <typename OpTy>
88100
using SPIRVLogicalGroupOpTy = typename SPIRVLogicalGroupOp<OpTy>::type;

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,25 +133,23 @@ Value warpReduceHelper(RewriterBase &rewriter, Location loc, Value acc,
133133
Operation *reduceOp, unsigned numLanesToReduce,
134134
unsigned warpSize) {
135135
auto resultType = reduceOp->getResult(0).getType();
136-
Value warpReduce =
137-
TypeSwitch<mlir::Operation *, Value>(reduceOp)
138-
.Case<arith::AddFOp, arith::AddIOp, arith::MulFOp, arith::MulIOp,
139-
arith::MaxSIOp, arith::MaxUIOp, arith::MinSIOp, arith::MinUIOp,
140-
arith::MaxNumFOp, arith::MinNumFOp>([&](auto groupOp) {
141-
return createSPIRVGroupOp<
142-
SPIRVArithmeticGroupOpTy<decltype(groupOp)>>(
143-
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
144-
})
145-
.Case<arith::AndIOp, arith::OrIOp, arith::XOrIOp>([&](auto groupOp) {
146-
if (resultType.isInteger(1)) {
147-
return createSPIRVGroupOp<
148-
SPIRVLogicalGroupOpTy<decltype(groupOp)>>(
149-
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
150-
}
151-
return createSPIRVGroupOp<SPIRVBitwiseGroupOpTy<decltype(groupOp)>>(
152-
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
153-
});
154-
return warpReduce;
136+
// Use bit-equivalent logical operation for Boolean values.
137+
if (resultType.isInteger(1))
138+
return TypeSwitch<mlir::Operation *, Value>(reduceOp)
139+
.Case<arith::AddIOp, arith::MulIOp, arith::MaxSIOp, arith::MaxUIOp,
140+
arith::MinSIOp, arith::MinUIOp, arith::AndIOp, arith::OrIOp,
141+
arith::XOrIOp>([&](auto groupOp) {
142+
return createSPIRVGroupOp<SPIRVLogicalGroupOpTy<decltype(groupOp)>>(
143+
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
144+
});
145+
return TypeSwitch<mlir::Operation *, Value>(reduceOp)
146+
.Case<arith::AddFOp, arith::AddIOp, arith::MulFOp, arith::MulIOp,
147+
arith::MaxSIOp, arith::MaxUIOp, arith::MinSIOp, arith::MinUIOp,
148+
arith::MaxNumFOp, arith::MinNumFOp, arith::AndIOp, arith::OrIOp,
149+
arith::XOrIOp>([&](auto groupOp) {
150+
return createSPIRVGroupOp<SPIRVGroupOpTy<decltype(groupOp)>>(
151+
rewriter, loc, resultType, acc, numLanesToReduce, warpSize);
152+
});
155153
}
156154

157155
} // namespace

third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -598,8 +598,7 @@ class ReduceOpConversion : public ConvertTritonGPUOpToLLVMPattern<ReduceOp> {
598598
// FIXME: support all possible reduction modes
599599
TypeSwitch<Operation *>(combine).Case<arith::AddFOp, arith::MaxNumFOp>(
600600
[&](auto reduce) {
601-
rewriter.replaceOpWithNewOp<
602-
intel::SPIRVArithmeticGroupOpTy<decltype(reduce)>>(
601+
rewriter.replaceOpWithNewOp<intel::SPIRVGroupOpTy<decltype(reduce)>>(
603602
op, typeConverter->convertType(op.getType(0)),
604603
spirv::Scope::Subgroup, spirv::GroupOperation::Reduce,
605604
adaptor.getSrcs()[0], Value());

0 commit comments

Comments
 (0)