Skip to content

Commit 1637cca

Browse files
authored
[gpu] Refactor ConvertGpuKernelLaunchPattern and some preparations for dynamic size gpu local memory (#206)
1 parent fbe474d commit 1637cca

File tree

1 file changed

+39
-18
lines changed

1 file changed

+39
-18
lines changed

mlir/lib/Conversion/gpu_runtime_to_llvm.cpp

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -478,34 +478,55 @@ class ConvertGpuKernelLaunchPattern
478478
size, 0);
479479
});
480480

481-
auto getKernelParam = [&](unsigned i) -> mlir::Value {
482-
if (op.operands()[i].getType().isa<mlir::MemRefType>()) {
481+
mlir::Value one = rewriter.create<mlir::LLVM::ConstantOp>(
482+
loc, llvmInt32Type, rewriter.getI32IntegerAttr(1));
483+
auto localMemStorageClass = gpu_runtime::StorageClassAttr::get(
484+
getContext(), gpu_runtime::StorageClass::local);
485+
auto computeTypeSize = [&](mlir::Type type) -> mlir::Value {
486+
// %Size = getelementptr %T* null, int 1
487+
// %SizeI = ptrtoint %T* %Size to i32
488+
auto nullPtr = rewriter.create<mlir::LLVM::NullOp>(loc, type);
489+
auto gep = rewriter.create<mlir::LLVM::GEPOp>(loc, type, nullPtr, one);
490+
return rewriter.create<mlir::LLVM::PtrToIntOp>(loc, llvmIndexType, gep);
491+
};
492+
493+
auto getKernelParam =
494+
[&](unsigned i) -> std::pair<mlir::Value, mlir::Value> {
495+
auto memrefType = op.operands()[i].getType().dyn_cast<mlir::MemRefType>();
496+
auto paramType = paramsStorage[i].getType();
497+
if (memrefType) {
483498
mlir::MemRefDescriptor desc(kernelParams[i]);
484-
return desc.alignedPtr(rewriter, loc);
499+
if (memrefType.getMemorySpace() == localMemStorageClass) {
500+
auto rank = static_cast<unsigned>(memrefType.getRank());
501+
mlir::Value size = rewriter.create<mlir::LLVM::ConstantOp>(
502+
loc, llvmIndexType, rewriter.getIntegerAttr(llvmIndexType, 0));
503+
for (auto i : llvm::seq(0u, rank)) {
504+
auto dim = desc.size(rewriter, loc, i);
505+
size = rewriter.create<mlir::LLVM::MulOp>(loc, llvmIndexType, size,
506+
dim);
507+
}
508+
auto null = rewriter.create<mlir::LLVM::NullOp>(
509+
loc, desc.getElementPtrType());
510+
return {size, null};
511+
}
512+
auto size = computeTypeSize(paramType);
513+
return {size, desc.alignedPtr(rewriter, loc)};
485514
}
486515

487-
return kernelParams[i];
516+
auto size = computeTypeSize(paramType);
517+
return {size, kernelParams[i]};
488518
};
489519

490520
mlir::Value paramsArray =
491521
rewriter.create<mlir::LLVM::UndefOp>(loc, paramsArrayType);
492-
auto one = rewriter
493-
.create<mlir::LLVM::ConstantOp>(
494-
loc, llvmInt32Type, rewriter.getI32IntegerAttr(1))
495-
.getResult();
522+
496523
for (auto i : llvm::seq(0u, paramsCount)) {
497-
rewriter.create<mlir::LLVM::StoreOp>(loc, getKernelParam(i),
498-
paramsStorage[i]);
524+
auto param = getKernelParam(i);
525+
rewriter.create<mlir::LLVM::StoreOp>(loc, param.second, paramsStorage[i]);
499526
auto ptr = rewriter.create<mlir::LLVM::BitcastOp>(loc, llvmPointerType,
500527
paramsStorage[i]);
501-
// %Size = getelementptr %T* null, int 1
502-
// %SizeI = ptrtoint %T* %Size to i32
503-
auto paramPtrType = paramsStorage[i].getType();
504-
auto nullPtr = rewriter.create<mlir::LLVM::NullOp>(loc, paramPtrType);
505-
auto gep =
506-
rewriter.create<mlir::LLVM::GEPOp>(loc, paramPtrType, nullPtr, one);
507-
auto typeSize =
508-
rewriter.create<mlir::LLVM::PtrToIntOp>(loc, llvmIndexType, gep);
528+
529+
auto typeSize = param.first;
509530

510531
mlir::Value range =
511532
rewriter.create<mlir::LLVM::UndefOp>(loc, llvmRangeType);

0 commit comments

Comments
 (0)