77// ===----------------------------------------------------------------------===//
88
99#include " flang/Optimizer/Transforms/CUFGPUToLLVMConversion.h"
10+ #include " flang/Optimizer/Builder/CUFCommon.h"
1011#include " flang/Optimizer/CodeGen/TypeConverter.h"
12+ #include " flang/Optimizer/Dialect/CUF/CUFOps.h"
1113#include " flang/Optimizer/Support/DataLayout.h"
1214#include " flang/Runtime/CUDA/common.h"
1315#include " flang/Support/Fortran.h"
1416#include " mlir/Conversion/LLVMCommon/Pattern.h"
1517#include " mlir/Dialect/GPU/IR/GPUDialect.h"
18+ #include " mlir/Dialect/LLVMIR/NVVMDialect.h"
1619#include " mlir/Pass/Pass.h"
1720#include " mlir/Transforms/DialectConversion.h"
1821#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -175,6 +178,69 @@ struct GPULaunchKernelConversion
175178 }
176179};
177180
181+ static std::string getFuncName (cuf::SharedMemoryOp op) {
182+ if (auto gpuFuncOp = op->getParentOfType <mlir::gpu::GPUFuncOp>())
183+ return gpuFuncOp.getName ().str ();
184+ if (auto funcOp = op->getParentOfType <mlir::func::FuncOp>())
185+ return funcOp.getName ().str ();
186+ if (auto llvmFuncOp = op->getParentOfType <mlir::LLVM::LLVMFuncOp>())
187+ return llvmFuncOp.getSymName ().str ();
188+ return " " ;
189+ }
190+
191+ static mlir::Value createAddressOfOp (mlir::ConversionPatternRewriter &rewriter,
192+ mlir::Location loc,
193+ gpu::GPUModuleOp gpuMod,
194+ std::string &sharedGlobalName) {
195+ auto llvmPtrTy = mlir::LLVM::LLVMPointerType::get (
196+ rewriter.getContext (), mlir::NVVM::NVVMMemorySpace::kSharedMemorySpace );
197+ if (auto g = gpuMod.lookupSymbol <fir::GlobalOp>(sharedGlobalName))
198+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
199+ g.getSymName ());
200+ if (auto g = gpuMod.lookupSymbol <mlir::LLVM::GlobalOp>(sharedGlobalName))
201+ return rewriter.create <mlir::LLVM::AddressOfOp>(loc, llvmPtrTy,
202+ g.getSymName ());
203+ return {};
204+ }
205+
206+ struct CUFSharedMemoryOpConversion
207+ : public mlir::ConvertOpToLLVMPattern<cuf::SharedMemoryOp> {
208+ explicit CUFSharedMemoryOpConversion (
209+ const fir::LLVMTypeConverter &typeConverter, mlir::PatternBenefit benefit)
210+ : mlir::ConvertOpToLLVMPattern<cuf::SharedMemoryOp>(typeConverter,
211+ benefit) {}
212+ using OpAdaptor = typename cuf::SharedMemoryOp::Adaptor;
213+
214+ mlir::LogicalResult
215+ matchAndRewrite (cuf::SharedMemoryOp op, OpAdaptor adaptor,
216+ mlir::ConversionPatternRewriter &rewriter) const override {
217+ mlir::Location loc = op->getLoc ();
218+ if (!op.getOffset ())
219+ mlir::emitError (loc,
220+ " cuf.shared_memory must have an offset for code gen" );
221+
222+ auto gpuMod = op->getParentOfType <gpu::GPUModuleOp>();
223+ std::string sharedGlobalName =
224+ (getFuncName (op) + llvm::Twine (cudaSharedMemSuffix)).str ();
225+ mlir::Value sharedGlobalAddr =
226+ createAddressOfOp (rewriter, loc, gpuMod, sharedGlobalName);
227+
228+ if (!sharedGlobalAddr)
229+ mlir::emitError (loc, " Could not find the shared global operation\n " );
230+
231+ auto castPtr = rewriter.create <mlir::LLVM::AddrSpaceCastOp>(
232+ loc, mlir::LLVM::LLVMPointerType::get (rewriter.getContext ()),
233+ sharedGlobalAddr);
234+ mlir::Type baseType = castPtr->getResultTypes ().front ();
235+ llvm::SmallVector<mlir::LLVM::GEPArg> gepArgs = {
236+ static_cast <int32_t >(*op.getOffset ())};
237+ mlir::Value shmemPtr = rewriter.create <mlir::LLVM::GEPOp>(
238+ loc, baseType, rewriter.getI8Type (), castPtr, gepArgs);
239+ rewriter.replaceOp (op, {shmemPtr});
240+ return mlir::success ();
241+ }
242+ };
243+
178244class CUFGPUToLLVMConversion
179245 : public fir::impl::CUFGPUToLLVMConversionBase<CUFGPUToLLVMConversion> {
180246public:
@@ -194,6 +260,7 @@ class CUFGPUToLLVMConversion
194260 /* forceUnifiedTBAATree=*/ false , *dl);
195261 cuf::populateCUFGPUToLLVMConversionPatterns (typeConverter, patterns);
196262 target.addIllegalOp <mlir::gpu::LaunchFuncOp>();
263+ target.addIllegalOp <cuf::SharedMemoryOp>();
197264 target.addLegalDialect <mlir::LLVM::LLVMDialect>();
198265 if (mlir::failed (mlir::applyPartialConversion (getOperation (), target,
199266 std::move (patterns)))) {
@@ -208,5 +275,6 @@ class CUFGPUToLLVMConversion
208275void cuf::populateCUFGPUToLLVMConversionPatterns (
209276 const fir::LLVMTypeConverter &converter, mlir::RewritePatternSet &patterns,
210277 mlir::PatternBenefit benefit) {
211- patterns.add <GPULaunchKernelConversion>(converter, benefit);
278+ patterns.add <CUFSharedMemoryOpConversion, GPULaunchKernelConversion>(
279+ converter, benefit);
212280}
0 commit comments