@@ -771,79 +771,9 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
771771 return builder.create <cuf::SharedMemoryOp>(loc, ty, nm, symNm, lenParams,
772772 indices);
773773
774- if (!cuf::isCUDADeviceContext (builder.getRegion ())) {
775- mlir::Value alloc = builder.create <cuf::AllocOp>(
776- loc, ty, nm, symNm, dataAttr, lenParams, indices);
777- if (const auto *details{
778- ultimateSymbol
779- .detailsIf <Fortran::semantics::ObjectEntityDetails>()}) {
780- const Fortran::semantics::DeclTypeSpec *type{details->type ()};
781- const Fortran::semantics::DerivedTypeSpec *derived{
782- type ? type->AsDerived () : nullptr };
783- if (derived) {
784- Fortran::semantics::UltimateComponentIterator components{*derived};
785- auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
786-
787- llvm::SmallVector<mlir::Value> coordinates;
788- for (const auto &sym : components) {
789- if (Fortran::semantics::IsDeviceAllocatable (sym)) {
790- unsigned fieldIdx = recTy.getFieldIndex (sym.name ().ToString ());
791- mlir::Type fieldTy;
792- std::vector<mlir::Value> coordinates;
793-
794- if (fieldIdx != std::numeric_limits<unsigned >::max ()) {
795- // Field found in the base record type.
796- auto fieldName = recTy.getTypeList ()[fieldIdx].first ;
797- fieldTy = recTy.getTypeList ()[fieldIdx].second ;
798- mlir::Value fieldIndex = builder.create <fir::FieldIndexOp>(
799- loc, fir::FieldType::get (fieldTy.getContext ()), fieldName,
800- recTy,
801- /* typeParams=*/ mlir::ValueRange{});
802- coordinates.push_back (fieldIndex);
803- } else {
804- // Field not found in base record type, search in potential
805- // record type components.
806- for (auto component : recTy.getTypeList ()) {
807- if (auto childRecTy =
808- mlir::dyn_cast<fir::RecordType>(component.second )) {
809- fieldIdx = childRecTy.getFieldIndex (sym.name ().ToString ());
810- if (fieldIdx != std::numeric_limits<unsigned >::max ()) {
811- mlir::Value parentFieldIndex =
812- builder.create <fir::FieldIndexOp>(
813- loc, fir::FieldType::get (childRecTy.getContext ()),
814- component.first , recTy,
815- /* typeParams=*/ mlir::ValueRange{});
816- coordinates.push_back (parentFieldIndex);
817- auto fieldName = childRecTy.getTypeList ()[fieldIdx].first ;
818- fieldTy = childRecTy.getTypeList ()[fieldIdx].second ;
819- mlir::Value childFieldIndex =
820- builder.create <fir::FieldIndexOp>(
821- loc, fir::FieldType::get (fieldTy.getContext ()),
822- fieldName, childRecTy,
823- /* typeParams=*/ mlir::ValueRange{});
824- coordinates.push_back (childFieldIndex);
825- break ;
826- }
827- }
828- }
829- }
830-
831- if (coordinates.empty ())
832- TODO (loc, " device resident component in complex derived-type "
833- " hierarchy" );
834-
835- mlir::Value comp = builder.create <fir::CoordinateOp>(
836- loc, builder.getRefType (fieldTy), alloc, coordinates);
837- cuf::DataAttributeAttr dataAttr =
838- Fortran::lower::translateSymbolCUFDataAttribute (
839- builder.getContext (), sym);
840- builder.create <cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
841- }
842- }
843- }
844- }
845- return alloc;
846- }
774+ if (!cuf::isCUDADeviceContext (builder.getRegion ()))
775+ return builder.create <cuf::AllocOp>(loc, ty, nm, symNm, dataAttr,
776+ lenParams, indices);
847777 }
848778
849779 // Let the builder do all the heavy lifting.
@@ -857,6 +787,91 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
857787 return res;
858788}
859789
790+ // / Device allocatable components in a derived-type don't have the correct
791+ // / allocator index in their descriptor when they are created. After
792+ // / initialization, cuf.set_allocator_idx operations are inserted to set the
793+ // / correct allocator index for each device component.
794+ static void
795+ initializeDeviceComponentAllocator (Fortran::lower::AbstractConverter &converter,
796+ const Fortran::semantics::Symbol &symbol,
797+ Fortran::lower::SymMap &symMap) {
798+ if (const auto *details{
799+ symbol.GetUltimate ()
800+ .detailsIf <Fortran::semantics::ObjectEntityDetails>()}) {
801+ const Fortran::semantics::DeclTypeSpec *type{details->type ()};
802+ const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived ()
803+ : nullptr };
804+ if (derived) {
805+ fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
806+ mlir::Location loc = converter.getCurrentLocation ();
807+
808+ fir::ExtendedValue exv =
809+ converter.getSymbolExtendedValue (symbol.GetUltimate (), &symMap);
810+ auto recTy = mlir::dyn_cast<fir::RecordType>(
811+ fir::unwrapRefType (fir::getBase (exv).getType ()));
812+ assert (recTy && " expected fir::RecordType" );
813+
814+ llvm::SmallVector<mlir::Value> coordinates;
815+ Fortran::semantics::UltimateComponentIterator components{*derived};
816+ for (const auto &sym : components) {
817+ if (Fortran::semantics::IsDeviceAllocatable (sym)) {
818+ unsigned fieldIdx = recTy.getFieldIndex (sym.name ().ToString ());
819+ mlir::Type fieldTy;
820+ std::vector<mlir::Value> coordinates;
821+
822+ if (fieldIdx != std::numeric_limits<unsigned >::max ()) {
823+ // Field found in the base record type.
824+ auto fieldName = recTy.getTypeList ()[fieldIdx].first ;
825+ fieldTy = recTy.getTypeList ()[fieldIdx].second ;
826+ mlir::Value fieldIndex = builder.create <fir::FieldIndexOp>(
827+ loc, fir::FieldType::get (fieldTy.getContext ()), fieldName,
828+ recTy,
829+ /* typeParams=*/ mlir::ValueRange{});
830+ coordinates.push_back (fieldIndex);
831+ } else {
832+ // Field not found in base record type, search in potential
833+ // record type components.
834+ for (auto component : recTy.getTypeList ()) {
835+ if (auto childRecTy =
836+ mlir::dyn_cast<fir::RecordType>(component.second )) {
837+ fieldIdx = childRecTy.getFieldIndex (sym.name ().ToString ());
838+ if (fieldIdx != std::numeric_limits<unsigned >::max ()) {
839+ mlir::Value parentFieldIndex =
840+ builder.create <fir::FieldIndexOp>(
841+ loc, fir::FieldType::get (childRecTy.getContext ()),
842+ component.first , recTy,
843+ /* typeParams=*/ mlir::ValueRange{});
844+ coordinates.push_back (parentFieldIndex);
845+ auto fieldName = childRecTy.getTypeList ()[fieldIdx].first ;
846+ fieldTy = childRecTy.getTypeList ()[fieldIdx].second ;
847+ mlir::Value childFieldIndex =
848+ builder.create <fir::FieldIndexOp>(
849+ loc, fir::FieldType::get (fieldTy.getContext ()),
850+ fieldName, childRecTy,
851+ /* typeParams=*/ mlir::ValueRange{});
852+ coordinates.push_back (childFieldIndex);
853+ break ;
854+ }
855+ }
856+ }
857+ }
858+
859+ if (coordinates.empty ())
860+ TODO (loc, " device resident component in complex derived-type "
861+ " hierarchy" );
862+
863+ mlir::Value comp = builder.create <fir::CoordinateOp>(
864+ loc, builder.getRefType (fieldTy), fir::getBase (exv), coordinates);
865+ cuf::DataAttributeAttr dataAttr =
866+ Fortran::lower::translateSymbolCUFDataAttribute (
867+ builder.getContext (), sym);
868+ builder.create <cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
869+ }
870+ }
871+ }
872+ }
873+ }
874+
860875// / Must \p var be default initialized at runtime when entering its scope.
861876static bool
862877mustBeDefaultInitializedAtRuntime (const Fortran::lower::pft::Variable &var) {
@@ -1179,6 +1194,9 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
11791194 if (mustBeDefaultInitializedAtRuntime (var))
11801195 Fortran::lower::defaultInitializeAtRuntime (converter, var.getSymbol (),
11811196 symMap);
1197+ if (converter.getFoldingContext ().languageFeatures ().IsEnabled (
1198+ Fortran::common::LanguageFeature::CUDA))
1199+ initializeDeviceComponentAllocator (converter, var.getSymbol (), symMap);
11821200 auto *builder = &converter.getFirOpBuilder ();
11831201 if (needCUDAAlloc (var.getSymbol ()) &&
11841202 !cuf::isCUDADeviceContext (builder->getRegion ())) {
@@ -1437,6 +1455,9 @@ static void instantiateAlias(Fortran::lower::AbstractConverter &converter,
14371455 if (mustBeDefaultInitializedAtRuntime (var))
14381456 Fortran::lower::defaultInitializeAtRuntime (converter, var.getSymbol (),
14391457 symMap);
1458+ if (converter.getFoldingContext ().languageFeatures ().IsEnabled (
1459+ Fortran::common::LanguageFeature::CUDA))
1460+ initializeDeviceComponentAllocator (converter, var.getSymbol (), symMap);
14401461}
14411462
14421463// ===--------------------------------------------------------------===//
0 commit comments