Skip to content

Commit d897355

Browse files
authored
[flang][cuda] Set the allocator of derived type component after allocation (#152379)
- Move the allocator index set up after the allocate statement otherwise the derived type descriptor is not allocated. - Support array of derived-type with device component
1 parent 885ddf4 commit d897355

File tree

7 files changed

+205
-89
lines changed

7 files changed

+205
-89
lines changed

flang/include/flang/Lower/Cuda.h renamed to flang/include/flang/Lower/CUDA.h

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-- Lower/Cuda.h -- Cuda Fortran utilities ------------------*- C++ -*-===//
1+
//===-- Lower/CUDA.h -- CUDA Fortran utilities ------------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -15,6 +15,7 @@
1515

1616
#include "flang/Optimizer/Builder/FIRBuilder.h"
1717
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
18+
#include "flang/Runtime/allocator-registry-consts.h"
1819
#include "flang/Semantics/tools.h"
1920
#include "mlir/Dialect/Func/IR/FuncOps.h"
2021
#include "mlir/Dialect/OpenACC/OpenACC.h"
@@ -37,6 +38,15 @@ static inline unsigned getAllocatorIdx(const Fortran::semantics::Symbol &sym) {
3738
return kDefaultAllocator;
3839
}
3940

41+
void initializeDeviceComponentAllocator(
42+
Fortran::lower::AbstractConverter &converter,
43+
const Fortran::semantics::Symbol &sym, const fir::MutableBoxValue &box);
44+
45+
mlir::Type gatherDeviceComponentCoordinatesAndType(
46+
fir::FirOpBuilder &builder, mlir::Location loc,
47+
const Fortran::semantics::Symbol &sym, fir::RecordType recTy,
48+
llvm::SmallVector<mlir::Value> &coordinates);
49+
4050
} // end namespace Fortran::lower
4151

4252
#endif // FORTRAN_LOWER_CUDA_H

flang/lib/Lower/Allocatable.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
#include "flang/Lower/Allocatable.h"
1414
#include "flang/Evaluate/tools.h"
1515
#include "flang/Lower/AbstractConverter.h"
16+
#include "flang/Lower/CUDA.h"
1617
#include "flang/Lower/ConvertType.h"
1718
#include "flang/Lower/ConvertVariable.h"
18-
#include "flang/Lower/Cuda.h"
1919
#include "flang/Lower/IterationSpace.h"
2020
#include "flang/Lower/Mangler.h"
2121
#include "flang/Lower/OpenACC.h"
@@ -445,10 +445,14 @@ class AllocateStmtHelper {
445445
/*mustBeHeap=*/true);
446446
}
447447

448-
void postAllocationAction(const Allocation &alloc) {
448+
void postAllocationAction(const Allocation &alloc,
449+
const fir::MutableBoxValue &box) {
449450
if (alloc.getSymbol().test(Fortran::semantics::Symbol::Flag::AccDeclare))
450451
Fortran::lower::attachDeclarePostAllocAction(converter, builder,
451452
alloc.getSymbol());
453+
if (Fortran::semantics::HasCUDAComponent(alloc.getSymbol()))
454+
Fortran::lower::initializeDeviceComponentAllocator(
455+
converter, alloc.getSymbol(), box);
452456
}
453457

454458
void setPinnedToFalse() {
@@ -481,7 +485,7 @@ class AllocateStmtHelper {
481485
// Pointers must use PointerAllocate so that their deallocations
482486
// can be validated.
483487
genInlinedAllocation(alloc, box);
484-
postAllocationAction(alloc);
488+
postAllocationAction(alloc, box);
485489
setPinnedToFalse();
486490
return;
487491
}
@@ -504,7 +508,7 @@ class AllocateStmtHelper {
504508
genCudaAllocate(builder, loc, box, errorManager, alloc.getSymbol());
505509
}
506510
fir::factory::syncMutableBoxFromIRBox(builder, loc, box);
507-
postAllocationAction(alloc);
511+
postAllocationAction(alloc, box);
508512
errorManager.assignStat(builder, loc, stat);
509513
}
510514

@@ -647,7 +651,7 @@ class AllocateStmtHelper {
647651
setPinnedToFalse();
648652
}
649653
fir::factory::syncMutableBoxFromIRBox(builder, loc, box);
650-
postAllocationAction(alloc);
654+
postAllocationAction(alloc, box);
651655
errorManager.assignStat(builder, loc, stat);
652656
}
653657

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
#include "flang/Lower/Bridge.h"
1414

1515
#include "flang/Lower/Allocatable.h"
16+
#include "flang/Lower/CUDA.h"
1617
#include "flang/Lower/CallInterface.h"
1718
#include "flang/Lower/Coarray.h"
1819
#include "flang/Lower/ConvertCall.h"
1920
#include "flang/Lower/ConvertExpr.h"
2021
#include "flang/Lower/ConvertExprToHLFIR.h"
2122
#include "flang/Lower/ConvertType.h"
2223
#include "flang/Lower/ConvertVariable.h"
23-
#include "flang/Lower/Cuda.h"
2424
#include "flang/Lower/DirectivesCommon.h"
2525
#include "flang/Lower/HostAssociations.h"
2626
#include "flang/Lower/IO.h"

flang/lib/Lower/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_flang_library(FortranLower
1515
ConvertProcedureDesignator.cpp
1616
ConvertType.cpp
1717
ConvertVariable.cpp
18+
CUDA.cpp
1819
CustomIntrinsicCall.cpp
1920
HlfirIntrinsics.cpp
2021
HostAssociations.cpp

flang/lib/Lower/CUDA.cpp

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
//===-- CUDA.cpp -- CUDA Fortran specific lowering ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "flang/Lower/CUDA.h"
14+
15+
#define DEBUG_TYPE "flang-lower-cuda"
16+
17+
void Fortran::lower::initializeDeviceComponentAllocator(
18+
Fortran::lower::AbstractConverter &converter,
19+
const Fortran::semantics::Symbol &sym, const fir::MutableBoxValue &box) {
20+
if (const auto *details{
21+
sym.GetUltimate()
22+
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
23+
const Fortran::semantics::DeclTypeSpec *type{details->type()};
24+
const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
25+
: nullptr};
26+
if (derived) {
27+
if (!FindCUDADeviceAllocatableUltimateComponent(*derived))
28+
return; // No device components.
29+
30+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
31+
mlir::Location loc = converter.getCurrentLocation();
32+
33+
mlir::Type baseTy = fir::unwrapRefType(box.getAddr().getType());
34+
35+
// Only pointer and allocatable needs post allocation initialization
36+
// of components descriptors.
37+
if (!fir::isAllocatableType(baseTy) && !fir::isPointerType(baseTy))
38+
return;
39+
40+
// Extract the derived type.
41+
mlir::Type ty = fir::getDerivedType(baseTy);
42+
auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
43+
assert(recTy && "expected fir::RecordType");
44+
45+
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseTy))
46+
baseTy = boxTy.getEleTy();
47+
baseTy = fir::unwrapRefType(baseTy);
48+
49+
Fortran::semantics::UltimateComponentIterator components{*derived};
50+
mlir::Value loadedBox = fir::LoadOp::create(builder, loc, box.getAddr());
51+
mlir::Value addr;
52+
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(baseTy)) {
53+
mlir::Type idxTy = builder.getIndexType();
54+
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
55+
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
56+
llvm::SmallVector<fir::DoLoopOp> loops;
57+
llvm::SmallVector<mlir::Value> indices;
58+
llvm::SmallVector<mlir::Value> extents;
59+
for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
60+
mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
61+
auto dimInfo = fir::BoxDimsOp::create(builder, loc, idxTy, idxTy,
62+
idxTy, loadedBox, dim);
63+
mlir::Value lbub = mlir::arith::AddIOp::create(
64+
builder, loc, dimInfo.getResult(0), dimInfo.getResult(1));
65+
mlir::Value ext =
66+
mlir::arith::SubIOp::create(builder, loc, lbub, one);
67+
mlir::Value cmp = mlir::arith::CmpIOp::create(
68+
builder, loc, mlir::arith::CmpIPredicate::sgt, ext, zero);
69+
ext = mlir::arith::SelectOp::create(builder, loc, cmp, ext, zero);
70+
extents.push_back(ext);
71+
72+
auto loop = fir::DoLoopOp::create(
73+
builder, loc, dimInfo.getResult(0), dimInfo.getResult(1),
74+
dimInfo.getResult(2), /*isUnordered=*/true,
75+
/*finalCount=*/false, mlir::ValueRange{});
76+
loops.push_back(loop);
77+
indices.push_back(loop.getInductionVar());
78+
builder.setInsertionPointToStart(loop.getBody());
79+
}
80+
mlir::Value boxAddr = fir::BoxAddrOp::create(builder, loc, loadedBox);
81+
auto shape = fir::ShapeOp::create(builder, loc, extents);
82+
addr = fir::ArrayCoorOp::create(
83+
builder, loc, fir::ReferenceType::get(recTy), boxAddr, shape,
84+
/*slice=*/mlir::Value{}, indices, /*typeparms=*/mlir::ValueRange{});
85+
} else {
86+
addr = fir::BoxAddrOp::create(builder, loc, loadedBox);
87+
}
88+
for (const auto &compSym : components) {
89+
if (Fortran::semantics::IsDeviceAllocatable(compSym)) {
90+
llvm::SmallVector<mlir::Value> coord;
91+
mlir::Type fieldTy = gatherDeviceComponentCoordinatesAndType(
92+
builder, loc, compSym, recTy, coord);
93+
assert(coord.size() == 1 && "expect one coordinate");
94+
mlir::Value comp = fir::CoordinateOp::create(
95+
builder, loc, builder.getRefType(fieldTy), addr, coord[0]);
96+
cuf::DataAttributeAttr dataAttr =
97+
Fortran::lower::translateSymbolCUFDataAttribute(
98+
builder.getContext(), compSym);
99+
cuf::SetAllocatorIndexOp::create(builder, loc, comp, dataAttr);
100+
}
101+
}
102+
}
103+
}
104+
}
105+
106+
mlir::Type Fortran::lower::gatherDeviceComponentCoordinatesAndType(
107+
fir::FirOpBuilder &builder, mlir::Location loc,
108+
const Fortran::semantics::Symbol &sym, fir::RecordType recTy,
109+
llvm::SmallVector<mlir::Value> &coordinates) {
110+
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
111+
mlir::Type fieldTy;
112+
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
113+
// Field found in the base record type.
114+
auto fieldName = recTy.getTypeList()[fieldIdx].first;
115+
fieldTy = recTy.getTypeList()[fieldIdx].second;
116+
mlir::Value fieldIndex = fir::FieldIndexOp::create(
117+
builder, loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
118+
recTy,
119+
/*typeParams=*/mlir::ValueRange{});
120+
coordinates.push_back(fieldIndex);
121+
} else {
122+
// Field not found in base record type, search in potential
123+
// record type components.
124+
for (auto component : recTy.getTypeList()) {
125+
if (auto childRecTy = mlir::dyn_cast<fir::RecordType>(component.second)) {
126+
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
127+
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
128+
mlir::Value parentFieldIndex = fir::FieldIndexOp::create(
129+
builder, loc, fir::FieldType::get(childRecTy.getContext()),
130+
component.first, recTy,
131+
/*typeParams=*/mlir::ValueRange{});
132+
coordinates.push_back(parentFieldIndex);
133+
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
134+
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
135+
mlir::Value childFieldIndex = fir::FieldIndexOp::create(
136+
builder, loc, fir::FieldType::get(fieldTy.getContext()),
137+
fieldName, childRecTy,
138+
/*typeParams=*/mlir::ValueRange{});
139+
coordinates.push_back(childFieldIndex);
140+
break;
141+
}
142+
}
143+
}
144+
}
145+
if (coordinates.empty())
146+
TODO(loc, "device resident component in complex derived-type hierarchy");
147+
return fieldTy;
148+
}

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 9 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
#include "flang/Lower/AbstractConverter.h"
1515
#include "flang/Lower/Allocatable.h"
1616
#include "flang/Lower/BoxAnalyzer.h"
17+
#include "flang/Lower/CUDA.h"
1718
#include "flang/Lower/CallInterface.h"
1819
#include "flang/Lower/ConvertConstant.h"
1920
#include "flang/Lower/ConvertExpr.h"
2021
#include "flang/Lower/ConvertExprToHLFIR.h"
2122
#include "flang/Lower/ConvertProcedureDesignator.h"
22-
#include "flang/Lower/Cuda.h"
2323
#include "flang/Lower/Mangler.h"
2424
#include "flang/Lower/PFTBuilder.h"
2525
#include "flang/Lower/StatementContext.h"
@@ -814,81 +814,24 @@ initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter,
814814
baseTy = boxTy.getEleTy();
815815
baseTy = fir::unwrapRefType(baseTy);
816816

817-
if (mlir::isa<fir::SequenceType>(baseTy) &&
818-
(fir::isAllocatableType(fir::getBase(exv).getType()) ||
819-
fir::isPointerType(fir::getBase(exv).getType())))
817+
if (fir::isAllocatableType(fir::getBase(exv).getType()) ||
818+
fir::isPointerType(fir::getBase(exv).getType()))
820819
return; // Allocator index need to be set after allocation.
821820

822821
auto recTy =
823822
mlir::dyn_cast<fir::RecordType>(fir::unwrapSequenceType(baseTy));
824823
assert(recTy && "expected fir::RecordType");
825824

826-
llvm::SmallVector<mlir::Value> coordinates;
827825
Fortran::semantics::UltimateComponentIterator components{*derived};
828826
for (const auto &sym : components) {
829827
if (Fortran::semantics::IsDeviceAllocatable(sym)) {
830-
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
831-
mlir::Type fieldTy;
832-
llvm::SmallVector<mlir::Value> coordinates;
833-
834-
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
835-
// Field found in the base record type.
836-
auto fieldName = recTy.getTypeList()[fieldIdx].first;
837-
fieldTy = recTy.getTypeList()[fieldIdx].second;
838-
mlir::Value fieldIndex = fir::FieldIndexOp::create(
839-
builder, loc, fir::FieldType::get(fieldTy.getContext()),
840-
fieldName, recTy,
841-
/*typeParams=*/mlir::ValueRange{});
842-
coordinates.push_back(fieldIndex);
843-
} else {
844-
// Field not found in base record type, search in potential
845-
// record type components.
846-
for (auto component : recTy.getTypeList()) {
847-
if (auto childRecTy =
848-
mlir::dyn_cast<fir::RecordType>(component.second)) {
849-
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
850-
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
851-
mlir::Value parentFieldIndex = fir::FieldIndexOp::create(
852-
builder, loc,
853-
fir::FieldType::get(childRecTy.getContext()),
854-
component.first, recTy,
855-
/*typeParams=*/mlir::ValueRange{});
856-
coordinates.push_back(parentFieldIndex);
857-
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
858-
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
859-
mlir::Value childFieldIndex = fir::FieldIndexOp::create(
860-
builder, loc, fir::FieldType::get(fieldTy.getContext()),
861-
fieldName, childRecTy,
862-
/*typeParams=*/mlir::ValueRange{});
863-
coordinates.push_back(childFieldIndex);
864-
break;
865-
}
866-
}
867-
}
868-
}
869-
870-
if (coordinates.empty())
871-
TODO(loc, "device resident component in complex derived-type "
872-
"hierarchy");
873-
828+
llvm::SmallVector<mlir::Value> coord;
829+
mlir::Type fieldTy =
830+
Fortran::lower::gatherDeviceComponentCoordinatesAndType(
831+
builder, loc, sym, recTy, coord);
874832
mlir::Value base = fir::getBase(exv);
875-
mlir::Value comp;
876-
if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(base.getType()))) {
877-
mlir::Value box = fir::LoadOp::create(builder, loc, base);
878-
mlir::Value addr = fir::BoxAddrOp::create(builder, loc, box);
879-
llvm::SmallVector<mlir::Value> lenParams;
880-
assert(coordinates.size() == 1 && "expect one coordinate");
881-
auto field = mlir::dyn_cast<fir::FieldIndexOp>(
882-
coordinates[0].getDefiningOp());
883-
comp = hlfir::DesignateOp::create(
884-
builder, loc, builder.getRefType(fieldTy), addr,
885-
/*component=*/field.getFieldName(),
886-
/*componentShape=*/mlir::Value{},
887-
hlfir::DesignateOp::Subscripts{});
888-
} else {
889-
comp = fir::CoordinateOp::create(
890-
builder, loc, builder.getRefType(fieldTy), base, coordinates);
891-
}
833+
mlir::Value comp = fir::CoordinateOp::create(
834+
builder, loc, builder.getRefType(fieldTy), base, coord);
892835
cuf::DataAttributeAttr dataAttr =
893836
Fortran::lower::translateSymbolCUFDataAttribute(
894837
builder.getContext(), sym);

0 commit comments

Comments
 (0)