1515#include " flang/Optimizer/Dialect/FIROps.h"
1616#include " flang/Optimizer/HLFIR/HLFIROps.h"
1717#include " flang/Optimizer/Support/DataLayout.h"
18+ #include " flang/Optimizer/Transforms/CUFCommon.h"
1819#include " flang/Runtime/CUDA/allocatable.h"
1920#include " flang/Runtime/CUDA/common.h"
2021#include " flang/Runtime/CUDA/descriptor.h"
@@ -620,6 +621,69 @@ struct CufDataTransferOpConversion
620621 const mlir::SymbolTable &symtab;
621622};
622623
624+ struct CUFLaunchOpConversion
625+ : public mlir::OpRewritePattern<cuf::KernelLaunchOp> {
626+ public:
627+ using OpRewritePattern::OpRewritePattern;
628+
629+ CUFLaunchOpConversion (mlir::MLIRContext *context,
630+ const mlir::SymbolTable &symTab)
631+ : OpRewritePattern(context), symTab{symTab} {}
632+
633+ mlir::LogicalResult
634+ matchAndRewrite (cuf::KernelLaunchOp op,
635+ mlir::PatternRewriter &rewriter) const override {
636+ mlir::Location loc = op.getLoc ();
637+ auto idxTy = mlir::IndexType::get (op.getContext ());
638+ auto zero = rewriter.create <mlir::arith::ConstantOp>(
639+ loc, rewriter.getIntegerType (32 ), rewriter.getI32IntegerAttr (0 ));
640+ auto gridSizeX =
641+ rewriter.create <mlir::arith::IndexCastOp>(loc, idxTy, op.getGridX ());
642+ auto gridSizeY =
643+ rewriter.create <mlir::arith::IndexCastOp>(loc, idxTy, op.getGridY ());
644+ auto gridSizeZ =
645+ rewriter.create <mlir::arith::IndexCastOp>(loc, idxTy, op.getGridZ ());
646+ auto blockSizeX =
647+ rewriter.create <mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockX ());
648+ auto blockSizeY =
649+ rewriter.create <mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockY ());
650+ auto blockSizeZ =
651+ rewriter.create <mlir::arith::IndexCastOp>(loc, idxTy, op.getBlockZ ());
652+ auto kernelName = mlir::SymbolRefAttr::get (
653+ rewriter.getStringAttr (cudaDeviceModuleName),
654+ {mlir::SymbolRefAttr::get (
655+ rewriter.getContext (),
656+ op.getCallee ().getLeafReference ().getValue ())});
657+ mlir::Value clusterDimX, clusterDimY, clusterDimZ;
658+ if (auto funcOp = symTab.lookup <mlir::func::FuncOp>(
659+ op.getCallee ().getLeafReference ())) {
660+ if (auto clusterDimsAttr = funcOp->getAttrOfType <cuf::ClusterDimsAttr>(
661+ cuf::getClusterDimsAttrName ())) {
662+ clusterDimX = rewriter.create <mlir::arith::ConstantIndexOp>(
663+ loc, clusterDimsAttr.getX ().getInt ());
664+ clusterDimY = rewriter.create <mlir::arith::ConstantIndexOp>(
665+ loc, clusterDimsAttr.getY ().getInt ());
666+ clusterDimZ = rewriter.create <mlir::arith::ConstantIndexOp>(
667+ loc, clusterDimsAttr.getZ ().getInt ());
668+ }
669+ }
670+ auto gpuLaunchOp = rewriter.create <mlir::gpu::LaunchFuncOp>(
671+ loc, kernelName, mlir::gpu::KernelDim3{gridSizeX, gridSizeY, gridSizeZ},
672+ mlir::gpu::KernelDim3{blockSizeX, blockSizeY, blockSizeZ}, zero,
673+ op.getArgs ());
674+ if (clusterDimX && clusterDimY && clusterDimZ) {
675+ gpuLaunchOp.getClusterSizeXMutable ().assign (clusterDimX);
676+ gpuLaunchOp.getClusterSizeYMutable ().assign (clusterDimY);
677+ gpuLaunchOp.getClusterSizeZMutable ().assign (clusterDimZ);
678+ }
679+ rewriter.replaceOp (op, gpuLaunchOp);
680+ return mlir::success ();
681+ }
682+
683+ private:
684+ const mlir::SymbolTable &symTab;
685+ };
686+
623687class CUFOpConversion : public fir ::impl::CUFOpConversionBase<CUFOpConversion> {
624688public:
625689 void runOnOperation () override {
@@ -637,7 +701,8 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
637701 fir::support::getOrSetDataLayout (module , /* allowDefaultLayout=*/ false );
638702 fir::LLVMTypeConverter typeConverter (module , /* applyTBAA=*/ false ,
639703 /* forceUnifiedTBAATree=*/ false , *dl);
640- target.addLegalDialect <fir::FIROpsDialect, mlir::arith::ArithDialect>();
704+ target.addLegalDialect <fir::FIROpsDialect, mlir::arith::ArithDialect,
705+ mlir::gpu::GPUDialect>();
641706 cuf::populateCUFToFIRConversionPatterns (typeConverter, *dl, symtab,
642707 patterns);
643708 if (mlir::failed (mlir::applyPartialConversion (getOperation (), target,
@@ -656,5 +721,6 @@ void cuf::populateCUFToFIRConversionPatterns(
656721 patterns.insert <CufAllocOpConversion>(patterns.getContext (), &dl, &converter);
657722 patterns.insert <CufAllocateOpConversion, CufDeallocateOpConversion,
658723 CufFreeOpConversion>(patterns.getContext ());
659- patterns.insert <CufDataTransferOpConversion>(patterns.getContext (), symtab);
724+ patterns.insert <CufDataTransferOpConversion, CUFLaunchOpConversion>(
725+ patterns.getContext (), symtab);
660726}
0 commit comments