Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion flang/include/flang/Lower/CUDA.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class Location;
class MLIRContext;
} // namespace mlir

namespace hlfir {
class ElementalOp;
} // namespace hlfir

namespace Fortran::lower {

class AbstractConverter;
Expand Down Expand Up @@ -58,7 +62,9 @@ cuf::DataAttributeAttr
translateSymbolCUFDataAttribute(mlir::MLIRContext *mlirContext,
const Fortran::semantics::Symbol &sym);

bool isTransferWithConversion(mlir::Value rhs);
/// Check is the rhs has an implicit conversion. Return the elemental op if
/// there is a conversion. Return null otherwise.
hlfir::ElementalOp isTransferWithConversion(mlir::Value rhs);

} // end namespace Fortran::lower

Expand Down
5 changes: 1 addition & 4 deletions flang/lib/Lower/Bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4987,11 +4987,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {

// host = device
if (!lhsIsDevice && rhsIsDevice) {
if (Fortran::lower::isTransferWithConversion(rhs)) {
if (auto elementalOp = Fortran::lower::isTransferWithConversion(rhs)) {
mlir::OpBuilder::InsertionGuard insertionGuard(builder);
auto elementalOp =
mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp());
assert(elementalOp && "expect elemental op");
auto designateOp =
*elementalOp.getBody()->getOps<hlfir::DesignateOp>().begin();
builder.setInsertionPoint(elementalOp);
Expand Down
27 changes: 21 additions & 6 deletions flang/lib/Lower/CUDA.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,26 @@ cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
return cuf::getDataAttribute(mlirContext, cudaAttr);
}

bool Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
auto isConversionElementalOp = [](hlfir::ElementalOp elOp) {
return llvm::hasSingleElement(
elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) ==
1;
};
if (auto declOp = mlir::dyn_cast<hlfir::DeclareOp>(rhs.getDefiningOp())) {
if (!declOp.getMemref().getDefiningOp())
return {};
if (auto associateOp = mlir::dyn_cast<hlfir::AssociateOp>(
declOp.getMemref().getDefiningOp()))
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(
associateOp.getSource().getDefiningOp()))
if (isConversionElementalOp(elOp))
return elOp;
}
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()))
if (llvm::hasSingleElement(elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) == 1)
return true;
return false;
if (isConversionElementalOp(elOp))
return elOp;
return {};
}
17 changes: 17 additions & 0 deletions flang/test/Lower/CUDA/cuda-data-transfer.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -542,3 +542,20 @@ end subroutine
! CHECK-NOT: cuf.data_transfer
! CHECK: hlfir.assign
! CHECK-NOT: cuf.data_transfer

! Data transfer with conversion with more complex elemental
! Check that the data transfer is palce
subroutine sub29()
real(2), device, allocatable :: a(:)
real(4), allocatable :: ha(:)
allocate(a(10))
allocate(ha(10))
ha = a
deallocate(a)
end subroutine

! CHECK-LABEL: func.func @_QPsub29()
! CHECK: %[[TMP:.*]] = fir.allocmem !fir.array<?xf16>, %24#1 {bindc_name = ".tmp", uniq_name = ""}
! CHECK: %[[TMP_BUFFER:.*]]:2 = hlfir.declare %[[TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xf16>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf16>>, !fir.heap<!fir.array<?xf16>>)
! CHECK: cuf.data_transfer %{{.*}} to %[[TMP_BUFFER]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.box<!fir.heap<!fir.array<?xf16>>>, !fir.box<!fir.array<?xf16>>
! CHECK: hlfir.elemental
Loading