Skip to content

Commit 48d72fe

Browse files
authored
Assume gpu iteration range is less than int32_t max (#175)
* Add pass to insert corresponding cf::AssertOp * Convert cf::AssertOp to spirv::AssumeTrueKHROp
1 parent 76c586c commit 48d72fe

File tree

1 file changed

+62
-7
lines changed

1 file changed

+62
-7
lines changed

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/lower_to_gpu.cpp

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h>
3232
#include <mlir/Dialect/Affine/IR/AffineOps.h>
3333
#include <mlir/Dialect/Arithmetic/Transforms/Passes.h>
34+
#include <mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
3435
#include <mlir/Dialect/Func/IR/FuncOps.h>
3536
#include <mlir/Dialect/Func/Transforms/FuncConversions.h>
3637
#include <mlir/Dialect/GPU/ParallelLoopMapper.h>
@@ -859,6 +860,44 @@ struct KernelMemrefOpsMovementPass
859860
}
860861
};
861862

863+
struct AssumeGpuIdRangePass
864+
: public mlir::PassWrapper<AssumeGpuIdRangePass,
865+
mlir::OperationPass<void>> {
866+
virtual void
867+
getDependentDialects(mlir::DialectRegistry &registry) const override {
868+
registry.insert<mlir::arith::ArithmeticDialect>();
869+
registry.insert<mlir::cf::ControlFlowDialect>();
870+
registry.insert<mlir::gpu::GPUDialect>();
871+
}
872+
873+
void runOnOperation() override {
874+
auto op = getOperation();
875+
876+
mlir::OpBuilder builder(&getContext());
877+
builder.setInsertionPointToStart(&op->getRegion(0).front());
878+
auto maxInt = builder
879+
.create<mlir::arith::ConstantIndexOp>(
880+
builder.getUnknownLoc(),
881+
std::numeric_limits<int32_t>::max() + 1)
882+
.getResult();
883+
884+
op->walk([&](mlir::Operation *nestedOp) {
885+
if (!mlir::isa<mlir::gpu::ThreadIdOp, mlir::gpu::BlockIdOp,
886+
mlir::gpu::GlobalIdOp>(nestedOp))
887+
return;
888+
889+
assert(nestedOp->getNumResults() == 1);
890+
auto res = nestedOp->getResult(0);
891+
assert(res.getType().isa<mlir::IndexType>());
892+
builder.setInsertionPointAfter(nestedOp);
893+
auto loc = op->getLoc();
894+
auto cmp = builder.create<mlir::arith::CmpIOp>(
895+
loc, mlir::arith::CmpIPredicate::slt, res, maxInt);
896+
builder.create<mlir::cf::AssertOp>(loc, cmp, "Invalid gpu id range");
897+
});
898+
}
899+
};
900+
862901
struct AbiAttrsPass
863902
: public mlir::PassWrapper<AbiAttrsPass,
864903
mlir::OperationPass<mlir::gpu::GPUModuleOp>> {
@@ -904,10 +943,12 @@ struct SetSPIRVCapabilitiesPass
904943
spirv::Capability::Float16,
905944
spirv::Capability::Float64,
906945
spirv::Capability::AtomicFloat32AddEXT,
946+
spirv::Capability::ExpectAssumeKHR,
907947
// clang-format on
908948
};
909949
spirv::Extension exts[] = {
910-
spirv::Extension::SPV_EXT_shader_atomic_float_add};
950+
spirv::Extension::SPV_EXT_shader_atomic_float_add,
951+
spirv::Extension::SPV_KHR_expect_assume};
911952
auto triple =
912953
spirv::VerCapExtAttr::get(spirv::Version::V_1_0, caps, exts, context);
913954
auto attr = spirv::TargetEnvAttr::get(
@@ -987,7 +1028,7 @@ static llvm::Optional<mlir::Value> getGpuStream(mlir::OpBuilder &builder,
9871028
class ConvertSubviewOp
9881029
: public mlir::OpConversionPattern<mlir::memref::SubViewOp> {
9891030
public:
990-
using mlir::OpConversionPattern<mlir::memref::SubViewOp>::OpConversionPattern;
1031+
using OpConversionPattern::OpConversionPattern;
9911032

9921033
mlir::LogicalResult
9931034
matchAndRewrite(mlir::memref::SubViewOp op,
@@ -1046,7 +1087,7 @@ class ConvertCastOp : public mlir::OpConversionPattern<T> {
10461087

10471088
class ConvertLoadOp : public mlir::OpConversionPattern<mlir::memref::LoadOp> {
10481089
public:
1049-
using mlir::OpConversionPattern<mlir::memref::LoadOp>::OpConversionPattern;
1090+
using OpConversionPattern::OpConversionPattern;
10501091

10511092
mlir::LogicalResult
10521093
matchAndRewrite(mlir::memref::LoadOp op,
@@ -1073,7 +1114,7 @@ class ConvertLoadOp : public mlir::OpConversionPattern<mlir::memref::LoadOp> {
10731114

10741115
class ConvertStoreOp : public mlir::OpConversionPattern<mlir::memref::StoreOp> {
10751116
public:
1076-
using mlir::OpConversionPattern<mlir::memref::StoreOp>::OpConversionPattern;
1117+
using OpConversionPattern::OpConversionPattern;
10771118

10781119
mlir::LogicalResult
10791120
matchAndRewrite(mlir::memref::StoreOp op,
@@ -1195,7 +1236,7 @@ class ConvertAtomicOps : public mlir::OpConversionPattern<mlir::func::CallOp> {
11951236
// TODO: something better
11961237
class ConvertFunc : public mlir::OpConversionPattern<mlir::FuncOp> {
11971238
public:
1198-
using mlir::OpConversionPattern<mlir::FuncOp>::OpConversionPattern;
1239+
using OpConversionPattern::OpConversionPattern;
11991240

12001241
mlir::LogicalResult
12011242
matchAndRewrite(mlir::FuncOp op, mlir::FuncOp::Adaptor /*adaptor*/,
@@ -1208,6 +1249,19 @@ class ConvertFunc : public mlir::OpConversionPattern<mlir::FuncOp> {
12081249
}
12091250
};
12101251

1252+
class ConvertAssert : public mlir::OpConversionPattern<mlir::cf::AssertOp> {
1253+
public:
1254+
using OpConversionPattern::OpConversionPattern;
1255+
1256+
mlir::LogicalResult
1257+
matchAndRewrite(mlir::cf::AssertOp op, mlir::cf::AssertOp::Adaptor adaptor,
1258+
mlir::ConversionPatternRewriter &rewriter) const override {
1259+
rewriter.replaceOpWithNewOp<mlir::spirv::AssumeTrueKHROp>(op,
1260+
adaptor.getArg());
1261+
return mlir::success();
1262+
}
1263+
};
1264+
12111265
struct GPUToSpirvPass
12121266
: public mlir::PassWrapper<GPUToSpirvPass,
12131267
mlir::OperationPass<mlir::ModuleOp>> {
@@ -1259,8 +1313,8 @@ struct GPUToSpirvPass
12591313
patterns
12601314
.insert<ConvertSubviewOp, ConvertCastOp<mlir::memref::CastOp>,
12611315
ConvertCastOp<mlir::memref::ReinterpretCastOp>, ConvertLoadOp,
1262-
ConvertStoreOp, ConvertAtomicOps, ConvertFunc>(typeConverter,
1263-
context);
1316+
ConvertStoreOp, ConvertAtomicOps, ConvertFunc, ConvertAssert>(
1317+
typeConverter, context);
12641318

12651319
if (failed(
12661320
applyFullConversion(kernelModules, *target, std::move(patterns))))
@@ -3106,6 +3160,7 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
31063160
gpuFuncPM.addPass(mlir::arith::createArithmeticExpandOpsPass());
31073161
gpuFuncPM.addPass(std::make_unique<FlattenScfPass>());
31083162
commonOptPasses(gpuFuncPM);
3163+
gpuFuncPM.addPass(std::make_unique<AssumeGpuIdRangePass>());
31093164

31103165
pm.addNestedPass<mlir::gpu::GPUModuleOp>(std::make_unique<AbiAttrsPass>());
31113166
pm.addPass(std::make_unique<SetSPIRVCapabilitiesPass>());

0 commit comments

Comments
 (0)