Skip to content

Commit eb0ddba

Browse files
authored
Reland "[flang][cuda] Set the allocator of derived type component after allocation" (llvm#152418)
Reviewed in llvm#152379 - Move the allocator index set up after the allocate statement otherwise the derived type descriptor is not allocated. - Support array of derived-type with device component
1 parent 9a592d9 commit eb0ddba

File tree

8 files changed

+227
-102
lines changed

8 files changed

+227
-102
lines changed

flang/include/flang/Lower/Cuda.h renamed to flang/include/flang/Lower/CUDA.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
//===-- Lower/Cuda.h -- Cuda Fortran utilities ------------------*- C++ -*-===//
1+
//===-- Lower/CUDA.h -- CUDA Fortran utilities ------------------*- C++ -*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
@@ -14,13 +14,23 @@
1414
#define FORTRAN_LOWER_CUDA_H
1515

1616
#include "flang/Optimizer/Builder/FIRBuilder.h"
17+
#include "flang/Optimizer/Builder/MutableBox.h"
1718
#include "flang/Optimizer/Dialect/CUF/CUFOps.h"
19+
#include "flang/Runtime/allocator-registry-consts.h"
1820
#include "flang/Semantics/tools.h"
1921
#include "mlir/Dialect/Func/IR/FuncOps.h"
2022
#include "mlir/Dialect/OpenACC/OpenACC.h"
2123

24+
namespace mlir {
25+
class Value;
26+
class Location;
27+
class MLIRContext;
28+
} // namespace mlir
29+
2230
namespace Fortran::lower {
2331

32+
class AbstractConverter;
33+
2434
static inline unsigned getAllocatorIdx(const Fortran::semantics::Symbol &sym) {
2535
std::optional<Fortran::common::CUDADataAttr> cudaAttr =
2636
Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate());
@@ -37,6 +47,21 @@ static inline unsigned getAllocatorIdx(const Fortran::semantics::Symbol &sym) {
3747
return kDefaultAllocator;
3848
}
3949

50+
void initializeDeviceComponentAllocator(
51+
Fortran::lower::AbstractConverter &converter,
52+
const Fortran::semantics::Symbol &sym, const fir::MutableBoxValue &box);
53+
54+
mlir::Type gatherDeviceComponentCoordinatesAndType(
55+
fir::FirOpBuilder &builder, mlir::Location loc,
56+
const Fortran::semantics::Symbol &sym, fir::RecordType recTy,
57+
llvm::SmallVector<mlir::Value> &coordinates);
58+
59+
/// Translate the CUDA Fortran attributes of \p sym into the FIR CUDA attribute
60+
/// representation.
61+
cuf::DataAttributeAttr
62+
translateSymbolCUFDataAttribute(mlir::MLIRContext *mlirContext,
63+
const Fortran::semantics::Symbol &sym);
64+
4065
} // end namespace Fortran::lower
4166

4267
#endif // FORTRAN_LOWER_CUDA_H

flang/include/flang/Lower/ConvertVariable.h

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,12 +162,6 @@ translateSymbolAttributes(mlir::MLIRContext *mlirContext,
162162
fir::FortranVariableFlagsEnum extraFlags =
163163
fir::FortranVariableFlagsEnum::None);
164164

165-
/// Translate the CUDA Fortran attributes of \p sym into the FIR CUDA attribute
166-
/// representation.
167-
cuf::DataAttributeAttr
168-
translateSymbolCUFDataAttribute(mlir::MLIRContext *mlirContext,
169-
const Fortran::semantics::Symbol &sym);
170-
171165
/// Map a symbol to a given fir::ExtendedValue. This will generate an
172166
/// hlfir.declare when lowering to HLFIR and map the hlfir.declare result to the
173167
/// symbol.

flang/lib/Lower/Allocatable.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
#include "flang/Lower/Allocatable.h"
1414
#include "flang/Evaluate/tools.h"
1515
#include "flang/Lower/AbstractConverter.h"
16+
#include "flang/Lower/CUDA.h"
1617
#include "flang/Lower/ConvertType.h"
1718
#include "flang/Lower/ConvertVariable.h"
18-
#include "flang/Lower/Cuda.h"
1919
#include "flang/Lower/IterationSpace.h"
2020
#include "flang/Lower/Mangler.h"
2121
#include "flang/Lower/OpenACC.h"
@@ -445,10 +445,14 @@ class AllocateStmtHelper {
445445
/*mustBeHeap=*/true);
446446
}
447447

448-
void postAllocationAction(const Allocation &alloc) {
448+
void postAllocationAction(const Allocation &alloc,
449+
const fir::MutableBoxValue &box) {
449450
if (alloc.getSymbol().test(Fortran::semantics::Symbol::Flag::AccDeclare))
450451
Fortran::lower::attachDeclarePostAllocAction(converter, builder,
451452
alloc.getSymbol());
453+
if (Fortran::semantics::HasCUDAComponent(alloc.getSymbol()))
454+
Fortran::lower::initializeDeviceComponentAllocator(
455+
converter, alloc.getSymbol(), box);
452456
}
453457

454458
void setPinnedToFalse() {
@@ -481,7 +485,7 @@ class AllocateStmtHelper {
481485
// Pointers must use PointerAllocate so that their deallocations
482486
// can be validated.
483487
genInlinedAllocation(alloc, box);
484-
postAllocationAction(alloc);
488+
postAllocationAction(alloc, box);
485489
setPinnedToFalse();
486490
return;
487491
}
@@ -504,7 +508,7 @@ class AllocateStmtHelper {
504508
genCudaAllocate(builder, loc, box, errorManager, alloc.getSymbol());
505509
}
506510
fir::factory::syncMutableBoxFromIRBox(builder, loc, box);
507-
postAllocationAction(alloc);
511+
postAllocationAction(alloc, box);
508512
errorManager.assignStat(builder, loc, stat);
509513
}
510514

@@ -647,7 +651,7 @@ class AllocateStmtHelper {
647651
setPinnedToFalse();
648652
}
649653
fir::factory::syncMutableBoxFromIRBox(builder, loc, box);
650-
postAllocationAction(alloc);
654+
postAllocationAction(alloc, box);
651655
errorManager.assignStat(builder, loc, stat);
652656
}
653657

flang/lib/Lower/Bridge.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
#include "flang/Lower/Bridge.h"
1414

1515
#include "flang/Lower/Allocatable.h"
16+
#include "flang/Lower/CUDA.h"
1617
#include "flang/Lower/CallInterface.h"
1718
#include "flang/Lower/Coarray.h"
1819
#include "flang/Lower/ConvertCall.h"
1920
#include "flang/Lower/ConvertExpr.h"
2021
#include "flang/Lower/ConvertExprToHLFIR.h"
2122
#include "flang/Lower/ConvertType.h"
2223
#include "flang/Lower/ConvertVariable.h"
23-
#include "flang/Lower/Cuda.h"
2424
#include "flang/Lower/DirectivesCommon.h"
2525
#include "flang/Lower/HostAssociations.h"
2626
#include "flang/Lower/IO.h"

flang/lib/Lower/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ add_flang_library(FortranLower
1515
ConvertProcedureDesignator.cpp
1616
ConvertType.cpp
1717
ConvertVariable.cpp
18+
CUDA.cpp
1819
CustomIntrinsicCall.cpp
1920
HlfirIntrinsics.cpp
2021
HostAssociations.cpp

flang/lib/Lower/CUDA.cpp

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
//===-- CUDA.cpp -- CUDA Fortran specific lowering ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "flang/Lower/CUDA.h"
14+
15+
#define DEBUG_TYPE "flang-lower-cuda"
16+
17+
void Fortran::lower::initializeDeviceComponentAllocator(
18+
Fortran::lower::AbstractConverter &converter,
19+
const Fortran::semantics::Symbol &sym, const fir::MutableBoxValue &box) {
20+
if (const auto *details{
21+
sym.GetUltimate()
22+
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
23+
const Fortran::semantics::DeclTypeSpec *type{details->type()};
24+
const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
25+
: nullptr};
26+
if (derived) {
27+
if (!FindCUDADeviceAllocatableUltimateComponent(*derived))
28+
return; // No device components.
29+
30+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
31+
mlir::Location loc = converter.getCurrentLocation();
32+
33+
mlir::Type baseTy = fir::unwrapRefType(box.getAddr().getType());
34+
35+
// Only pointer and allocatable needs post allocation initialization
36+
// of components descriptors.
37+
if (!fir::isAllocatableType(baseTy) && !fir::isPointerType(baseTy))
38+
return;
39+
40+
// Extract the derived type.
41+
mlir::Type ty = fir::getDerivedType(baseTy);
42+
auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
43+
assert(recTy && "expected fir::RecordType");
44+
45+
if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseTy))
46+
baseTy = boxTy.getEleTy();
47+
baseTy = fir::unwrapRefType(baseTy);
48+
49+
Fortran::semantics::UltimateComponentIterator components{*derived};
50+
mlir::Value loadedBox = fir::LoadOp::create(builder, loc, box.getAddr());
51+
mlir::Value addr;
52+
if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(baseTy)) {
53+
mlir::Type idxTy = builder.getIndexType();
54+
mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
55+
mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
56+
llvm::SmallVector<fir::DoLoopOp> loops;
57+
llvm::SmallVector<mlir::Value> indices;
58+
llvm::SmallVector<mlir::Value> extents;
59+
for (unsigned i = 0; i < seqTy.getDimension(); ++i) {
60+
mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i);
61+
auto dimInfo = fir::BoxDimsOp::create(builder, loc, idxTy, idxTy,
62+
idxTy, loadedBox, dim);
63+
mlir::Value lbub = mlir::arith::AddIOp::create(
64+
builder, loc, dimInfo.getResult(0), dimInfo.getResult(1));
65+
mlir::Value ext =
66+
mlir::arith::SubIOp::create(builder, loc, lbub, one);
67+
mlir::Value cmp = mlir::arith::CmpIOp::create(
68+
builder, loc, mlir::arith::CmpIPredicate::sgt, ext, zero);
69+
ext = mlir::arith::SelectOp::create(builder, loc, cmp, ext, zero);
70+
extents.push_back(ext);
71+
72+
auto loop = fir::DoLoopOp::create(
73+
builder, loc, dimInfo.getResult(0), dimInfo.getResult(1),
74+
dimInfo.getResult(2), /*isUnordered=*/true,
75+
/*finalCount=*/false, mlir::ValueRange{});
76+
loops.push_back(loop);
77+
indices.push_back(loop.getInductionVar());
78+
builder.setInsertionPointToStart(loop.getBody());
79+
}
80+
mlir::Value boxAddr = fir::BoxAddrOp::create(builder, loc, loadedBox);
81+
auto shape = fir::ShapeOp::create(builder, loc, extents);
82+
addr = fir::ArrayCoorOp::create(
83+
builder, loc, fir::ReferenceType::get(recTy), boxAddr, shape,
84+
/*slice=*/mlir::Value{}, indices, /*typeparms=*/mlir::ValueRange{});
85+
} else {
86+
addr = fir::BoxAddrOp::create(builder, loc, loadedBox);
87+
}
88+
for (const auto &compSym : components) {
89+
if (Fortran::semantics::IsDeviceAllocatable(compSym)) {
90+
llvm::SmallVector<mlir::Value> coord;
91+
mlir::Type fieldTy = gatherDeviceComponentCoordinatesAndType(
92+
builder, loc, compSym, recTy, coord);
93+
assert(coord.size() == 1 && "expect one coordinate");
94+
mlir::Value comp = fir::CoordinateOp::create(
95+
builder, loc, builder.getRefType(fieldTy), addr, coord[0]);
96+
cuf::DataAttributeAttr dataAttr =
97+
Fortran::lower::translateSymbolCUFDataAttribute(
98+
builder.getContext(), compSym);
99+
cuf::SetAllocatorIndexOp::create(builder, loc, comp, dataAttr);
100+
}
101+
}
102+
}
103+
}
104+
}
105+
106+
mlir::Type Fortran::lower::gatherDeviceComponentCoordinatesAndType(
107+
fir::FirOpBuilder &builder, mlir::Location loc,
108+
const Fortran::semantics::Symbol &sym, fir::RecordType recTy,
109+
llvm::SmallVector<mlir::Value> &coordinates) {
110+
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
111+
mlir::Type fieldTy;
112+
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
113+
// Field found in the base record type.
114+
auto fieldName = recTy.getTypeList()[fieldIdx].first;
115+
fieldTy = recTy.getTypeList()[fieldIdx].second;
116+
mlir::Value fieldIndex = fir::FieldIndexOp::create(
117+
builder, loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
118+
recTy,
119+
/*typeParams=*/mlir::ValueRange{});
120+
coordinates.push_back(fieldIndex);
121+
} else {
122+
// Field not found in base record type, search in potential
123+
// record type components.
124+
for (auto component : recTy.getTypeList()) {
125+
if (auto childRecTy = mlir::dyn_cast<fir::RecordType>(component.second)) {
126+
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
127+
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
128+
mlir::Value parentFieldIndex = fir::FieldIndexOp::create(
129+
builder, loc, fir::FieldType::get(childRecTy.getContext()),
130+
component.first, recTy,
131+
/*typeParams=*/mlir::ValueRange{});
132+
coordinates.push_back(parentFieldIndex);
133+
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
134+
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
135+
mlir::Value childFieldIndex = fir::FieldIndexOp::create(
136+
builder, loc, fir::FieldType::get(fieldTy.getContext()),
137+
fieldName, childRecTy,
138+
/*typeParams=*/mlir::ValueRange{});
139+
coordinates.push_back(childFieldIndex);
140+
break;
141+
}
142+
}
143+
}
144+
}
145+
if (coordinates.empty())
146+
TODO(loc, "device resident component in complex derived-type hierarchy");
147+
return fieldTy;
148+
}
149+
150+
cuf::DataAttributeAttr Fortran::lower::translateSymbolCUFDataAttribute(
151+
mlir::MLIRContext *mlirContext, const Fortran::semantics::Symbol &sym) {
152+
std::optional<Fortran::common::CUDADataAttr> cudaAttr =
153+
Fortran::semantics::GetCUDADataAttr(&sym.GetUltimate());
154+
return cuf::getDataAttribute(mlirContext, cudaAttr);
155+
}

0 commit comments

Comments
 (0)