@@ -1208,6 +1208,37 @@ class ConvertFunc : public mlir::OpConversionPattern<mlir::FuncOp> {
1208
1208
}
1209
1209
};
1210
1210
1211
+ template <typename SourceOp, mlir::spirv::BuiltIn builtin>
1212
+ class LaunchConfigConversion : public mlir ::OpConversionPattern<SourceOp> {
1213
+ public:
1214
+ LaunchConfigConversion (mlir::TypeConverter &typeConverter,
1215
+ mlir::MLIRContext *context)
1216
+ : mlir::OpConversionPattern<SourceOp>(typeConverter, context,
1217
+ /* benefit*/ 100 ) {}
1218
+ // using mlir::OpConversionPattern<SourceOp>::OpConversionPattern;
1219
+
1220
+ mlir::LogicalResult
1221
+ matchAndRewrite (SourceOp op, typename SourceOp::Adaptor adaptor,
1222
+ mlir::ConversionPatternRewriter &rewriter) const override ;
1223
+ };
1224
+
1225
+ template <typename SourceOp, mlir::spirv::BuiltIn builtin>
1226
+ mlir::LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
1227
+ SourceOp op, typename SourceOp::Adaptor adaptor,
1228
+ mlir::ConversionPatternRewriter &rewriter) const {
1229
+ auto *typeConverter =
1230
+ this ->template getTypeConverter <mlir::SPIRVTypeConverter>();
1231
+ auto indexType = typeConverter->getIndexType ();
1232
+
1233
+ // SPIR-V invocation builtin variables are a vector of type <3xi32>
1234
+ auto spirvBuiltin =
1235
+ mlir::spirv::getBuiltinVariableValue (op, builtin, indexType, rewriter);
1236
+ rewriter.replaceOpWithNewOp <mlir::spirv::CompositeExtractOp>(
1237
+ op, indexType, spirvBuiltin,
1238
+ rewriter.getI32ArrayAttr ({static_cast <int32_t >(op.dimension ())}));
1239
+ return mlir::success ();
1240
+ }
1241
+
1211
1242
struct GPUToSpirvPass
1212
1243
: public mlir::PassWrapper<GPUToSpirvPass,
1213
1244
mlir::OperationPass<mlir::ModuleOp>> {
@@ -1262,6 +1293,10 @@ struct GPUToSpirvPass
1262
1293
ConvertStoreOp, ConvertAtomicOps, ConvertFunc>(typeConverter,
1263
1294
context);
1264
1295
1296
+ patterns.insert <LaunchConfigConversion<
1297
+ mlir::gpu::BlockDimOp, mlir::spirv::BuiltIn::WorkgroupSize>>(
1298
+ typeConverter, context);
1299
+
1265
1300
if (failed (
1266
1301
applyFullConversion (kernelModules, *target, std::move (patterns))))
1267
1302
return signalPassFailure ();
@@ -2986,6 +3021,76 @@ class GpuLaunchSinkOpsPass
2986
3021
}
2987
3022
};
2988
3023
3024
+ struct SinkGpuDims : public mlir ::OpRewritePattern<mlir::gpu::LaunchOp> {
3025
+ using OpRewritePattern::OpRewritePattern;
3026
+
3027
+ mlir::LogicalResult
3028
+ matchAndRewrite (mlir::gpu::LaunchOp op,
3029
+ mlir::PatternRewriter &rewriter) const override {
3030
+ const mlir::Value dimArgs[] = {op.gridSizeX (), op.gridSizeY (),
3031
+ op.gridSizeZ (), op.blockSizeX (),
3032
+ op.blockSizeY (), op.blockSizeZ ()};
3033
+ llvm::SmallVector<std::pair<mlir::OpOperand *, unsigned >> uses;
3034
+ for (auto it : llvm::enumerate (dimArgs)) {
3035
+ auto i = static_cast <unsigned >(it.index ());
3036
+ auto addUse = [&](mlir::OpOperand &use) {
3037
+ if (op->isProperAncestor (use.getOwner ()))
3038
+ uses.emplace_back (&use, i);
3039
+ };
3040
+ auto val = it.value ();
3041
+ for (auto &use : val.getUses ())
3042
+ addUse (use);
3043
+
3044
+ if (auto cast = val.getDefiningOp <mlir::arith::IndexCastOp>())
3045
+ for (auto &use : cast.getIn ().getUses ())
3046
+ addUse (use);
3047
+ }
3048
+
3049
+ if (uses.empty ())
3050
+ return mlir::failure ();
3051
+
3052
+ std::array<mlir::Value, 6 > dims = {}; // TODO: static vector
3053
+
3054
+ auto loc = op->getLoc ();
3055
+ rewriter.setInsertionPointToStart (&op.body ().front ());
3056
+ auto getDim = [&](unsigned i, mlir::Type type) -> mlir::Value {
3057
+ assert (i < dims.size ());
3058
+ auto dim = dims[i];
3059
+ if (!dim) {
3060
+ if (i < 3 ) {
3061
+ dim = rewriter.create <mlir::gpu::GridDimOp>(
3062
+ loc, static_cast <mlir::gpu::Dimension>(i));
3063
+ } else {
3064
+ dim = rewriter.create <mlir::gpu::BlockDimOp>(
3065
+ loc, static_cast <mlir::gpu::Dimension>(i - 3 ));
3066
+ }
3067
+ dims[i] = dim;
3068
+ }
3069
+
3070
+ if (type != dim.getType ())
3071
+ dim = rewriter.create <mlir::arith::IndexCastOp>(loc, type, dim);
3072
+
3073
+ return dim;
3074
+ };
3075
+
3076
+ for (auto it : uses) {
3077
+ auto *use = it.first ;
3078
+ auto dim = it.second ;
3079
+ auto owner = use->getOwner ();
3080
+ rewriter.updateRootInPlace (owner, [&]() {
3081
+ auto type = use->get ().getType ();
3082
+ auto newVal = getDim (dim, type);
3083
+ use->set (newVal);
3084
+ });
3085
+ }
3086
+
3087
+ return mlir::success ();
3088
+ }
3089
+ };
3090
+
3091
+ struct SinkGpuDimsPass : public plier ::RewriteWrapperPass<SinkGpuDimsPass, void ,
3092
+ void , SinkGpuDims> {};
3093
+
2989
3094
static void commonOptPasses (mlir::OpPassManager &pm) {
2990
3095
pm.addPass (plier::createCommonOptsPass ());
2991
3096
pm.addPass (mlir::createCSEPass ());
@@ -3018,6 +3123,7 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
3018
3123
commonOptPasses (funcPM);
3019
3124
funcPM.addPass (std::make_unique<KernelMemrefOpsMovementPass>());
3020
3125
funcPM.addPass (std::make_unique<GpuLaunchSinkOpsPass>());
3126
+ funcPM.addPass (std::make_unique<SinkGpuDimsPass>());
3021
3127
pm.addPass (mlir::createGpuKernelOutliningPass ());
3022
3128
pm.addPass (mlir::createSymbolDCEPass ());
3023
3129
0 commit comments