Skip to content

Commit a10c083

Browse files
authored
Some additional optimizations for gpu (#153)
1 parent 6455d47 commit a10c083

File tree

2 files changed

+102
-4
lines changed

2 files changed

+102
-4
lines changed

mlir/lib/dialect/plier_util/dialect.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,56 @@ struct FillExtractSlice
179179
return mlir::success();
180180
}
181181
};
182+
183+
struct ReinterpretOfReinterpret
184+
: public mlir::OpRewritePattern<mlir::memref::ReinterpretCastOp> {
185+
using OpRewritePattern::OpRewritePattern;
186+
187+
mlir::LogicalResult
188+
matchAndRewrite(mlir::memref::ReinterpretCastOp op,
189+
mlir::PatternRewriter &rewriter) const override {
190+
mlir::OpFoldResult zero = rewriter.getI64IntegerAttr(0);
191+
auto getCastSrc = [&](mlir::Value src) -> mlir::Value {
192+
if (auto prev = src.getDefiningOp<mlir::memref::ReinterpretCastOp>()) {
193+
return prev.source();
194+
} else if (auto prev = src.getDefiningOp<mlir::memref::CastOp>()) {
195+
return prev.source();
196+
} else if (auto prev = src.getDefiningOp<mlir::memref::SubViewOp>()) {
197+
if (llvm::all_of(prev.getMixedOffsets(),
198+
[&](auto val) { return val == zero; }))
199+
return prev.source();
200+
}
201+
202+
return nullptr;
203+
};
204+
205+
mlir::Value src = op.source();
206+
while (auto prev = getCastSrc(src))
207+
src = prev;
208+
209+
if (src == op.source())
210+
return mlir::failure();
211+
212+
auto offsets = op.getMixedOffsets();
213+
if (offsets.size() != 1)
214+
return mlir::failure();
215+
216+
auto sizes = op.getMixedSizes();
217+
auto strides = op.getMixedStrides();
218+
219+
rewriter.replaceOpWithNewOp<mlir::memref::ReinterpretCastOp>(
220+
op, op.getType(), src, offsets.front(), sizes, strides);
221+
return mlir::success();
222+
}
223+
};
182224
} // namespace
183225

184226
void PlierUtilDialect::getCanonicalizationPatterns(
185227
mlir::RewritePatternSet &results) const {
186228
results.add<DimExpandShape<mlir::tensor::DimOp, mlir::tensor::ExpandShapeOp>,
187229
DimExpandShape<mlir::memref::DimOp, mlir::memref::ExpandShapeOp>,
188-
DimInsertSlice, FillExtractSlice>(getContext());
230+
DimInsertSlice, FillExtractSlice, ReinterpretOfReinterpret>(
231+
getContext());
189232
}
190233

191234
OpaqueType OpaqueType::get(mlir::MLIRContext *context) {

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/lower_to_gpu.cpp

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,57 @@ struct GPULowerDefaultLocalSize
13411341
}
13421342
};
13431343

1344+
struct FlattenScfIf : public mlir::OpRewritePattern<mlir::scf::IfOp> {
1345+
using OpRewritePattern::OpRewritePattern;
1346+
1347+
mlir::LogicalResult
1348+
matchAndRewrite(mlir::scf::IfOp op,
1349+
mlir::PatternRewriter &rewriter) const override {
1350+
if (op->getNumResults() == 0)
1351+
return mlir::failure();
1352+
1353+
auto arithDialect =
1354+
getContext()->getOrLoadDialect<mlir::arith::ArithmeticDialect>();
1355+
auto canFlatten = [&](mlir::Operation *op) {
1356+
return op->getDialect() == arithDialect;
1357+
};
1358+
1359+
auto &trueBody = op.getThenRegion().front();
1360+
auto &falseBody = op.getElseRegion().front();
1361+
for (auto *block : {&trueBody, &falseBody})
1362+
for (auto &op : block->without_terminator())
1363+
if (!canFlatten(&op))
1364+
return mlir::failure();
1365+
1366+
mlir::BlockAndValueMapping mapper;
1367+
for (auto *block : {&trueBody, &falseBody})
1368+
for (auto &op : block->without_terminator())
1369+
rewriter.clone(op, mapper);
1370+
1371+
auto trueYield = mlir::cast<mlir::scf::YieldOp>(trueBody.getTerminator());
1372+
auto falseYield = mlir::cast<mlir::scf::YieldOp>(falseBody.getTerminator());
1373+
1374+
llvm::SmallVector<mlir::Value> results;
1375+
results.reserve(op->getNumResults());
1376+
1377+
auto loc = op->getLoc();
1378+
auto cond = op.getCondition();
1379+
for (auto it : llvm::zip(trueYield.getResults(), falseYield.getResults())) {
1380+
auto trueVal = mapper.lookupOrDefault(std::get<0>(it));
1381+
auto falseVal = mapper.lookupOrDefault(std::get<1>(it));
1382+
auto res =
1383+
rewriter.create<mlir::arith::SelectOp>(loc, cond, trueVal, falseVal);
1384+
results.emplace_back(res);
1385+
}
1386+
1387+
rewriter.replaceOp(op, results);
1388+
return mlir::success();
1389+
}
1390+
};
1391+
1392+
struct FlattenScfPass : public plier::RewriteWrapperPass<FlattenScfPass, void,
1393+
void, FlattenScfIf> {};
1394+
13441395
template <typename Op, typename F>
13451396
static mlir::LogicalResult createGpuKernelLoad(mlir::PatternRewriter &builder,
13461397
Op &&op, F &&func) {
@@ -2972,11 +3023,15 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
29723023
pm.addPass(mlir::createSymbolDCEPass());
29733024

29743025
pm.addNestedPass<mlir::FuncOp>(std::make_unique<GPULowerDefaultLocalSize>());
2975-
pm.nest<mlir::gpu::GPUModuleOp>().addNestedPass<mlir::gpu::GPUFuncOp>(
2976-
mlir::arith::createArithmeticExpandOpsPass());
29773026
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
29783027
pm.addPass(mlir::createSymbolDCEPass());
2979-
pm.addNestedPass<mlir::gpu::GPUModuleOp>(mlir::createCanonicalizerPass());
3028+
3029+
auto &gpuFuncPM =
3030+
pm.nest<mlir::gpu::GPUModuleOp>().nest<mlir::gpu::GPUFuncOp>();
3031+
gpuFuncPM.addPass(mlir::arith::createArithmeticExpandOpsPass());
3032+
gpuFuncPM.addPass(std::make_unique<FlattenScfPass>());
3033+
commonOptPasses(gpuFuncPM);
3034+
29803035
pm.addNestedPass<mlir::gpu::GPUModuleOp>(std::make_unique<AbiAttrsPass>());
29813036
pm.addPass(std::make_unique<SetSPIRVCapabilitiesPass>());
29823037
pm.addPass(std::make_unique<GPUToSpirvPass>());

0 commit comments

Comments
 (0)