Skip to content

Commit a47bbbe

Browse files
authored
Sink Global/Local sizes into kernel (#158)
1 parent f2d1eb0 commit a47bbbe

File tree

2 files changed

+112
-2
lines changed

2 files changed

+112
-2
lines changed

numba_dpcomp/numba_dpcomp/mlir/kernel_impl.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,12 @@ def _get_default_local_size():
7373

7474
@registry.register_func('_get_default_local_size', _get_default_local_size)
7575
def _get_default_local_size_impl(builder, *args):
76-
res = (0,0,0)
77-
return builder.external_call('get_default_local_size', inputs=args, outputs=res)
76+
index_type = builder.index
77+
i64 = builder.int64
78+
zero = builder.cast(0, index_type)
79+
res = (zero,zero,zero)
80+
res = builder.external_call('get_default_local_size', inputs=args, outputs=res)
81+
return tuple(builder.cast(r, i64) for r in res)
7882

7983
@infer_global(_get_default_local_size)
8084
class _GetDefaultLocalSizeId(ConcreteTemplate):

numba_dpcomp/numba_dpcomp/mlir_compiler/lib/pipelines/lower_to_gpu.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,37 @@ class ConvertFunc : public mlir::OpConversionPattern<mlir::FuncOp> {
12081208
}
12091209
};
12101210

1211+
template <typename SourceOp, mlir::spirv::BuiltIn builtin>
1212+
class LaunchConfigConversion : public mlir::OpConversionPattern<SourceOp> {
1213+
public:
1214+
LaunchConfigConversion(mlir::TypeConverter &typeConverter,
1215+
mlir::MLIRContext *context)
1216+
: mlir::OpConversionPattern<SourceOp>(typeConverter, context,
1217+
/*benefit*/ 100) {}
1218+
// using mlir::OpConversionPattern<SourceOp>::OpConversionPattern;
1219+
1220+
mlir::LogicalResult
1221+
matchAndRewrite(SourceOp op, typename SourceOp::Adaptor adaptor,
1222+
mlir::ConversionPatternRewriter &rewriter) const override;
1223+
};
1224+
1225+
template <typename SourceOp, mlir::spirv::BuiltIn builtin>
1226+
mlir::LogicalResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
1227+
SourceOp op, typename SourceOp::Adaptor adaptor,
1228+
mlir::ConversionPatternRewriter &rewriter) const {
1229+
auto *typeConverter =
1230+
this->template getTypeConverter<mlir::SPIRVTypeConverter>();
1231+
auto indexType = typeConverter->getIndexType();
1232+
1233+
// SPIR-V invocation builtin variables are a vector of type <3xi32>
1234+
auto spirvBuiltin =
1235+
mlir::spirv::getBuiltinVariableValue(op, builtin, indexType, rewriter);
1236+
rewriter.replaceOpWithNewOp<mlir::spirv::CompositeExtractOp>(
1237+
op, indexType, spirvBuiltin,
1238+
rewriter.getI32ArrayAttr({static_cast<int32_t>(op.dimension())}));
1239+
return mlir::success();
1240+
}
1241+
12111242
struct GPUToSpirvPass
12121243
: public mlir::PassWrapper<GPUToSpirvPass,
12131244
mlir::OperationPass<mlir::ModuleOp>> {
@@ -1262,6 +1293,10 @@ struct GPUToSpirvPass
12621293
ConvertStoreOp, ConvertAtomicOps, ConvertFunc>(typeConverter,
12631294
context);
12641295

1296+
patterns.insert<LaunchConfigConversion<
1297+
mlir::gpu::BlockDimOp, mlir::spirv::BuiltIn::WorkgroupSize>>(
1298+
typeConverter, context);
1299+
12651300
if (failed(
12661301
applyFullConversion(kernelModules, *target, std::move(patterns))))
12671302
return signalPassFailure();
@@ -2986,6 +3021,76 @@ class GpuLaunchSinkOpsPass
29863021
}
29873022
};
29883023

3024+
struct SinkGpuDims : public mlir::OpRewritePattern<mlir::gpu::LaunchOp> {
3025+
using OpRewritePattern::OpRewritePattern;
3026+
3027+
mlir::LogicalResult
3028+
matchAndRewrite(mlir::gpu::LaunchOp op,
3029+
mlir::PatternRewriter &rewriter) const override {
3030+
const mlir::Value dimArgs[] = {op.gridSizeX(), op.gridSizeY(),
3031+
op.gridSizeZ(), op.blockSizeX(),
3032+
op.blockSizeY(), op.blockSizeZ()};
3033+
llvm::SmallVector<std::pair<mlir::OpOperand *, unsigned>> uses;
3034+
for (auto it : llvm::enumerate(dimArgs)) {
3035+
auto i = static_cast<unsigned>(it.index());
3036+
auto addUse = [&](mlir::OpOperand &use) {
3037+
if (op->isProperAncestor(use.getOwner()))
3038+
uses.emplace_back(&use, i);
3039+
};
3040+
auto val = it.value();
3041+
for (auto &use : val.getUses())
3042+
addUse(use);
3043+
3044+
if (auto cast = val.getDefiningOp<mlir::arith::IndexCastOp>())
3045+
for (auto &use : cast.getIn().getUses())
3046+
addUse(use);
3047+
}
3048+
3049+
if (uses.empty())
3050+
return mlir::failure();
3051+
3052+
std::array<mlir::Value, 6> dims = {}; // TODO: static vector
3053+
3054+
auto loc = op->getLoc();
3055+
rewriter.setInsertionPointToStart(&op.body().front());
3056+
auto getDim = [&](unsigned i, mlir::Type type) -> mlir::Value {
3057+
assert(i < dims.size());
3058+
auto dim = dims[i];
3059+
if (!dim) {
3060+
if (i < 3) {
3061+
dim = rewriter.create<mlir::gpu::GridDimOp>(
3062+
loc, static_cast<mlir::gpu::Dimension>(i));
3063+
} else {
3064+
dim = rewriter.create<mlir::gpu::BlockDimOp>(
3065+
loc, static_cast<mlir::gpu::Dimension>(i - 3));
3066+
}
3067+
dims[i] = dim;
3068+
}
3069+
3070+
if (type != dim.getType())
3071+
dim = rewriter.create<mlir::arith::IndexCastOp>(loc, type, dim);
3072+
3073+
return dim;
3074+
};
3075+
3076+
for (auto it : uses) {
3077+
auto *use = it.first;
3078+
auto dim = it.second;
3079+
auto owner = use->getOwner();
3080+
rewriter.updateRootInPlace(owner, [&]() {
3081+
auto type = use->get().getType();
3082+
auto newVal = getDim(dim, type);
3083+
use->set(newVal);
3084+
});
3085+
}
3086+
3087+
return mlir::success();
3088+
}
3089+
};
3090+
3091+
struct SinkGpuDimsPass : public plier::RewriteWrapperPass<SinkGpuDimsPass, void,
3092+
void, SinkGpuDims> {};
3093+
29893094
static void commonOptPasses(mlir::OpPassManager &pm) {
29903095
pm.addPass(plier::createCommonOptsPass());
29913096
pm.addPass(mlir::createCSEPass());
@@ -3018,6 +3123,7 @@ static void populateLowerToGPUPipelineLow(mlir::OpPassManager &pm) {
30183123
commonOptPasses(funcPM);
30193124
funcPM.addPass(std::make_unique<KernelMemrefOpsMovementPass>());
30203125
funcPM.addPass(std::make_unique<GpuLaunchSinkOpsPass>());
3126+
funcPM.addPass(std::make_unique<SinkGpuDimsPass>());
30213127
pm.addPass(mlir::createGpuKernelOutliningPass());
30223128
pm.addPass(mlir::createSymbolDCEPass());
30233129

0 commit comments

Comments
 (0)