Skip to content

Commit 5a894fe

Browse files
Add warp reduce for integer min and max (#2933)
Signed-off-by: Whitney Tsang <[email protected]>
1 parent 585ee9c commit 5a894fe

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/SPIRVSubgroupOps.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,18 @@ template <> struct SPIRVArithmeticGroupOp<arith::MulFOp> {
3535
template <> struct SPIRVArithmeticGroupOp<arith::MulIOp> {
3636
using type = spirv::GroupNonUniformIMulOp;
3737
};
38+
template <> struct SPIRVArithmeticGroupOp<arith::MaxSIOp> {
39+
using type = spirv::GroupNonUniformSMaxOp;
40+
};
41+
template <> struct SPIRVArithmeticGroupOp<arith::MaxUIOp> {
42+
using type = spirv::GroupNonUniformUMaxOp;
43+
};
44+
template <> struct SPIRVArithmeticGroupOp<arith::MinSIOp> {
45+
using type = spirv::GroupNonUniformSMinOp;
46+
};
47+
template <> struct SPIRVArithmeticGroupOp<arith::MinUIOp> {
48+
using type = spirv::GroupNonUniformUMinOp;
49+
};
3850
template <> struct SPIRVArithmeticGroupOp<arith::MaxNumFOp> {
3951
using type = spirv::GroupNonUniformFMaxOp;
4052
};

third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ Value warpReduceHelper(RewriterBase &rewriter, Location loc, Value acc,
136136
Value warpReduce =
137137
TypeSwitch<mlir::Operation *, Value>(reduceOp)
138138
.Case<arith::AddFOp, arith::AddIOp, arith::MulFOp, arith::MulIOp,
139+
arith::MaxSIOp, arith::MaxUIOp, arith::MinSIOp, arith::MinUIOp,
139140
arith::MaxNumFOp, arith::MinNumFOp>([&](auto groupOp) {
140141
return createSPIRVGroupOp<
141142
SPIRVArithmeticGroupOpTy<decltype(groupOp)>>(
@@ -182,9 +183,11 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc,
182183
reduceOp->getOperand(1) != block.getArgument(1))
183184
return false;
184185

185-
auto supportedOp = isa<arith::AddFOp, arith::AddIOp, arith::MulFOp,
186-
arith::MulIOp, arith::MaxNumFOp, arith::MinNumFOp,
187-
arith::AndIOp, arith::OrIOp, arith::XOrIOp>(reduceOp);
186+
auto supportedOp =
187+
isa<arith::AddFOp, arith::AddIOp, arith::MulFOp, arith::MulIOp,
188+
arith::MaxSIOp, arith::MaxUIOp, arith::MinSIOp, arith::MinUIOp,
189+
arith::MaxNumFOp, arith::MinNumFOp, arith::AndIOp, arith::OrIOp,
190+
arith::XOrIOp>(reduceOp);
188191

189192
if (!supportedOp)
190193
return false;

0 commit comments

Comments
 (0)