From d35259cddf5f8c2953cf1815288a501665578042 Mon Sep 17 00:00:00 2001 From: Renaud-K Date: Fri, 18 Oct 2024 14:48:35 -0700 Subject: [PATCH] call CUFGetDeviceAddress to get global device address from host address --- .../Optimizer/Transforms/CufOpConversion.h | 2 + .../Optimizer/Transforms/CufOpConversion.cpp | 96 ++++++++++++++++--- flang/test/Fir/CUDA/cuda-data-transfer.fir | 43 +++++++++ 3 files changed, 126 insertions(+), 15 deletions(-) diff --git a/flang/include/flang/Optimizer/Transforms/CufOpConversion.h b/flang/include/flang/Optimizer/Transforms/CufOpConversion.h index 79ce4ac5c6cbc..0a71cdfddec1a 100644 --- a/flang/include/flang/Optimizer/Transforms/CufOpConversion.h +++ b/flang/include/flang/Optimizer/Transforms/CufOpConversion.h @@ -18,12 +18,14 @@ class LLVMTypeConverter; namespace mlir { class DataLayout; +class SymbolTable; } namespace cuf { void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl, + const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns); } // namespace cuf diff --git a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp index 91ef1259332de..9df559ee0ab1f 100644 --- a/flang/lib/Optimizer/Transforms/CufOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CufOpConversion.cpp @@ -77,6 +77,69 @@ static bool hasDoubleDescriptors(OpTy op) { return false; } +static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, + mlir::Location loc, mlir::Type toTy, + mlir::Value val) { + if (val.getType() != toTy) + return rewriter.create(loc, toTy, val); + return val; +} + +mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter, + mlir::OpOperand &operand, + const mlir::SymbolTable &symtab) { + mlir::Value v = operand.get(); + auto declareOp = v.getDefiningOp(); + if (!declareOp) + return v; + + auto addrOfOp = declareOp.getMemref().getDefiningOp(); + if (!addrOfOp) + return v; + + auto globalOp = symtab.lookup( + addrOfOp.getSymbol().getRootReference().getValue()); + + if (!globalOp) + return v; + + bool isDevGlobal{false}; + auto attr = globalOp.getDataAttrAttr(); + if (attr) { + switch (attr.getValue()) { + case cuf::DataAttribute::Device: + case cuf::DataAttribute::Managed: + case cuf::DataAttribute::Pinned: + isDevGlobal = true; + break; + default: + break; + } + } + if (!isDevGlobal) + return v; + mlir::OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(operand.getOwner()); + auto loc = declareOp.getLoc(); + auto mod = declareOp->getParentOfType(); + fir::FirOpBuilder builder(rewriter, mod); + + mlir::func::FuncOp callee = + fir::runtime::getRuntimeFunc(loc, builder); + auto fTy = callee.getFunctionType(); + auto toTy = fTy.getInput(0); + mlir::Value inputArg = + createConvertOp(rewriter, loc, toTy, declareOp.getResult()); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); + llvm::SmallVector args{fir::runtime::createArguments( + builder, loc, fTy, inputArg, sourceFile, sourceLine)}; + auto call = rewriter.create(loc, callee, args); + + return call->getResult(0); +} + template static mlir::LogicalResult convertOpToCall(OpTy op, mlir::PatternRewriter &rewriter, @@ -363,18 +426,14 @@ struct CufFreeOpConversion : public mlir::OpRewritePattern { } }; -static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, - mlir::Location loc, mlir::Type toTy, - mlir::Value val) { - if (val.getType() != toTy) - return rewriter.create(loc, toTy, val); - return val; -} - struct CufDataTransferOpConversion : public mlir::OpRewritePattern { using OpRewritePattern::OpRewritePattern; + CufDataTransferOpConversion(mlir::MLIRContext *context, + const mlir::SymbolTable &symtab) + : OpRewritePattern(context), symtab{symtab} {} + mlir::LogicalResult matchAndRewrite(cuf::DataTransferOp op, mlir::PatternRewriter &rewriter) const override { @@ -445,9 +504,11 @@ struct CufDataTransferOpConversion mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(5)); - llvm::SmallVector args{fir::runtime::createArguments( - builder, loc, fTy, op.getDst(), op.getSrc(), bytes, modeValue, - sourceFile, sourceLine)}; + mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab); + mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab); + llvm::SmallVector args{ + fir::runtime::createArguments(builder, loc, fTy, dst, src, bytes, + modeValue, sourceFile, sourceLine)}; builder.create(loc, func, args); rewriter.eraseOp(op); return mlir::success(); @@ -552,6 +613,9 @@ struct CufDataTransferOpConversion } return mlir::success(); } + +private: + const mlir::SymbolTable &symtab; }; class CufOpConversion : public fir::impl::CufOpConversionBase { @@ -565,13 +629,15 @@ class CufOpConversion : public fir::impl::CufOpConversionBase { mlir::ModuleOp module = mlir::dyn_cast(op); if (!module) return signalPassFailure(); + mlir::SymbolTable symtab(module); std::optional dl = fir::support::getOrSetDataLayout(module, /*allowDefaultLayout=*/false); fir::LLVMTypeConverter typeConverter(module, /*applyTBAA=*/false, /*forceUnifiedTBAATree=*/false, *dl); target.addLegalDialect(); - cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, patterns); + cuf::populateCUFToFIRConversionPatterns(typeConverter, *dl, symtab, + patterns); if (mlir::failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) { mlir::emitError(mlir::UnknownLoc::get(ctx), @@ -584,9 +650,9 @@ class CufOpConversion : public fir::impl::CufOpConversionBase { void cuf::populateCUFToFIRConversionPatterns( const fir::LLVMTypeConverter &converter, mlir::DataLayout &dl, - mlir::RewritePatternSet &patterns) { + const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) { patterns.insert(patterns.getContext(), &dl, &converter); patterns.insert( - patterns.getContext()); + CufFreeOpConversion>(patterns.getContext()); + patterns.insert(patterns.getContext(), symtab); } diff --git a/flang/test/Fir/CUDA/cuda-data-transfer.fir b/flang/test/Fir/CUDA/cuda-data-transfer.fir index ed894aed5534a..c33c50115b9fc 100644 --- a/flang/test/Fir/CUDA/cuda-data-transfer.fir +++ b/flang/test/Fir/CUDA/cuda-data-transfer.fir @@ -189,4 +189,47 @@ func.func @_QPsub7() { // CHECK: %[[SRC:.*]] = fir.convert %[[IHOST]]#0 : (!fir.ref>) -> !fir.llvm_ptr // CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %[[BYTES]], %c0{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none +fir.global @_QMmtestsEn(dense<[3, 4, 5, 6, 7]> : tensor<5xi32>) {data_attr = #cuf.cuda} : !fir.array<5xi32> +func.func @_QPsub8() attributes {fir.bindc_name = "t"} { + %c5 = arith.constant 5 : index + %0 = fir.alloca !fir.array<5xi32> {bindc_name = "m", uniq_name = "_QFEm"} + %1 = fir.shape %c5 : (index) -> !fir.shape<1> + %2 = fir.declare %0(%1) {uniq_name = "_QFEm"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> + %3 = fir.address_of(@_QMmtestsEn) : !fir.ref> + %4 = fir.declare %3(%1) {data_attr = #cuf.cuda, uniq_name = "_QMmtestsEn"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> + cuf.data_transfer %4 to %2 {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.ref> + return +} + +// CHECK-LABEL: func.func @_QPsub8() +// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> +// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]] +// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref> +// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]] +// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none + + +func.func @_QPsub9() { + %c5 = arith.constant 5 : index + %0 = fir.alloca !fir.array<5xi32> {bindc_name = "m", uniq_name = "_QFtest9Em"} + %1 = fir.shape %c5 : (index) -> !fir.shape<1> + %2 = fir.declare %0(%1) {uniq_name = "_QFtest9Em"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> + %3 = fir.address_of(@_QMmtestsEn) : !fir.ref> + %4 = fir.declare %3(%1) {data_attr = #cuf.cuda, uniq_name = "_QMmtestsEn"} : (!fir.ref>, !fir.shape<1>) -> !fir.ref> + cuf.data_transfer %2 to %4 {transfer_kind = #cuf.cuda_transfer} : !fir.ref>, !fir.ref> + return +} + +// CHECK-LABEL: func.func @_QPsub9() +// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32> +// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]] +// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref> +// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]] +// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.ref, i32) -> !fir.llvm_ptr +// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref>) -> !fir.llvm_ptr +// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr, !fir.llvm_ptr, i64, i32, !fir.ref, i32) -> none } // end of module