Skip to content

Commit 821a434

Browse files
clementvalkrishna2803
authored andcommitted
[flang][cuda] Generate cuf.allocate for descriptor with CUDA components (llvm#152041)
The descriptor for derived-type with CUDA components are allocated in managed memory. The lowering was calling the standard runtime on allocate statement where it should be a `cuf.allocate` operation.
1 parent 927125b commit 821a434

File tree

5 files changed

+39
-5
lines changed

5 files changed

+39
-5
lines changed

flang/include/flang/Semantics/tools.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ inline bool HasCUDAAttr(const Symbol &sym) {
223223
return false;
224224
}
225225

226+
bool HasCUDAComponent(const Symbol &sym);
227+
226228
inline bool IsCUDAShared(const Symbol &sym) {
227229
if (const auto *details{sym.GetUltimate().detailsIf<ObjectEntityDetails>()}) {
228230
if (details->cudaDataAttr() &&

flang/lib/Lower/Allocatable.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -466,7 +466,9 @@ class AllocateStmtHelper {
466466

467467
void genSimpleAllocation(const Allocation &alloc,
468468
const fir::MutableBoxValue &box) {
469-
bool isCudaSymbol = Fortran::semantics::HasCUDAAttr(alloc.getSymbol());
469+
bool isCudaAllocate =
470+
Fortran::semantics::HasCUDAAttr(alloc.getSymbol()) ||
471+
Fortran::semantics::HasCUDAComponent(alloc.getSymbol());
470472
bool isCudaDeviceContext = cuf::isCUDADeviceContext(builder.getRegion());
471473
bool inlineAllocation = !box.isDerived() && !errorManager.hasStatSpec() &&
472474
!alloc.type.IsPolymorphic() &&
@@ -475,7 +477,7 @@ class AllocateStmtHelper {
475477
unsigned allocatorIdx = Fortran::lower::getAllocatorIdx(alloc.getSymbol());
476478

477479
if (inlineAllocation &&
478-
((isCudaSymbol && isCudaDeviceContext) || !isCudaSymbol)) {
480+
((isCudaAllocate && isCudaDeviceContext) || !isCudaAllocate)) {
479481
// Pointers must use PointerAllocate so that their deallocations
480482
// can be validated.
481483
genInlinedAllocation(alloc, box);
@@ -494,7 +496,7 @@ class AllocateStmtHelper {
494496
genSetDeferredLengthParameters(alloc, box);
495497
genAllocateObjectBounds(alloc, box);
496498
mlir::Value stat;
497-
if (!isCudaSymbol) {
499+
if (!isCudaAllocate) {
498500
stat = genRuntimeAllocate(builder, loc, box, errorManager);
499501
setPinnedToFalse();
500502
} else {

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -814,8 +814,10 @@ initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter,
814814
baseTy = boxTy.getEleTy();
815815
baseTy = fir::unwrapRefType(baseTy);
816816

817-
if (mlir::isa<fir::SequenceType>(baseTy))
818-
TODO(loc, "array of derived-type with device component");
817+
if (mlir::isa<fir::SequenceType>(baseTy) &&
818+
(fir::isAllocatableType(fir::getBase(exv).getType()) ||
819+
fir::isPointerType(fir::getBase(exv).getType())))
820+
return; // Allocator index need to be set after allocation.
819821

820822
auto recTy =
821823
mlir::dyn_cast<fir::RecordType>(fir::unwrapSequenceType(baseTy));

flang/lib/Semantics/tools.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,21 @@ bool IsDeviceAllocatable(const Symbol &symbol) {
10941094
return false;
10951095
}
10961096

1097+
bool HasCUDAComponent(const Symbol &symbol) {
1098+
if (const auto *details{symbol.GetUltimate()
1099+
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
1100+
const Fortran::semantics::DeclTypeSpec *type{details->type()};
1101+
const Fortran::semantics::DerivedTypeSpec *derived{
1102+
type ? type->AsDerived() : nullptr};
1103+
if (derived) {
1104+
if (FindCUDADeviceAllocatableUltimateComponent(*derived)) {
1105+
return true;
1106+
}
1107+
}
1108+
}
1109+
return false;
1110+
}
1111+
10971112
UltimateComponentIterator::const_iterator
10981113
FindCUDADeviceAllocatableUltimateComponent(const DerivedTypeSpec &derived) {
10991114
UltimateComponentIterator ultimates{derived};

flang/test/Lower/CUDA/cuda-allocatable.cuf

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ module globals
66
real, device, allocatable :: a_device(:)
77
real, managed, allocatable :: a_managed(:)
88
real, pinned, allocatable :: a_pinned(:)
9+
type :: t1
10+
integer :: a
11+
real, dimension(:), allocatable, device :: b
12+
end type
913
end module
1014

1115
! CHECK-LABEL: fir.global @_QMglobalsEa_device {data_attr = #cuf.cuda<device>} : !fir.box<!fir.heap<!fir.array<?xf32>>>
@@ -222,3 +226,12 @@ end
222226
! CHECK: %[[FALSE:.*]] = arith.constant false
223227
! CHECK: %[[FLASE_CONV:.*]] = fir.convert %[[FALSE]] : (i1) -> !fir.logical<4>
224228
! CHECK: fir.store %[[FLASE_CONV]] to %[[PLOG_DECL]]#0 : !fir.ref<!fir.logical<4>>
229+
230+
subroutine cuda_component()
231+
use globals
232+
type(t1), pointer, dimension(:) :: d
233+
allocate(d(10))
234+
end subroutine
235+
236+
! CHECK-LABEL: func.func @_QPcuda_component()
237+
! CHECK: cuf.allocate

0 commit comments

Comments
 (0)