Skip to content

Commit b4e2272

Browse files
authored
[flang][cuda] Move cuf.set_allocator_idx after derived-type init (#148936)
Derived type initialization overwrite the component descriptor. Place the `cuf.set_allocator_idx` after the initialization is performed.
1 parent 42d2ae1 commit b4e2272

File tree

2 files changed

+100
-76
lines changed

2 files changed

+100
-76
lines changed

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 94 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -771,79 +771,9 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
771771
return builder.create<cuf::SharedMemoryOp>(loc, ty, nm, symNm, lenParams,
772772
indices);
773773

774-
if (!cuf::isCUDADeviceContext(builder.getRegion())) {
775-
mlir::Value alloc = builder.create<cuf::AllocOp>(
776-
loc, ty, nm, symNm, dataAttr, lenParams, indices);
777-
if (const auto *details{
778-
ultimateSymbol
779-
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
780-
const Fortran::semantics::DeclTypeSpec *type{details->type()};
781-
const Fortran::semantics::DerivedTypeSpec *derived{
782-
type ? type->AsDerived() : nullptr};
783-
if (derived) {
784-
Fortran::semantics::UltimateComponentIterator components{*derived};
785-
auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
786-
787-
llvm::SmallVector<mlir::Value> coordinates;
788-
for (const auto &sym : components) {
789-
if (Fortran::semantics::IsDeviceAllocatable(sym)) {
790-
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
791-
mlir::Type fieldTy;
792-
std::vector<mlir::Value> coordinates;
793-
794-
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
795-
// Field found in the base record type.
796-
auto fieldName = recTy.getTypeList()[fieldIdx].first;
797-
fieldTy = recTy.getTypeList()[fieldIdx].second;
798-
mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
799-
loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
800-
recTy,
801-
/*typeParams=*/mlir::ValueRange{});
802-
coordinates.push_back(fieldIndex);
803-
} else {
804-
// Field not found in base record type, search in potential
805-
// record type components.
806-
for (auto component : recTy.getTypeList()) {
807-
if (auto childRecTy =
808-
mlir::dyn_cast<fir::RecordType>(component.second)) {
809-
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
810-
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
811-
mlir::Value parentFieldIndex =
812-
builder.create<fir::FieldIndexOp>(
813-
loc, fir::FieldType::get(childRecTy.getContext()),
814-
component.first, recTy,
815-
/*typeParams=*/mlir::ValueRange{});
816-
coordinates.push_back(parentFieldIndex);
817-
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
818-
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
819-
mlir::Value childFieldIndex =
820-
builder.create<fir::FieldIndexOp>(
821-
loc, fir::FieldType::get(fieldTy.getContext()),
822-
fieldName, childRecTy,
823-
/*typeParams=*/mlir::ValueRange{});
824-
coordinates.push_back(childFieldIndex);
825-
break;
826-
}
827-
}
828-
}
829-
}
830-
831-
if (coordinates.empty())
832-
TODO(loc, "device resident component in complex derived-type "
833-
"hierarchy");
834-
835-
mlir::Value comp = builder.create<fir::CoordinateOp>(
836-
loc, builder.getRefType(fieldTy), alloc, coordinates);
837-
cuf::DataAttributeAttr dataAttr =
838-
Fortran::lower::translateSymbolCUFDataAttribute(
839-
builder.getContext(), sym);
840-
builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
841-
}
842-
}
843-
}
844-
}
845-
return alloc;
846-
}
774+
if (!cuf::isCUDADeviceContext(builder.getRegion()))
775+
return builder.create<cuf::AllocOp>(loc, ty, nm, symNm, dataAttr,
776+
lenParams, indices);
847777
}
848778

849779
// Let the builder do all the heavy lifting.
@@ -857,6 +787,91 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
857787
return res;
858788
}
859789

790+
/// Device allocatable components in a derived-type don't have the correct
791+
/// allocator index in their descriptor when they are created. After
792+
/// initialization, cuf.set_allocator_idx operations are inserted to set the
793+
/// correct allocator index for each device component.
794+
static void
795+
initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter,
796+
const Fortran::semantics::Symbol &symbol,
797+
Fortran::lower::SymMap &symMap) {
798+
if (const auto *details{
799+
symbol.GetUltimate()
800+
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
801+
const Fortran::semantics::DeclTypeSpec *type{details->type()};
802+
const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
803+
: nullptr};
804+
if (derived) {
805+
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
806+
mlir::Location loc = converter.getCurrentLocation();
807+
808+
fir::ExtendedValue exv =
809+
converter.getSymbolExtendedValue(symbol.GetUltimate(), &symMap);
810+
auto recTy = mlir::dyn_cast<fir::RecordType>(
811+
fir::unwrapRefType(fir::getBase(exv).getType()));
812+
assert(recTy && "expected fir::RecordType");
813+
814+
llvm::SmallVector<mlir::Value> coordinates;
815+
Fortran::semantics::UltimateComponentIterator components{*derived};
816+
for (const auto &sym : components) {
817+
if (Fortran::semantics::IsDeviceAllocatable(sym)) {
818+
unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
819+
mlir::Type fieldTy;
820+
std::vector<mlir::Value> coordinates;
821+
822+
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
823+
// Field found in the base record type.
824+
auto fieldName = recTy.getTypeList()[fieldIdx].first;
825+
fieldTy = recTy.getTypeList()[fieldIdx].second;
826+
mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
827+
loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
828+
recTy,
829+
/*typeParams=*/mlir::ValueRange{});
830+
coordinates.push_back(fieldIndex);
831+
} else {
832+
// Field not found in base record type, search in potential
833+
// record type components.
834+
for (auto component : recTy.getTypeList()) {
835+
if (auto childRecTy =
836+
mlir::dyn_cast<fir::RecordType>(component.second)) {
837+
fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
838+
if (fieldIdx != std::numeric_limits<unsigned>::max()) {
839+
mlir::Value parentFieldIndex =
840+
builder.create<fir::FieldIndexOp>(
841+
loc, fir::FieldType::get(childRecTy.getContext()),
842+
component.first, recTy,
843+
/*typeParams=*/mlir::ValueRange{});
844+
coordinates.push_back(parentFieldIndex);
845+
auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
846+
fieldTy = childRecTy.getTypeList()[fieldIdx].second;
847+
mlir::Value childFieldIndex =
848+
builder.create<fir::FieldIndexOp>(
849+
loc, fir::FieldType::get(fieldTy.getContext()),
850+
fieldName, childRecTy,
851+
/*typeParams=*/mlir::ValueRange{});
852+
coordinates.push_back(childFieldIndex);
853+
break;
854+
}
855+
}
856+
}
857+
}
858+
859+
if (coordinates.empty())
860+
TODO(loc, "device resident component in complex derived-type "
861+
"hierarchy");
862+
863+
mlir::Value comp = builder.create<fir::CoordinateOp>(
864+
loc, builder.getRefType(fieldTy), fir::getBase(exv), coordinates);
865+
cuf::DataAttributeAttr dataAttr =
866+
Fortran::lower::translateSymbolCUFDataAttribute(
867+
builder.getContext(), sym);
868+
builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
869+
}
870+
}
871+
}
872+
}
873+
}
874+
860875
/// Must \p var be default initialized at runtime when entering its scope.
861876
static bool
862877
mustBeDefaultInitializedAtRuntime(const Fortran::lower::pft::Variable &var) {
@@ -1179,6 +1194,9 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
11791194
if (mustBeDefaultInitializedAtRuntime(var))
11801195
Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
11811196
symMap);
1197+
if (converter.getFoldingContext().languageFeatures().IsEnabled(
1198+
Fortran::common::LanguageFeature::CUDA))
1199+
initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap);
11821200
auto *builder = &converter.getFirOpBuilder();
11831201
if (needCUDAAlloc(var.getSymbol()) &&
11841202
!cuf::isCUDADeviceContext(builder->getRegion())) {
@@ -1437,6 +1455,9 @@ static void instantiateAlias(Fortran::lower::AbstractConverter &converter,
14371455
if (mustBeDefaultInitializedAtRuntime(var))
14381456
Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
14391457
symMap);
1458+
if (converter.getFoldingContext().languageFeatures().IsEnabled(
1459+
Fortran::common::LanguageFeature::CUDA))
1460+
initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap);
14401461
}
14411462

14421463
//===--------------------------------------------------------------===//

flang/test/Lower/CUDA/cuda-set-allocator.cuf

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,13 @@ contains
1212
end subroutine
1313

1414
! CHECK-LABEL: func.func @_QMm1Psub1()
15-
! CHECK: %[[DT:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
16-
! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]], x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
15+
! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
16+
! CHECK: %[[DT:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>, !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>)
17+
! CHECK: fir.address_of(@_QQ_QMm1Tty_device.DerivedInit)
18+
! CHECK: fir.copy
19+
! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]]#0, x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
1720
! CHECK: cuf.set_allocator_idx %[[X]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
18-
! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]], z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
21+
! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]]#0, z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
1922
! CHECK: cuf.set_allocator_idx %[[Z]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
2023

2124
end module

0 commit comments

Comments
 (0)