Skip to content

Commit e55071b

Browse files
authored
[flang][cuda] Extent detection of data transfer with conversion (#163852)
1 parent 527f7f5 commit e55071b

File tree

4 files changed

+46
-11
lines changed

4 files changed

+46
-11
lines changed

flang/include/flang/Lower/CUDA.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ class Location;
2727
class MLIRContext;
2828
} // namespace mlir
2929

30+
namespace hlfir {
31+
class ElementalOp;
32+
} // namespace hlfir
33+
3034
namespace Fortran::lower {
3135

3236
class AbstractConverter;
@@ -58,7 +62,9 @@ cuf::DataAttributeAttr
5862
translateSymbolCUFDataAttribute(mlir::MLIRContext *mlirContext,
5963
const Fortran::semantics::Symbol &sym);
6064

61-
bool isTransferWithConversion(mlir::Value rhs);
65+
/// Check if the rhs has an implicit conversion. Return the elemental op if
66+
/// there is a conversion. Return null otherwise.
67+
hlfir::ElementalOp isTransferWithConversion(mlir::Value rhs);
6268

6369
} // end namespace Fortran::lower
6470

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4987,11 +4987,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
49874987

49884988
// host = device
49894989
if (!lhsIsDevice && rhsIsDevice) {
4990-
if (Fortran::lower::isTransferWithConversion(rhs)) {
4990+
if (auto elementalOp = Fortran::lower::isTransferWithConversion(rhs)) {
49914991
mlir::OpBuilder::InsertionGuard insertionGuard(builder);
4992-
auto elementalOp =
4993-
mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp());
4994-
assert(elementalOp && "expect elemental op");
49954992
auto designateOp =
49964993
*elementalOp.getBody()->getOps<hlfir::DesignateOp>().begin();
49974994
builder.setInsertionPoint(elementalOp);

flang/lib/Lower/CUDA.cpp

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,26 @@ cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
6868
return cuf::getDataAttribute(mlirContext, cudaAttr);
6969
}
7070

71-
bool Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
71+
hlfir::ElementalOp Fortran::lower::isTransferWithConversion(mlir::Value rhs) {
72+
auto isConversionElementalOp = [](hlfir::ElementalOp elOp) {
73+
return llvm::hasSingleElement(
74+
elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
75+
llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
76+
llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) ==
77+
1;
78+
};
79+
if (auto declOp = mlir::dyn_cast<hlfir::DeclareOp>(rhs.getDefiningOp())) {
80+
if (!declOp.getMemref().getDefiningOp())
81+
return {};
82+
if (auto associateOp = mlir::dyn_cast<hlfir::AssociateOp>(
83+
declOp.getMemref().getDefiningOp()))
84+
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(
85+
associateOp.getSource().getDefiningOp()))
86+
if (isConversionElementalOp(elOp))
87+
return elOp;
88+
}
7289
if (auto elOp = mlir::dyn_cast<hlfir::ElementalOp>(rhs.getDefiningOp()))
73-
if (llvm::hasSingleElement(elOp.getBody()->getOps<hlfir::DesignateOp>()) &&
74-
llvm::hasSingleElement(elOp.getBody()->getOps<fir::LoadOp>()) == 1 &&
75-
llvm::hasSingleElement(elOp.getBody()->getOps<fir::ConvertOp>()) == 1)
76-
return true;
77-
return false;
90+
if (isConversionElementalOp(elOp))
91+
return elOp;
92+
return {};
7893
}

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,3 +542,20 @@ end subroutine
542542
! CHECK-NOT: cuf.data_transfer
543543
! CHECK: hlfir.assign
544544
! CHECK-NOT: cuf.data_transfer
545+
546+
! Data transfer with conversion with more complex elemental
547+
! Check that the data transfer is placed before the elemental op.
548+
subroutine sub29()
549+
real(2), device, allocatable :: a(:)
550+
real(4), allocatable :: ha(:)
551+
allocate(a(10))
552+
allocate(ha(10))
553+
ha = a
554+
deallocate(a)
555+
end subroutine
556+
557+
! CHECK-LABEL: func.func @_QPsub29()
558+
! CHECK: %[[TMP:.*]] = fir.allocmem !fir.array<?xf16>, %24#1 {bindc_name = ".tmp", uniq_name = ""}
559+
! CHECK: %[[TMP_BUFFER:.*]]:2 = hlfir.declare %[[TMP]](%{{.*}}) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<?xf16>>, !fir.shape<1>) -> (!fir.box<!fir.array<?xf16>>, !fir.heap<!fir.array<?xf16>>)
560+
! CHECK: cuf.data_transfer %{{.*}} to %[[TMP_BUFFER]]#0 {transfer_kind = #cuf.cuda_transfer<device_host>} : !fir.box<!fir.heap<!fir.array<?xf16>>>, !fir.box<!fir.array<?xf16>>
561+
! CHECK: hlfir.elemental

0 commit comments

Comments
 (0)