Skip to content

Commit 0c87039

Browse files
authored
CSE spirv Input loads (#172)
1 parent 8033a37 commit 0c87039

File tree

2 files changed

+72
-1
lines changed

2 files changed

+72
-1
lines changed

mlir/lib/dialect/plier_util/dialect.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <mlir/IR/BuiltinOps.h>
2020
#include <mlir/IR/BuiltinTypes.h>
2121
#include <mlir/IR/DialectImplementation.h>
22+
#include <mlir/IR/Dominance.h>
2223
#include <mlir/IR/PatternMatch.h>
2324
#include <mlir/Transforms/InliningUtils.h>
2425

@@ -28,6 +29,7 @@
2829
#include <mlir/Dialect/Linalg/IR/Linalg.h>
2930
#include <mlir/Dialect/MemRef/IR/MemRef.h>
3031
#include <mlir/Dialect/SCF/SCF.h>
32+
#include <mlir/Dialect/SPIRV/IR/SPIRVOps.h>
3133

3234
#include <llvm/ADT/SmallBitVector.h>
3335
#include <llvm/ADT/TypeSwitch.h>
@@ -185,13 +187,52 @@ struct FillExtractSlice
185187
return mlir::success();
186188
}
187189
};
190+
191+
struct SpirvInputCSE : public mlir::OpRewritePattern<mlir::spirv::LoadOp> {
192+
using OpRewritePattern::OpRewritePattern;
193+
194+
mlir::LogicalResult
195+
matchAndRewrite(mlir::spirv::LoadOp op,
196+
mlir::PatternRewriter &rewriter) const override {
197+
auto ptr = op.ptr();
198+
if (ptr.getType().cast<mlir::spirv::PointerType>().getStorageClass() !=
199+
mlir::spirv::StorageClass::Input)
200+
return mlir::failure();
201+
202+
auto func = op->getParentOfType<mlir::spirv::FuncOp>();
203+
if (!func)
204+
return mlir::failure();
205+
206+
mlir::DominanceInfo dom;
207+
mlir::spirv::LoadOp prevLoad;
208+
func->walk([&](mlir::spirv::LoadOp load) -> mlir::WalkResult {
209+
if (load == op)
210+
return mlir::WalkResult::interrupt();
211+
212+
if (load->getOperands() == op->getOperands() &&
213+
load->getResultTypes() == op->getResultTypes() &&
214+
dom.properlyDominates(load.getOperation(), op)) {
215+
prevLoad = load;
216+
return mlir::WalkResult::interrupt();
217+
}
218+
219+
return mlir::WalkResult::advance();
220+
});
221+
222+
if (!prevLoad)
223+
return mlir::failure();
224+
225+
rewriter.replaceOp(op, prevLoad.getResult());
226+
return mlir::success();
227+
}
228+
};
188229
} // namespace
189230

190231
void PlierUtilDialect::getCanonicalizationPatterns(
191232
mlir::RewritePatternSet &results) const {
192233
results.add<DimExpandShape<mlir::tensor::DimOp, mlir::tensor::ExpandShapeOp>,
193234
DimExpandShape<mlir::memref::DimOp, mlir::memref::ExpandShapeOp>,
194-
DimInsertSlice, FillExtractSlice>(getContext());
235+
DimInsertSlice, FillExtractSlice, SpirvInputCSE>(getContext());
195236
}
196237

197238
OpaqueType OpaqueType::get(mlir::MLIRContext *context) {

numba_dpcomp/numba_dpcomp/mlir/tests/test_gpu.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,36 @@ def func(a, b, c, res):
654654
assert_equal(gpu_res, sim_res)
655655

656656

657+
@require_gpu
658+
def test_input_load_cse():
659+
def func(c):
660+
i = get_global_id(0)
661+
j = get_global_id(1)
662+
k = get_global_id(2)
663+
c[i, j, k] = i + 10 * j + 100 * k
664+
665+
sim_func = kernel_sim(func)
666+
gpu_func = kernel_cached(func)
667+
668+
a = np.array([[[1, 2, 3], [4, 5, 6]]], np.float32)
669+
sim_res = np.zeros(a.shape, a.dtype)
670+
sim_func[a.shape, DEFAULT_LOCAL_SIZE](sim_res)
671+
672+
gpu_res = np.zeros(a.shape, a.dtype)
673+
674+
with print_pass_ir(["SerializeSPIRVPass"], []):
675+
gpu_func[a.shape, DEFAULT_LOCAL_SIZE](gpu_res)
676+
ir = get_print_buffer()
677+
assert (
678+
ir.count(
679+
'spv.Load "Input" %__builtin_var_GlobalInvocationId___addr : vector<3xi64>'
680+
)
681+
== 1
682+
), ir
683+
684+
assert_equal(gpu_res, sim_res)
685+
686+
657687
@require_dpctl
658688
def test_dpctl_simple1():
659689
def func(a, b, c):

0 commit comments

Comments
 (0)