|
19 | 19 | #include <mlir/IR/BuiltinOps.h>
|
20 | 20 | #include <mlir/IR/BuiltinTypes.h>
|
21 | 21 | #include <mlir/IR/DialectImplementation.h>
|
| 22 | +#include <mlir/IR/Dominance.h> |
22 | 23 | #include <mlir/IR/PatternMatch.h>
|
23 | 24 | #include <mlir/Transforms/InliningUtils.h>
|
24 | 25 |
|
|
28 | 29 | #include <mlir/Dialect/Linalg/IR/Linalg.h>
|
29 | 30 | #include <mlir/Dialect/MemRef/IR/MemRef.h>
|
30 | 31 | #include <mlir/Dialect/SCF/SCF.h>
|
| 32 | +#include <mlir/Dialect/SPIRV/IR/SPIRVOps.h> |
31 | 33 |
|
32 | 34 | #include <llvm/ADT/SmallBitVector.h>
|
33 | 35 | #include <llvm/ADT/TypeSwitch.h>
|
@@ -185,13 +187,52 @@ struct FillExtractSlice
|
185 | 187 | return mlir::success();
|
186 | 188 | }
|
187 | 189 | };
|
| 190 | + |
| 191 | +struct SpirvInputCSE : public mlir::OpRewritePattern<mlir::spirv::LoadOp> { |
| 192 | + using OpRewritePattern::OpRewritePattern; |
| 193 | + |
| 194 | + mlir::LogicalResult |
| 195 | + matchAndRewrite(mlir::spirv::LoadOp op, |
| 196 | + mlir::PatternRewriter &rewriter) const override { |
| 197 | + auto ptr = op.ptr(); |
| 198 | + if (ptr.getType().cast<mlir::spirv::PointerType>().getStorageClass() != |
| 199 | + mlir::spirv::StorageClass::Input) |
| 200 | + return mlir::failure(); |
| 201 | + |
| 202 | + auto func = op->getParentOfType<mlir::spirv::FuncOp>(); |
| 203 | + if (!func) |
| 204 | + return mlir::failure(); |
| 205 | + |
| 206 | + mlir::DominanceInfo dom; |
| 207 | + mlir::spirv::LoadOp prevLoad; |
| 208 | + func->walk([&](mlir::spirv::LoadOp load) -> mlir::WalkResult { |
| 209 | + if (load == op) |
| 210 | + return mlir::WalkResult::interrupt(); |
| 211 | + |
| 212 | + if (load->getOperands() == op->getOperands() && |
| 213 | + load->getResultTypes() == op->getResultTypes() && |
| 214 | + dom.properlyDominates(load.getOperation(), op)) { |
| 215 | + prevLoad = load; |
| 216 | + return mlir::WalkResult::interrupt(); |
| 217 | + } |
| 218 | + |
| 219 | + return mlir::WalkResult::advance(); |
| 220 | + }); |
| 221 | + |
| 222 | + if (!prevLoad) |
| 223 | + return mlir::failure(); |
| 224 | + |
| 225 | + rewriter.replaceOp(op, prevLoad.getResult()); |
| 226 | + return mlir::success(); |
| 227 | + } |
| 228 | +}; |
188 | 229 | } // namespace
|
189 | 230 |
|
190 | 231 | void PlierUtilDialect::getCanonicalizationPatterns(
|
191 | 232 | mlir::RewritePatternSet &results) const {
|
192 | 233 | results.add<DimExpandShape<mlir::tensor::DimOp, mlir::tensor::ExpandShapeOp>,
|
193 | 234 | DimExpandShape<mlir::memref::DimOp, mlir::memref::ExpandShapeOp>,
|
194 |
| - DimInsertSlice, FillExtractSlice>(getContext()); |
| 235 | + DimInsertSlice, FillExtractSlice, SpirvInputCSE>(getContext()); |
195 | 236 | }
|
196 | 237 |
|
197 | 238 | OpaqueType OpaqueType::get(mlir::MLIRContext *context) {
|
|
0 commit comments