@@ -478,34 +478,55 @@ class ConvertGpuKernelLaunchPattern
478
478
size, 0 );
479
479
});
480
480
481
- auto getKernelParam = [&](unsigned i) -> mlir::Value {
482
- if (op.operands ()[i].getType ().isa <mlir::MemRefType>()) {
481
+ mlir::Value one = rewriter.create <mlir::LLVM::ConstantOp>(
482
+ loc, llvmInt32Type, rewriter.getI32IntegerAttr (1 ));
483
+ auto localMemStorageClass = gpu_runtime::StorageClassAttr::get (
484
+ getContext (), gpu_runtime::StorageClass::local);
485
+ auto computeTypeSize = [&](mlir::Type type) -> mlir::Value {
486
+ // %Size = getelementptr %T* null, int 1
487
+ // %SizeI = ptrtoint %T* %Size to i32
488
+ auto nullPtr = rewriter.create <mlir::LLVM::NullOp>(loc, type);
489
+ auto gep = rewriter.create <mlir::LLVM::GEPOp>(loc, type, nullPtr, one);
490
+ return rewriter.create <mlir::LLVM::PtrToIntOp>(loc, llvmIndexType, gep);
491
+ };
492
+
493
+ auto getKernelParam =
494
+ [&](unsigned i) -> std::pair<mlir::Value, mlir::Value> {
495
+ auto memrefType = op.operands ()[i].getType ().dyn_cast <mlir::MemRefType>();
496
+ auto paramType = paramsStorage[i].getType ();
497
+ if (memrefType) {
483
498
mlir::MemRefDescriptor desc (kernelParams[i]);
484
- return desc.alignedPtr (rewriter, loc);
499
+ if (memrefType.getMemorySpace () == localMemStorageClass) {
500
+ auto rank = static_cast <unsigned >(memrefType.getRank ());
501
+ mlir::Value size = rewriter.create <mlir::LLVM::ConstantOp>(
502
+ loc, llvmIndexType, rewriter.getIntegerAttr (llvmIndexType, 0 ));
503
+ for (auto i : llvm::seq (0u , rank)) {
504
+ auto dim = desc.size (rewriter, loc, i);
505
+ size = rewriter.create <mlir::LLVM::MulOp>(loc, llvmIndexType, size,
506
+ dim);
507
+ }
508
+ auto null = rewriter.create <mlir::LLVM::NullOp>(
509
+ loc, desc.getElementPtrType ());
510
+ return {size, null};
511
+ }
512
+ auto size = computeTypeSize (paramType);
513
+ return {size, desc.alignedPtr (rewriter, loc)};
485
514
}
486
515
487
- return kernelParams[i];
516
+ auto size = computeTypeSize (paramType);
517
+ return {size, kernelParams[i]};
488
518
};
489
519
490
520
mlir::Value paramsArray =
491
521
rewriter.create <mlir::LLVM::UndefOp>(loc, paramsArrayType);
492
- auto one = rewriter
493
- .create <mlir::LLVM::ConstantOp>(
494
- loc, llvmInt32Type, rewriter.getI32IntegerAttr (1 ))
495
- .getResult ();
522
+
496
523
for (auto i : llvm::seq (0u , paramsCount)) {
497
- rewriter. create <mlir::LLVM::StoreOp>(loc, getKernelParam (i),
498
- paramsStorage[i]);
524
+ auto param = getKernelParam (i);
525
+ rewriter. create <mlir::LLVM::StoreOp>(loc, param. second , paramsStorage[i]);
499
526
auto ptr = rewriter.create <mlir::LLVM::BitcastOp>(loc, llvmPointerType,
500
527
paramsStorage[i]);
501
- // %Size = getelementptr %T* null, int 1
502
- // %SizeI = ptrtoint %T* %Size to i32
503
- auto paramPtrType = paramsStorage[i].getType ();
504
- auto nullPtr = rewriter.create <mlir::LLVM::NullOp>(loc, paramPtrType);
505
- auto gep =
506
- rewriter.create <mlir::LLVM::GEPOp>(loc, paramPtrType, nullPtr, one);
507
- auto typeSize =
508
- rewriter.create <mlir::LLVM::PtrToIntOp>(loc, llvmIndexType, gep);
528
+
529
+ auto typeSize = param.first ;
509
530
510
531
mlir::Value range =
511
532
rewriter.create <mlir::LLVM::UndefOp>(loc, llvmRangeType);
0 commit comments