@@ -76,11 +76,6 @@ struct GPULaunchKernelConversion
7676 mlir::LogicalResult
7777 matchAndRewrite (mlir::gpu::LaunchFuncOp op, OpAdaptor adaptor,
7878 mlir::ConversionPatternRewriter &rewriter) const override {
79-
80- if (op.hasClusterSize ()) {
81- return mlir::failure ();
82- }
83-
8479 mlir::Location loc = op.getLoc ();
8580 auto *ctx = rewriter.getContext ();
8681 mlir::ModuleOp mod = op->getParentOfType <mlir::ModuleOp>();
@@ -107,37 +102,65 @@ struct GPULaunchKernelConversion
107102 rewriter.create <LLVM::AddressOfOp>(loc, ptrTy, kernel.getName ());
108103 }
109104
110- auto funcOp = mod.lookupSymbol <mlir::LLVM::LLVMFuncOp>(
111- RTNAME_STRING (CUFLaunchKernel));
112-
113105 auto llvmIntPtrType = mlir::IntegerType::get (
114106 ctx, this ->getTypeConverter ()->getPointerBitwidth (0 ));
115107 auto voidTy = mlir::LLVM::LLVMVoidType::get (ctx);
116- auto funcTy = mlir::LLVM::LLVMFunctionType::get (
117- voidTy,
118- {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
119- llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
120- /* isVarArg=*/ false );
121-
122- auto cufLaunchKernel = mlir::SymbolRefAttr::get (
123- mod.getContext (), RTNAME_STRING (CUFLaunchKernel));
124- if (!funcOp) {
125- mlir::OpBuilder::InsertionGuard insertGuard (rewriter);
126- rewriter.setInsertionPointToStart (mod.getBody ());
127- auto launchKernelFuncOp = rewriter.create <mlir::LLVM::LLVMFuncOp>(
128- loc, RTNAME_STRING (CUFLaunchKernel), funcTy);
129- launchKernelFuncOp.setVisibility (mlir::SymbolTable::Visibility::Private);
130- }
131108
132109 mlir::Value nullPtr = rewriter.create <LLVM::ZeroOp>(loc, ptrTy);
133110
134- rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
135- op, funcTy, cufLaunchKernel,
136- mlir::ValueRange{kernelPtr, adaptor.getGridSizeX (),
137- adaptor.getGridSizeY (), adaptor.getGridSizeZ (),
138- adaptor.getBlockSizeX (), adaptor.getBlockSizeY (),
139- adaptor.getBlockSizeZ (), dynamicMemorySize, kernelArgs,
140- nullPtr});
111+ if (op.hasClusterSize ()) {
112+ auto funcOp = mod.lookupSymbol <mlir::LLVM::LLVMFuncOp>(
113+ RTNAME_STRING (CUFLaunchClusterKernel));
114+ auto funcTy = mlir::LLVM::LLVMFunctionType::get (
115+ voidTy,
116+ {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
117+ llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
118+ llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
119+ /* isVarArg=*/ false );
120+ auto cufLaunchClusterKernel = mlir::SymbolRefAttr::get (
121+ mod.getContext (), RTNAME_STRING (CUFLaunchClusterKernel));
122+ if (!funcOp) {
123+ mlir::OpBuilder::InsertionGuard insertGuard (rewriter);
124+ rewriter.setInsertionPointToStart (mod.getBody ());
125+ auto launchKernelFuncOp = rewriter.create <mlir::LLVM::LLVMFuncOp>(
126+ loc, RTNAME_STRING (CUFLaunchClusterKernel), funcTy);
127+ launchKernelFuncOp.setVisibility (
128+ mlir::SymbolTable::Visibility::Private);
129+ }
130+ rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
131+ op, funcTy, cufLaunchClusterKernel,
132+ mlir::ValueRange{kernelPtr, adaptor.getClusterSizeX (),
133+ adaptor.getClusterSizeY (), adaptor.getClusterSizeZ (),
134+ adaptor.getGridSizeX (), adaptor.getGridSizeY (),
135+ adaptor.getGridSizeZ (), adaptor.getBlockSizeX (),
136+ adaptor.getBlockSizeY (), adaptor.getBlockSizeZ (),
137+ dynamicMemorySize, kernelArgs, nullPtr});
138+ } else {
139+ auto funcOp = mod.lookupSymbol <mlir::LLVM::LLVMFuncOp>(
140+ RTNAME_STRING (CUFLaunchKernel));
141+ auto funcTy = mlir::LLVM::LLVMFunctionType::get (
142+ voidTy,
143+ {ptrTy, llvmIntPtrType, llvmIntPtrType, llvmIntPtrType,
144+ llvmIntPtrType, llvmIntPtrType, llvmIntPtrType, i32Ty, ptrTy, ptrTy},
145+ /* isVarArg=*/ false );
146+ auto cufLaunchKernel = mlir::SymbolRefAttr::get (
147+ mod.getContext (), RTNAME_STRING (CUFLaunchKernel));
148+ if (!funcOp) {
149+ mlir::OpBuilder::InsertionGuard insertGuard (rewriter);
150+ rewriter.setInsertionPointToStart (mod.getBody ());
151+ auto launchKernelFuncOp = rewriter.create <mlir::LLVM::LLVMFuncOp>(
152+ loc, RTNAME_STRING (CUFLaunchKernel), funcTy);
153+ launchKernelFuncOp.setVisibility (
154+ mlir::SymbolTable::Visibility::Private);
155+ }
156+ rewriter.replaceOpWithNewOp <mlir::LLVM::CallOp>(
157+ op, funcTy, cufLaunchKernel,
158+ mlir::ValueRange{kernelPtr, adaptor.getGridSizeX (),
159+ adaptor.getGridSizeY (), adaptor.getGridSizeZ (),
160+ adaptor.getBlockSizeX (), adaptor.getBlockSizeY (),
161+ adaptor.getBlockSizeZ (), dynamicMemorySize,
162+ kernelArgs, nullPtr});
163+ }
141164
142165 return mlir::success ();
143166 }
0 commit comments