Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions flang/include/flang/Optimizer/Transforms/CUFOpConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ class SymbolTable;

namespace cuf {

/// Patterns that convert CUF operations to runtime calls.
void populateCUFToFIRConversionPatterns(const fir::LLVMTypeConverter &converter,
mlir::DataLayout &dl,
const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns);

/// Patterns that updates fir operations in presence of CUF.
void populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns);

} // namespace cuf

#endif // FORTRAN_OPTIMIZER_TRANSFORMS_CUFOPCONVERSION_H_
150 changes: 90 additions & 60 deletions flang/lib/Optimizer/Transforms/CUFOpConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ static bool hasDoubleDescriptors(OpTy op) {
return false;
}

bool isDeviceGlobal(fir::GlobalOp op) {
auto attr = op.getDataAttr();
if (attr && (*attr == cuf::DataAttribute::Device ||
*attr == cuf::DataAttribute::Managed ||
*attr == cuf::DataAttribute::Constant))
return true;
return false;
}

static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
mlir::Location loc, mlir::Type toTy,
mlir::Value val) {
Expand All @@ -89,62 +98,6 @@ static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter,
return val;
}

mlir::Value getDeviceAddress(mlir::PatternRewriter &rewriter,
mlir::OpOperand &operand,
const mlir::SymbolTable &symtab) {
mlir::Value v = operand.get();
auto declareOp = v.getDefiningOp<fir::DeclareOp>();
if (!declareOp)
return v;

auto addrOfOp = declareOp.getMemref().getDefiningOp<fir::AddrOfOp>();
if (!addrOfOp)
return v;

auto globalOp = symtab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue());

if (!globalOp)
return v;

bool isDevGlobal{false};
auto attr = globalOp.getDataAttrAttr();
if (attr) {
switch (attr.getValue()) {
case cuf::DataAttribute::Device:
case cuf::DataAttribute::Managed:
case cuf::DataAttribute::Constant:
isDevGlobal = true;
break;
default:
break;
}
}
if (!isDevGlobal)
return v;
mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(operand.getOwner());
auto loc = declareOp.getLoc();
auto mod = declareOp->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);

mlir::func::FuncOp callee =
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(loc, builder);
auto fTy = callee.getFunctionType();
auto toTy = fTy.getInput(0);
mlir::Value inputArg =
createConvertOp(rewriter, loc, toTy, declareOp.getResult());
mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, inputArg, sourceFile, sourceLine)};
auto call = rewriter.create<fir::CallOp>(loc, callee, args);
mlir::Value cast = createConvertOp(
rewriter, loc, declareOp.getMemref().getType(), call->getResult(0));
return cast;
}

template <typename OpTy>
static mlir::LogicalResult convertOpToCall(OpTy op,
mlir::PatternRewriter &rewriter,
Expand Down Expand Up @@ -422,6 +375,54 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> {
const fir::LLVMTypeConverter *typeConverter;
};

struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> {
using OpRewritePattern::OpRewritePattern;

DeclareOpConversion(mlir::MLIRContext *context,
const mlir::SymbolTable &symtab)
: OpRewritePattern(context), symTab{symtab} {}

mlir::LogicalResult
matchAndRewrite(fir::DeclareOp op,
mlir::PatternRewriter &rewriter) const override {
if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
if (auto global = symTab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue())) {
if (isDeviceGlobal(global)) {
rewriter.setInsertionPointAfter(addrOfOp);
auto mod = op->getParentOfType<mlir::ModuleOp>();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Location loc = op.getLoc();
mlir::func::FuncOp callee =
fir::runtime::getRuntimeFunc<mkRTKey(CUFGetDeviceAddress)>(
loc, builder);
auto fTy = callee.getFunctionType();
mlir::Type toTy = fTy.getInput(0);
mlir::Value inputArg =
createConvertOp(rewriter, loc, toTy, addrOfOp.getResult());
mlir::Value sourceFile =
fir::factory::locationToFilename(builder, loc);
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(2));
llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments(
builder, loc, fTy, inputArg, sourceFile, sourceLine)};
auto call = rewriter.create<fir::CallOp>(loc, callee, args);
mlir::Value cast = createConvertOp(
rewriter, loc, op.getMemref().getType(), call->getResult(0));
rewriter.startOpModification(op);
op.getMemrefMutable().assign(cast);
rewriter.finalizeOpModification(op);
return success();
}
}
}
return failure();
}

private:
const mlir::SymbolTable &symTab;
};

struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -511,7 +512,7 @@ static mlir::Value emboxSrc(mlir::PatternRewriter &rewriter,
builder.create<fir::StoreOp>(loc, src, alloc);
addr = alloc;
} else {
addr = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
addr = op.getSrc();
}
llvm::SmallVector<mlir::Value> lenParams;
mlir::Type boxTy = fir::BoxType::get(srcTy);
Expand All @@ -531,7 +532,7 @@ static mlir::Value emboxDst(mlir::PatternRewriter &rewriter,
mlir::Location loc = op.getLoc();
fir::FirOpBuilder builder(rewriter, mod);
mlir::Type dstTy = fir::unwrapRefType(op.getDst().getType());
mlir::Value dstAddr = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
mlir::Value dstAddr = op.getDst();
mlir::Type dstBoxTy = fir::BoxType::get(dstTy);
llvm::SmallVector<mlir::Value> lenParams;
mlir::Value dstBox =
Expand Down Expand Up @@ -652,8 +653,8 @@ struct CUFDataTransferOpConversion
mlir::Value sourceLine =
fir::factory::locationToLineNo(builder, loc, fTy.getInput(5));

mlir::Value dst = getDeviceAddress(rewriter, op.getDstMutable(), symtab);
mlir::Value src = getDeviceAddress(rewriter, op.getSrcMutable(), symtab);
mlir::Value dst = op.getDst();
mlir::Value src = op.getSrc();
// Materialize the src if constant.
if (matchPattern(src.getDefiningOp(), mlir::m_Constant())) {
mlir::Value temp = builder.createTemporary(loc, srcTy);
Expand Down Expand Up @@ -823,6 +824,30 @@ class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> {
"error in CUF op conversion\n");
signalPassFailure();
}

target.addDynamicallyLegalOp<fir::DeclareOp>([&](fir::DeclareOp op) {
if (inDeviceContext(op))
return true;
if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) {
if (auto global = symtab.lookup<fir::GlobalOp>(
addrOfOp.getSymbol().getRootReference().getValue())) {
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType())))
return true;
if (isDeviceGlobal(global))
return false;
}
}
return true;
});

patterns.clear();
cuf::populateFIRCUFConversionPatterns(symtab, patterns);
if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns)))) {
mlir::emitError(mlir::UnknownLoc::get(ctx),
"error in CUF op conversion\n");
signalPassFailure();
}
}
};
} // namespace
Expand All @@ -837,3 +862,8 @@ void cuf::populateCUFToFIRConversionPatterns(
&dl, &converter);
patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab);
}

void cuf::populateFIRCUFConversionPatterns(const mlir::SymbolTable &symtab,
mlir::RewritePatternSet &patterns) {
patterns.insert<DeclareOpConversion>(patterns.getContext(), symtab);
}
29 changes: 16 additions & 13 deletions flang/test/Fir/CUDA/cuda-data-transfer.fir
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,12 @@ func.func @_QPsub8() attributes {fir.bindc_name = "t"} {
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC_CONV:.*]] = fir.convert %[[SRC]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]]
// CHECK: %[[DST:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[SRC_CONV]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none


Expand All @@ -223,11 +223,11 @@ func.func @_QPsub9() {
// CHECK: %[[ALLOCA:.*]] = fir.alloca !fir.array<5xi32>
// CHECK: %[[LOCAL:.*]] = fir.declare %[[ALLOCA]]
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmtestsEn) : !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[GBL]]
// CHECK: %[[HOST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[DST:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[HOST]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[DST_CONV:.*]] = fir.convert %[[DST]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DST:.*]] = fir.convert %[[DST_CONV]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<5xi32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]]
// CHECK: %[[DST:.*]] = fir.convert %[[DECL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[SRC:.*]] = fir.convert %[[LOCAL]] : (!fir.ref<!fir.array<5xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: fir.call @_FortranACUFDataTransferPtrPtr(%[[DST]], %[[SRC]], %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.llvm_ptr<i8>, i64, i32, !fir.ref<i8>, i32) -> none

Expand Down Expand Up @@ -380,9 +380,12 @@ func.func @_QPdevice_addr_conv() {
}

// CHECK-LABEL: func.func @_QPdevice_addr_conv()
// CHECK: %[[DEV_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%{{.*}}, %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEV_ADDR_CONV:.*]] = fir.convert %[[DEV_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
// CHECK: fir.embox %[[DEV_ADDR_CONV]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
// CHECK: %[[GBL:.*]] = fir.address_of(@_QMmod1Ea_dev) : !fir.ref<!fir.array<4xf32>>
// CHECK: %[[GBL_CONV:.*]] = fir.convert %[[GBL]] : (!fir.ref<!fir.array<4xf32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[GBL_CONV]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[ADDR_CONV:.*]] = fir.convert %[[ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<4xf32>>
// CHECK: %[[DECL:.*]] = fir.declare %[[ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Ea_dev"} : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.ref<!fir.array<4xf32>>
// CHECK: fir.embox %[[DECL]](%{{.*}}) : (!fir.ref<!fir.array<4xf32>>, !fir.shape<1>) -> !fir.box<!fir.array<4xf32>>
// CHECK: fir.call @_FortranACUFDataTransferCstDesc

func.func @_QQchar_transfer() attributes {fir.bindc_name = "char_transfer"} {
Expand Down
34 changes: 34 additions & 0 deletions flang/test/Fir/CUDA/cuda-global-addr.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
// RUN: fir-opt --cuf-convert %s | FileCheck %s

module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<f80, dense<128> : vector<2xi64>>, #dlti.dl_entry<i128, dense<128> : vector<2xi64>>, #dlti.dl_entry<i64, dense<64> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr<272>, dense<64> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<271>, dense<32> : vector<4xi64>>, #dlti.dl_entry<!llvm.ptr<270>, dense<32> : vector<4xi64>>, #dlti.dl_entry<f128, dense<128> : vector<2xi64>>, #dlti.dl_entry<f64, dense<64> : vector<2xi64>>, #dlti.dl_entry<f16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i32, dense<32> : vector<2xi64>>, #dlti.dl_entry<i16, dense<16> : vector<2xi64>>, #dlti.dl_entry<i8, dense<8> : vector<2xi64>>, #dlti.dl_entry<i1, dense<8> : vector<2xi64>>, #dlti.dl_entry<!llvm.ptr, dense<64> : vector<4xi64>>, #dlti.dl_entry<"dlti.endianness", "little">, #dlti.dl_entry<"dlti.stack_alignment", 128 : i64>>} {
fir.global @_QMmod1Eadev {data_attr = #cuf.cuda<device>} : !fir.array<10xi32> {
%0 = fir.zero_bits !fir.array<10xi32>
fir.has_value %0 : !fir.array<10xi32>
}
func.func @_QQmain() attributes {fir.bindc_name = "test"} {
%c14_i32 = arith.constant 14 : i32
%c6_i32 = arith.constant 6 : i32
%c4 = arith.constant 4 : index
%c1_i32 = arith.constant 1 : i32
%c0_i32 = arith.constant 0 : i32
%c10 = arith.constant 10 : index
%1 = fir.shape %c10 : (index) -> !fir.shape<1>
%3 = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
%4 = fir.declare %3(%1) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>
%5 = fir.alloca i32 {bindc_name = "i", uniq_name = "_QFEi"}
%6 = fir.declare %5 {uniq_name = "_QFEi"} : (!fir.ref<i32>) -> !fir.ref<i32>
fir.store %c0_i32 to %6 : !fir.ref<i32>
%7 = fir.array_coor %4(%1) %c4 : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>, index) -> !fir.ref<i32>
cuf.data_transfer %c1_i32 to %7 {transfer_kind = #cuf.cuda_transfer<host_device>} : i32, !fir.ref<i32>
return
}

}

// CHECK-LABEL: func.func @_QQmain()
// CHECK: %[[ADDR:.*]] = fir.address_of(@_QMmod1Eadev) : !fir.ref<!fir.array<10xi32>>
// CHECK: %[[ADDRPTR:.*]] = fir.convert %[[ADDR]] : (!fir.ref<!fir.array<10xi32>>) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEVICE_ADDR:.*]] = fir.call @_FortranACUFGetDeviceAddress(%[[ADDRPTR]], %{{.*}}, %{{.*}}) : (!fir.llvm_ptr<i8>, !fir.ref<i8>, i32) -> !fir.llvm_ptr<i8>
// CHECK: %[[DEVICE_ADDR_CONV:.*]] = fir.convert %[[DEVICE_ADDR]] : (!fir.llvm_ptr<i8>) -> !fir.ref<!fir.array<10xi32>>
// CHECK: %{{.*}} = fir.declare %[[DEVICE_ADDR_CONV]](%{{.*}}) {data_attr = #cuf.cuda<device>, uniq_name = "_QMmod1Eadev"} : (!fir.ref<!fir.array<10xi32>>, !fir.shape<1>) -> !fir.ref<!fir.array<10xi32>>

Loading