Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 94 additions & 73 deletions flang/lib/Lower/ConvertVariable.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -771,79 +771,9 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
return builder.create<cuf::SharedMemoryOp>(loc, ty, nm, symNm, lenParams,
indices);

if (!cuf::isCUDADeviceContext(builder.getRegion())) {
mlir::Value alloc = builder.create<cuf::AllocOp>(
loc, ty, nm, symNm, dataAttr, lenParams, indices);
if (const auto *details{
ultimateSymbol
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
const Fortran::semantics::DeclTypeSpec *type{details->type()};
const Fortran::semantics::DerivedTypeSpec *derived{
type ? type->AsDerived() : nullptr};
if (derived) {
Fortran::semantics::UltimateComponentIterator components{*derived};
auto recTy = mlir::dyn_cast<fir::RecordType>(ty);

llvm::SmallVector<mlir::Value> coordinates;
for (const auto &sym : components) {
if (Fortran::semantics::IsDeviceAllocatable(sym)) {
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
mlir::Type fieldTy;
std::vector<mlir::Value> coordinates;

if (fieldIdx != std::numeric_limits<unsigned>::max()) {
// Field found in the base record type.
auto fieldName = recTy.getTypeList()[fieldIdx].first;
fieldTy = recTy.getTypeList()[fieldIdx].second;
mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
recTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(fieldIndex);
} else {
// Field not found in base record type, search in potential
// record type components.
for (auto component : recTy.getTypeList()) {
if (auto childRecTy =
mlir::dyn_cast<fir::RecordType>(component.second)) {
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
mlir::Value parentFieldIndex =
builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(childRecTy.getContext()),
component.first, recTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(parentFieldIndex);
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
mlir::Value childFieldIndex =
builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(fieldTy.getContext()),
fieldName, childRecTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(childFieldIndex);
break;
}
}
}
}

if (coordinates.empty())
TODO(loc, "device resident component in complex derived-type "
"hierarchy");

mlir::Value comp = builder.create<fir::CoordinateOp>(
loc, builder.getRefType(fieldTy), alloc, coordinates);
cuf::DataAttributeAttr dataAttr =
Fortran::lower::translateSymbolCUFDataAttribute(
builder.getContext(), sym);
builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
}
}
}
}
return alloc;
}
if (!cuf::isCUDADeviceContext(builder.getRegion()))
return builder.create<cuf::AllocOp>(loc, ty, nm, symNm, dataAttr,
lenParams, indices);
}

// Let the builder do all the heavy lifting.
Expand All @@ -857,6 +787,91 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
return res;
}

/// Device allocatable component in a derived-type don't have the correct
/// allocator index in their descriptor when they are created. After
/// initialization, cuf.set_allocator_idx operations are inserted to set the
/// correct allocator index for each device component.
static void
initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter,
const Fortran::semantics::Symbol &symbol,
Fortran::lower::SymMap &symMap) {
if (const auto *details{
symbol.GetUltimate()
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
const Fortran::semantics::DeclTypeSpec *type{details->type()};
const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
: nullptr};
if (derived) {
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
mlir::Location loc = converter.getCurrentLocation();

fir::ExtendedValue exv =
converter.getSymbolExtendedValue(symbol.GetUltimate(), &symMap);
auto recTy = mlir::dyn_cast<fir::RecordType>(
fir::unwrapRefType(fir::getBase(exv).getType()));
assert(recTy && "expected fir::RecordType");

llvm::SmallVector<mlir::Value> coordinates;
Fortran::semantics::UltimateComponentIterator components{*derived};
for (const auto &sym : components) {
if (Fortran::semantics::IsDeviceAllocatable(sym)) {
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
mlir::Type fieldTy;
std::vector<mlir::Value> coordinates;

if (fieldIdx != std::numeric_limits<unsigned>::max()) {
// Field found in the base record type.
auto fieldName = recTy.getTypeList()[fieldIdx].first;
fieldTy = recTy.getTypeList()[fieldIdx].second;
mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
recTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(fieldIndex);
} else {
// Field not found in base record type, search in potential
// record type components.
for (auto component : recTy.getTypeList()) {
if (auto childRecTy =
mlir::dyn_cast<fir::RecordType>(component.second)) {
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
mlir::Value parentFieldIndex =
builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(childRecTy.getContext()),
component.first, recTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(parentFieldIndex);
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
mlir::Value childFieldIndex =
builder.create<fir::FieldIndexOp>(
loc, fir::FieldType::get(fieldTy.getContext()),
fieldName, childRecTy,
/*typeParams=*/mlir::ValueRange{});
coordinates.push_back(childFieldIndex);
break;
}
}
}
}

if (coordinates.empty())
TODO(loc, "device resident component in complex derived-type "
"hierarchy");

mlir::Value comp = builder.create<fir::CoordinateOp>(
loc, builder.getRefType(fieldTy), fir::getBase(exv), coordinates);
cuf::DataAttributeAttr dataAttr =
Fortran::lower::translateSymbolCUFDataAttribute(
builder.getContext(), sym);
builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
}
}
}
}
}

/// Must \p var be default initialized at runtime when entering its scope.
static bool
mustBeDefaultInitializedAtRuntime(const Fortran::lower::pft::Variable &var) {
Expand Down Expand Up @@ -1179,6 +1194,9 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
if (mustBeDefaultInitializedAtRuntime(var))
Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
symMap);
if (converter.getFoldingContext().languageFeatures().IsEnabled(
Fortran::common::LanguageFeature::CUDA))
initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap);
auto *builder = &converter.getFirOpBuilder();
if (needCUDAAlloc(var.getSymbol()) &&
!cuf::isCUDADeviceContext(builder->getRegion())) {
Expand Down Expand Up @@ -1437,6 +1455,9 @@ static void instantiateAlias(Fortran::lower::AbstractConverter &converter,
if (mustBeDefaultInitializedAtRuntime(var))
Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
symMap);
if (converter.getFoldingContext().languageFeatures().IsEnabled(
Fortran::common::LanguageFeature::CUDA))
initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap);
}

//===--------------------------------------------------------------===//
Expand Down
14 changes: 11 additions & 3 deletions flang/test/Lower/CUDA/cuda-set-allocator.cuf
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,18 @@ contains
end subroutine

! CHECK-LABEL: func.func @_QMm1Psub1()
! CHECK: %[[DT:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]], x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
! CHECK: %[[DT:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>, !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>)
! CHECK: fir.address_of(@_QQ_QMm1Tty_device.DerivedInit)
! CHECK: fir.copy
! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]]#0, x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
! CHECK: cuf.set_allocator_idx %[[X]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]], z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]]#0, z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
! CHECK: cuf.set_allocator_idx %[[Z]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}

end module





Loading