diff --git a/flang/include/flang/Lower/CUDA.h b/flang/include/flang/Lower/CUDA.h index ab9dde8ad5198..ef7cdc42d72f2 100644 --- a/flang/include/flang/Lower/CUDA.h +++ b/flang/include/flang/Lower/CUDA.h @@ -27,6 +27,10 @@ class Location; class MLIRContext; } // namespace mlir +namespace hlfir { +class ElementalOp; +} // namespace hlfir + namespace Fortran::lower { class AbstractConverter; @@ -58,7 +62,9 @@ cuf::DataAttributeAttr translateSymbolCUFDataAttribute(mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym); -bool isTransferWithConversion(mlir::Value rhs); +/// Check if the rhs has an implicit conversion. Return the elemental op if +/// there is a conversion. Return null otherwise. +hlfir::ElementalOp isTransferWithConversion(mlir::Value rhs); } // end namespace Fortran::lower diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp index 68adf346fe8c0..525fb0e9997b7 100644 --- a/flang/lib/Lower/Bridge.cpp +++ b/flang/lib/Lower/Bridge.cpp @@ -4987,11 +4987,8 @@ class FirConverter : public Fortran::lower::AbstractConverter { // host = device if (!lhsIsDevice && rhsIsDevice) { - if (Fortran::lower::isTransferWithConversion(rhs)) { + if (auto elementalOp = Fortran::lower::isTransferWithConversion(rhs)) { mlir::OpBuilder::InsertionGuard insertionGuard(builder); - auto elementalOp = - mlir::dyn_cast(rhs.getDefiningOp()); - assert(elementalOp && "expect elemental op"); auto designateOp = *elementalOp.getBody()->getOps().begin(); builder.setInsertionPoint(elementalOp); diff --git a/flang/lib/Lower/CUDA.cpp b/flang/lib/Lower/CUDA.cpp index bb4bdee78f97d..9501b0ec60002 100644 --- a/flang/lib/Lower/CUDA.cpp +++ b/flang/lib/Lower/CUDA.cpp @@ -68,11 +68,26 @@ cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute( return cuf::getDataAttribute(mlirContext, cudaAttr); } -bool Fortran::lower::isTransferWithConversion(mlir::Value rhs) { +hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) { + auto isConversionElementalOp = [](hlfir::ElementalOp elOp) { + return llvm::hasSingleElement( + elOp.getBody()->getOps()) && + llvm::hasSingleElement(elOp.getBody()->getOps()) == 1 && + llvm::hasSingleElement(elOp.getBody()->getOps()) == + 1; + }; + if (auto declOp = mlir::dyn_cast(rhs.getDefiningOp())) { + if (!declOp.getMemref().getDefiningOp()) + return {}; + if (auto associateOp = mlir::dyn_cast( + declOp.getMemref().getDefiningOp())) + if (auto elOp = mlir::dyn_cast( + associateOp.getSource().getDefiningOp())) + if (isConversionElementalOp(elOp)) + return elOp; + } if (auto elOp = mlir::dyn_cast(rhs.getDefiningOp())) - if (llvm::hasSingleElement(elOp.getBody()->getOps()) && - llvm::hasSingleElement(elOp.getBody()->getOps()) == 1 && - llvm::hasSingleElement(elOp.getBody()->getOps()) == 1) - return true; - return false; + if (isConversionElementalOp(elOp)) + return elOp; + return {}; } diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf index d1c8ecca3b019..b0b8d09c0c55b 100644 --- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf +++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf @@ -542,3 +542,20 @@ end subroutine ! CHECK-NOT: cuf.data_transfer ! CHECK: hlfir.assign ! CHECK-NOT: cuf.data_transfer + +! Data transfer with conversion with more complex elemental +! Check that the data transfer is placed before the elemental op. +subroutine sub29() + real(2), device, allocatable :: a(:) + real(4), allocatable :: ha(:) + allocate(a(10)) + allocate(ha(10)) + ha = a + deallocate(a) +end subroutine + +! CHECK-LABEL: func.func @_QPsub29() +! CHECK: %[[TMP:.*]] = fir.allocmem !fir.array, %24#1 {bindc_name = ".tmp", uniq_name = ""} +! CHECK: %[[TMP_BUFFER:.*]]:2 = hlfir.declare %[[TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap>, !fir.shape<1>) -> (!fir.box>, !fir.heap>) +! CHECK: cuf.data_transfer %{{.*}} to %[[TMP_BUFFER]]#0 {transfer_kind = #cuf.cuda_transfer} : !fir.box>>, !fir.box> +! CHECK: hlfir.elemental