|
7 | 7 | //===----------------------------------------------------------------------===// |
8 | 8 |
|
9 | 9 | #include "TargetInfo.h" |
| 10 | +#include "SPIRVSubgroupOps.h" |
10 | 11 | #include "Utility.h" |
11 | 12 | #include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" |
12 | 13 | #include "llvm/ADT/TypeSwitch.h" |
@@ -133,46 +134,21 @@ Value warpReduceHelper(RewriterBase &rewriter, Location loc, Value acc, |
133 | 134 | unsigned warpSize) { |
134 | 135 | auto resultType = reduceOp->getResult(0).getType(); |
135 | 136 | Value warpReduce = |
136 | | - llvm::TypeSwitch<mlir::Operation *, Value>(reduceOp) |
137 | | - .Case<arith::AddFOp>([&](auto) { |
138 | | - return createSPIRVGroupOp<spirv::GroupNonUniformFAddOp>( |
139 | | - rewriter, loc, resultType, acc, numLanesToReduce, warpSize); |
140 | | - }) |
141 | | - .Case<arith::AddIOp>([&](auto) { |
142 | | - return createSPIRVGroupOp<spirv::GroupNonUniformIAddOp>( |
143 | | - rewriter, loc, resultType, acc, numLanesToReduce, warpSize); |
144 | | - }) |
145 | | - .Case<arith::MulFOp>([&](auto) { |
146 | | - return createSPIRVGroupOp<spirv::GroupNonUniformFMulOp>( |
147 | | - rewriter, loc, resultType, acc, numLanesToReduce, warpSize); |
148 | | - }) |
149 | | - .Case<arith::MulIOp>([&](auto) { |
150 | | - return createSPIRVGroupOp<spirv::GroupNonUniformIMulOp>( |
151 | | - rewriter, loc, resultType, acc, numLanesToReduce, warpSize); |
152 | | - }) |
153 | | - .Case<arith::MaxNumFOp>([&](auto) { |
154 | | - return createSPIRVGroupOp<spirv::GroupNonUniformFMaxOp>( |
155 | | - rewriter, loc, resultType, acc, numLanesToReduce, warpSize); |
156 | | - }) |
157 | | - .Case<arith::MinNumFOp>([&](auto) { |
158 | | - return createSPIRVGroupOp<spirv::GroupNonUniformFMinOp>( |
159 | | - rewriter, loc, resultType, acc, numLanesToReduce, warpSize); |
160 | | - }) |
161 | | - .Case<arith::AndIOp>([&](auto) { |
162 | | - return createSPIRVGroupOp<spirv::GroupNonUniformBitwiseAndOp>( |
163 | | - rewriter, loc, resultType, acc, numLanesToReduce, warpSize); |
164 | | - }) |
165 | | - .Case<arith::OrIOp>([&](auto) { |
166 | | - return createSPIRVGroupOp<spirv::GroupNonUniformBitwiseOrOp>( |
| 137 | + TypeSwitch<mlir::Operation *, Value>(reduceOp) |
| 138 | + .Case<arith::AddFOp, arith::AddIOp, arith::MulFOp, arith::MulIOp, |
| 139 | + arith::MaxNumFOp, arith::MinNumFOp>([&](auto groupOp) { |
| 140 | + return createSPIRVGroupOp< |
| 141 | + SPIRVArithmeticGroupOpTy<decltype(groupOp)>>( |
167 | 142 | rewriter, loc, resultType, acc, numLanesToReduce, warpSize); |
168 | 143 | }) |
169 | | - .Case<arith::XOrIOp>([&](auto) { |
170 | | - return createSPIRVGroupOp<spirv::GroupNonUniformBitwiseXorOp>( |
| 144 | + .Case<arith::AndIOp, arith::OrIOp, arith::XOrIOp>([&](auto groupOp) { |
| 145 | + if (resultType.isInteger(1)) { |
| 146 | + return createSPIRVGroupOp< |
| 147 | + SPIRVLogicalGroupOpTy<decltype(groupOp)>>( |
| 148 | + rewriter, loc, resultType, acc, numLanesToReduce, warpSize); |
| 149 | + } |
| 150 | + return createSPIRVGroupOp<SPIRVBitwiseGroupOpTy<decltype(groupOp)>>( |
171 | 151 | rewriter, loc, resultType, acc, numLanesToReduce, warpSize); |
172 | | - }) |
173 | | - .Default([](auto) { |
174 | | - llvm_unreachable("Unsupported reduction"); |
175 | | - return Value(); |
176 | 152 | }); |
177 | 153 | return warpReduce; |
178 | 154 | } |
@@ -206,12 +182,9 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, |
206 | 182 | reduceOp->getOperand(1) != block.getArgument(1)) |
207 | 183 | return false; |
208 | 184 |
|
209 | | - auto supportedOp = |
210 | | - llvm::TypeSwitch<mlir::Operation *, bool>(reduceOp) |
211 | | - .Case<arith::AddFOp, arith::AddIOp, arith::MulFOp, arith::MulIOp, |
212 | | - arith::MaxNumFOp, arith::MinNumFOp, arith::AndIOp, arith::OrIOp, |
213 | | - arith::XOrIOp>([&](auto) { return true; }) |
214 | | - .Default([](auto) { return false; }); |
| 185 | + auto supportedOp = isa<arith::AddFOp, arith::AddIOp, arith::MulFOp, |
| 186 | + arith::MulIOp, arith::MaxNumFOp, arith::MinNumFOp, |
| 187 | + arith::AndIOp, arith::OrIOp, arith::XOrIOp>(reduceOp); |
215 | 188 |
|
216 | 189 | if (!supportedOp) |
217 | 190 | return false; |
|
0 commit comments