Skip to content

Commit 38a7bcf

Browse files
authored
[gpu] Sink UndefOp into gpu kernel (#201)
1 parent 3549263 commit 38a7bcf

File tree

2 files changed

+22
-2
lines changed

2 files changed

+22
-2
lines changed

mlir/lib/Conversion/gpu_to_gpu_runtime.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,25 @@ class ConvertAssert : public mlir::OpConversionPattern<mlir::cf::AssertOp> {
971971
}
972972
};
973973

974+
class ConvertUndef : public mlir::OpConversionPattern<plier::UndefOp> {
975+
public:
976+
using OpConversionPattern::OpConversionPattern;
977+
978+
mlir::LogicalResult
979+
matchAndRewrite(plier::UndefOp op, plier::UndefOp::Adaptor adaptor,
980+
mlir::ConversionPatternRewriter &rewriter) const override {
981+
auto converter = getTypeConverter();
982+
assert(converter);
983+
984+
auto resType = converter->convertType(op.getType());
985+
if (!resType)
986+
return mlir::failure();
987+
988+
rewriter.replaceOpWithNewOp<mlir::spirv::UndefOp>(op, resType);
989+
return mlir::success();
990+
}
991+
};
992+
974993
struct GPUToSpirvPass
975994
: public mlir::PassWrapper<GPUToSpirvPass,
976995
mlir::OperationPass<mlir::ModuleOp>> {
@@ -1025,7 +1044,8 @@ struct GPUToSpirvPass
10251044
.insert<ConvertSubviewOp, ConvertCastOp<mlir::memref::CastOp>,
10261045
ConvertCastOp<mlir::memref::ReinterpretCastOp>, ConvertLoadOp,
10271046
ConvertStoreOp, ConvertAtomicOps, ConvertFunc, ConvertAssert,
1028-
ConvertBarrierOp, ConvertMemFenceOp>(typeConverter, context);
1047+
ConvertBarrierOp, ConvertMemFenceOp, ConvertUndef>(
1048+
typeConverter, context);
10291049

10301050
if (failed(
10311051
applyFullConversion(kernelModules, *target, std::move(patterns))))

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/lower_to_gpu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1213,7 +1213,7 @@ class GpuLaunchSinkOpsPass
12131213
auto isSinkingBeneficiary = [](mlir::Operation *op) -> bool {
12141214
return isa<arith::ConstantOp, func::ConstantOp, arith::SelectOp,
12151215
arith::CmpIOp, arith::IndexCastOp, arith::MulIOp,
1216-
arith::SubIOp, arith::AddIOp>(op);
1216+
arith::SubIOp, arith::AddIOp, plier::UndefOp>(op);
12171217
};
12181218

12191219
// Pull in instructions that can be sunk

0 commit comments

Comments
 (0)