31
31
#include < mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h>
32
32
#include < mlir/Dialect/Affine/IR/AffineOps.h>
33
33
#include < mlir/Dialect/Arithmetic/Transforms/Passes.h>
34
+ #include < mlir/Dialect/ControlFlow/IR/ControlFlowOps.h>
34
35
#include < mlir/Dialect/Func/IR/FuncOps.h>
35
36
#include < mlir/Dialect/Func/Transforms/FuncConversions.h>
36
37
#include < mlir/Dialect/GPU/ParallelLoopMapper.h>
@@ -859,6 +860,44 @@ struct KernelMemrefOpsMovementPass
859
860
}
860
861
};
861
862
863
+ struct AssumeGpuIdRangePass
864
+ : public mlir::PassWrapper<AssumeGpuIdRangePass,
865
+ mlir::OperationPass<void >> {
866
+ virtual void
867
+ getDependentDialects (mlir::DialectRegistry ®istry) const override {
868
+ registry.insert <mlir::arith::ArithmeticDialect>();
869
+ registry.insert <mlir::cf::ControlFlowDialect>();
870
+ registry.insert <mlir::gpu::GPUDialect>();
871
+ }
872
+
873
+ void runOnOperation () override {
874
+ auto op = getOperation ();
875
+
876
+ mlir::OpBuilder builder (&getContext ());
877
+ builder.setInsertionPointToStart (&op->getRegion (0 ).front ());
878
+ auto maxInt = builder
879
+ .create <mlir::arith::ConstantIndexOp>(
880
+ builder.getUnknownLoc (),
881
+ std::numeric_limits<int32_t >::max () + 1 )
882
+ .getResult ();
883
+
884
+ op->walk ([&](mlir::Operation *nestedOp) {
885
+ if (!mlir::isa<mlir::gpu::ThreadIdOp, mlir::gpu::BlockIdOp,
886
+ mlir::gpu::GlobalIdOp>(nestedOp))
887
+ return ;
888
+
889
+ assert (nestedOp->getNumResults () == 1 );
890
+ auto res = nestedOp->getResult (0 );
891
+ assert (res.getType ().isa <mlir::IndexType>());
892
+ builder.setInsertionPointAfter (nestedOp);
893
+ auto loc = op->getLoc ();
894
+ auto cmp = builder.create <mlir::arith::CmpIOp>(
895
+ loc, mlir::arith::CmpIPredicate::slt, res, maxInt);
896
+ builder.create <mlir::cf::AssertOp>(loc, cmp, " Invalid gpu id range" );
897
+ });
898
+ }
899
+ };
900
+
862
901
struct AbiAttrsPass
863
902
: public mlir::PassWrapper<AbiAttrsPass,
864
903
mlir::OperationPass<mlir::gpu::GPUModuleOp>> {
@@ -904,10 +943,12 @@ struct SetSPIRVCapabilitiesPass
904
943
spirv::Capability::Float16,
905
944
spirv::Capability::Float64,
906
945
spirv::Capability::AtomicFloat32AddEXT,
946
+ spirv::Capability::ExpectAssumeKHR,
907
947
// clang-format on
908
948
};
909
949
spirv::Extension exts[] = {
910
- spirv::Extension::SPV_EXT_shader_atomic_float_add};
950
+ spirv::Extension::SPV_EXT_shader_atomic_float_add,
951
+ spirv::Extension::SPV_KHR_expect_assume};
911
952
auto triple =
912
953
spirv::VerCapExtAttr::get (spirv::Version::V_1_0, caps, exts, context);
913
954
auto attr = spirv::TargetEnvAttr::get (
@@ -987,7 +1028,7 @@ static llvm::Optional<mlir::Value> getGpuStream(mlir::OpBuilder &builder,
987
1028
class ConvertSubviewOp
988
1029
: public mlir::OpConversionPattern<mlir::memref::SubViewOp> {
989
1030
public:
990
- using mlir:: OpConversionPattern<mlir::memref::SubViewOp> ::OpConversionPattern;
1031
+ using OpConversionPattern::OpConversionPattern;
991
1032
992
1033
mlir::LogicalResult
993
1034
matchAndRewrite (mlir::memref::SubViewOp op,
@@ -1046,7 +1087,7 @@ class ConvertCastOp : public mlir::OpConversionPattern<T> {
1046
1087
1047
1088
class ConvertLoadOp : public mlir ::OpConversionPattern<mlir::memref::LoadOp> {
1048
1089
public:
1049
- using mlir:: OpConversionPattern<mlir::memref::LoadOp> ::OpConversionPattern;
1090
+ using OpConversionPattern::OpConversionPattern;
1050
1091
1051
1092
mlir::LogicalResult
1052
1093
matchAndRewrite (mlir::memref::LoadOp op,
@@ -1073,7 +1114,7 @@ class ConvertLoadOp : public mlir::OpConversionPattern<mlir::memref::LoadOp> {
1073
1114
1074
1115
class ConvertStoreOp : public mlir ::OpConversionPattern<mlir::memref::StoreOp> {
1075
1116
public:
1076
- using mlir:: OpConversionPattern<mlir::memref::StoreOp> ::OpConversionPattern;
1117
+ using OpConversionPattern::OpConversionPattern;
1077
1118
1078
1119
mlir::LogicalResult
1079
1120
matchAndRewrite (mlir::memref::StoreOp op,
@@ -1195,7 +1236,7 @@ class ConvertAtomicOps : public mlir::OpConversionPattern<mlir::func::CallOp> {
1195
1236
// TODO: something better
1196
1237
class ConvertFunc : public mlir ::OpConversionPattern<mlir::FuncOp> {
1197
1238
public:
1198
- using mlir:: OpConversionPattern<mlir::FuncOp> ::OpConversionPattern;
1239
+ using OpConversionPattern::OpConversionPattern;
1199
1240
1200
1241
mlir::LogicalResult
1201
1242
matchAndRewrite (mlir::FuncOp op, mlir::FuncOp::Adaptor /* adaptor*/ ,
@@ -1208,6 +1249,19 @@ class ConvertFunc : public mlir::OpConversionPattern<mlir::FuncOp> {
1208
1249
}
1209
1250
};
1210
1251
1252
+ class ConvertAssert : public mlir ::OpConversionPattern<mlir::cf::AssertOp> {
1253
+ public:
1254
+ using OpConversionPattern::OpConversionPattern;
1255
+
1256
+ mlir::LogicalResult
1257
+ matchAndRewrite (mlir::cf::AssertOp op, mlir::cf::AssertOp::Adaptor adaptor,
1258
+ mlir::ConversionPatternRewriter &rewriter) const override {
1259
+ rewriter.replaceOpWithNewOp <mlir::spirv::AssumeTrueKHROp>(op,
1260
+ adaptor.getArg ());
1261
+ return mlir::success ();
1262
+ }
1263
+ };
1264
+
1211
1265
struct GPUToSpirvPass
1212
1266
: public mlir::PassWrapper<GPUToSpirvPass,
1213
1267
mlir::OperationPass<mlir::ModuleOp>> {
@@ -1259,8 +1313,8 @@ struct GPUToSpirvPass
1259
1313
patterns
1260
1314
.insert <ConvertSubviewOp, ConvertCastOp<mlir::memref::CastOp>,
1261
1315
ConvertCastOp<mlir::memref::ReinterpretCastOp>, ConvertLoadOp,
1262
- ConvertStoreOp, ConvertAtomicOps, ConvertFunc>(typeConverter,
1263
- context);
1316
+ ConvertStoreOp, ConvertAtomicOps, ConvertFunc, ConvertAssert>(
1317
+ typeConverter, context);
1264
1318
1265
1319
if (failed (
1266
1320
applyFullConversion (kernelModules, *target, std::move (patterns))))
@@ -3106,6 +3160,7 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
3106
3160
gpuFuncPM.addPass (mlir::arith::createArithmeticExpandOpsPass ());
3107
3161
gpuFuncPM.addPass (std::make_unique<FlattenScfPass>());
3108
3162
commonOptPasses (gpuFuncPM);
3163
+ gpuFuncPM.addPass (std::make_unique<AssumeGpuIdRangePass>());
3109
3164
3110
3165
pm.addNestedPass <mlir::gpu::GPUModuleOp>(std::make_unique<AbiAttrsPass>());
3111
3166
pm.addPass (std::make_unique<SetSPIRVCapabilitiesPass>());
0 commit comments