@@ -771,79 +771,9 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
771
771
return builder.create <cuf::SharedMemoryOp>(loc, ty, nm, symNm, lenParams,
772
772
indices);
773
773
774
- if (!cuf::isCUDADeviceContext (builder.getRegion ())) {
775
- mlir::Value alloc = builder.create <cuf::AllocOp>(
776
- loc, ty, nm, symNm, dataAttr, lenParams, indices);
777
- if (const auto *details{
778
- ultimateSymbol
779
- .detailsIf <Fortran::semantics::ObjectEntityDetails>()}) {
780
- const Fortran::semantics::DeclTypeSpec *type{details->type ()};
781
- const Fortran::semantics::DerivedTypeSpec *derived{
782
- type ? type->AsDerived () : nullptr };
783
- if (derived) {
784
- Fortran::semantics::UltimateComponentIterator components{*derived};
785
- auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
786
-
787
- llvm::SmallVector<mlir::Value> coordinates;
788
- for (const auto &sym : components) {
789
- if (Fortran::semantics::IsDeviceAllocatable (sym)) {
790
- unsigned fieldIdx = recTy.getFieldIndex (sym.name ().ToString ());
791
- mlir::Type fieldTy;
792
- std::vector<mlir::Value> coordinates;
793
-
794
- if (fieldIdx != std::numeric_limits<unsigned >::max ()) {
795
- // Field found in the base record type.
796
- auto fieldName = recTy.getTypeList ()[fieldIdx].first ;
797
- fieldTy = recTy.getTypeList ()[fieldIdx].second ;
798
- mlir::Value fieldIndex = builder.create <fir::FieldIndexOp>(
799
- loc, fir::FieldType::get (fieldTy.getContext ()), fieldName,
800
- recTy,
801
- /* typeParams=*/ mlir::ValueRange{});
802
- coordinates.push_back (fieldIndex);
803
- } else {
804
- // Field not found in base record type, search in potential
805
- // record type components.
806
- for (auto component : recTy.getTypeList ()) {
807
- if (auto childRecTy =
808
- mlir::dyn_cast<fir::RecordType>(component.second )) {
809
- fieldIdx = childRecTy.getFieldIndex (sym.name ().ToString ());
810
- if (fieldIdx != std::numeric_limits<unsigned >::max ()) {
811
- mlir::Value parentFieldIndex =
812
- builder.create <fir::FieldIndexOp>(
813
- loc, fir::FieldType::get (childRecTy.getContext ()),
814
- component.first , recTy,
815
- /* typeParams=*/ mlir::ValueRange{});
816
- coordinates.push_back (parentFieldIndex);
817
- auto fieldName = childRecTy.getTypeList ()[fieldIdx].first ;
818
- fieldTy = childRecTy.getTypeList ()[fieldIdx].second ;
819
- mlir::Value childFieldIndex =
820
- builder.create <fir::FieldIndexOp>(
821
- loc, fir::FieldType::get (fieldTy.getContext ()),
822
- fieldName, childRecTy,
823
- /* typeParams=*/ mlir::ValueRange{});
824
- coordinates.push_back (childFieldIndex);
825
- break ;
826
- }
827
- }
828
- }
829
- }
830
-
831
- if (coordinates.empty ())
832
- TODO (loc, " device resident component in complex derived-type "
833
- " hierarchy" );
834
-
835
- mlir::Value comp = builder.create <fir::CoordinateOp>(
836
- loc, builder.getRefType (fieldTy), alloc, coordinates);
837
- cuf::DataAttributeAttr dataAttr =
838
- Fortran::lower::translateSymbolCUFDataAttribute (
839
- builder.getContext (), sym);
840
- builder.create <cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
841
- }
842
- }
843
- }
844
- }
845
- return alloc;
846
- }
774
+ if (!cuf::isCUDADeviceContext (builder.getRegion ()))
775
+ return builder.create <cuf::AllocOp>(loc, ty, nm, symNm, dataAttr,
776
+ lenParams, indices);
847
777
}
848
778
849
779
// Let the builder do all the heavy lifting.
@@ -857,6 +787,91 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
857
787
return res;
858
788
}
859
789
790
+ // / Device allocatable components in a derived-type don't have the correct
791
+ // / allocator index in their descriptor when they are created. After
792
+ // / initialization, cuf.set_allocator_idx operations are inserted to set the
793
+ // / correct allocator index for each device component.
794
+ static void
795
+ initializeDeviceComponentAllocator (Fortran::lower::AbstractConverter &converter,
796
+ const Fortran::semantics::Symbol &symbol,
797
+ Fortran::lower::SymMap &symMap) {
798
+ if (const auto *details{
799
+ symbol.GetUltimate ()
800
+ .detailsIf <Fortran::semantics::ObjectEntityDetails>()}) {
801
+ const Fortran::semantics::DeclTypeSpec *type{details->type ()};
802
+ const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived ()
803
+ : nullptr };
804
+ if (derived) {
805
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
806
+ mlir::Location loc = converter.getCurrentLocation ();
807
+
808
+ fir::ExtendedValue exv =
809
+ converter.getSymbolExtendedValue (symbol.GetUltimate (), &symMap);
810
+ auto recTy = mlir::dyn_cast<fir::RecordType>(
811
+ fir::unwrapRefType (fir::getBase (exv).getType ()));
812
+ assert (recTy && " expected fir::RecordType" );
813
+
814
+ llvm::SmallVector<mlir::Value> coordinates;
815
+ Fortran::semantics::UltimateComponentIterator components{*derived};
816
+ for (const auto &sym : components) {
817
+ if (Fortran::semantics::IsDeviceAllocatable (sym)) {
818
+ unsigned fieldIdx = recTy.getFieldIndex (sym.name ().ToString ());
819
+ mlir::Type fieldTy;
820
+ std::vector<mlir::Value> coordinates;
821
+
822
+ if (fieldIdx != std::numeric_limits<unsigned >::max ()) {
823
+ // Field found in the base record type.
824
+ auto fieldName = recTy.getTypeList ()[fieldIdx].first ;
825
+ fieldTy = recTy.getTypeList ()[fieldIdx].second ;
826
+ mlir::Value fieldIndex = builder.create <fir::FieldIndexOp>(
827
+ loc, fir::FieldType::get (fieldTy.getContext ()), fieldName,
828
+ recTy,
829
+ /* typeParams=*/ mlir::ValueRange{});
830
+ coordinates.push_back (fieldIndex);
831
+ } else {
832
+ // Field not found in base record type, search in potential
833
+ // record type components.
834
+ for (auto component : recTy.getTypeList ()) {
835
+ if (auto childRecTy =
836
+ mlir::dyn_cast<fir::RecordType>(component.second )) {
837
+ fieldIdx = childRecTy.getFieldIndex (sym.name ().ToString ());
838
+ if (fieldIdx != std::numeric_limits<unsigned >::max ()) {
839
+ mlir::Value parentFieldIndex =
840
+ builder.create <fir::FieldIndexOp>(
841
+ loc, fir::FieldType::get (childRecTy.getContext ()),
842
+ component.first , recTy,
843
+ /* typeParams=*/ mlir::ValueRange{});
844
+ coordinates.push_back (parentFieldIndex);
845
+ auto fieldName = childRecTy.getTypeList ()[fieldIdx].first ;
846
+ fieldTy = childRecTy.getTypeList ()[fieldIdx].second ;
847
+ mlir::Value childFieldIndex =
848
+ builder.create <fir::FieldIndexOp>(
849
+ loc, fir::FieldType::get (fieldTy.getContext ()),
850
+ fieldName, childRecTy,
851
+ /* typeParams=*/ mlir::ValueRange{});
852
+ coordinates.push_back (childFieldIndex);
853
+ break ;
854
+ }
855
+ }
856
+ }
857
+ }
858
+
859
+ if (coordinates.empty ())
860
+ TODO (loc, " device resident component in complex derived-type "
861
+ " hierarchy" );
862
+
863
+ mlir::Value comp = builder.create <fir::CoordinateOp>(
864
+ loc, builder.getRefType (fieldTy), fir::getBase (exv), coordinates);
865
+ cuf::DataAttributeAttr dataAttr =
866
+ Fortran::lower::translateSymbolCUFDataAttribute (
867
+ builder.getContext (), sym);
868
+ builder.create <cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
869
+ }
870
+ }
871
+ }
872
+ }
873
+ }
874
+
860
875
// / Must \p var be default initialized at runtime when entering its scope.
861
876
static bool
862
877
mustBeDefaultInitializedAtRuntime (const Fortran::lower::pft::Variable &var) {
@@ -1179,6 +1194,9 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
1179
1194
if (mustBeDefaultInitializedAtRuntime (var))
1180
1195
Fortran::lower::defaultInitializeAtRuntime (converter, var.getSymbol (),
1181
1196
symMap);
1197
+ if (converter.getFoldingContext ().languageFeatures ().IsEnabled (
1198
+ Fortran::common::LanguageFeature::CUDA))
1199
+ initializeDeviceComponentAllocator (converter, var.getSymbol (), symMap);
1182
1200
auto *builder = &converter.getFirOpBuilder ();
1183
1201
if (needCUDAAlloc (var.getSymbol ()) &&
1184
1202
!cuf::isCUDADeviceContext (builder->getRegion ())) {
@@ -1437,6 +1455,9 @@ static void instantiateAlias(Fortran::lower::AbstractConverter &converter,
1437
1455
if (mustBeDefaultInitializedAtRuntime (var))
1438
1456
Fortran::lower::defaultInitializeAtRuntime (converter, var.getSymbol (),
1439
1457
symMap);
1458
+ if (converter.getFoldingContext ().languageFeatures ().IsEnabled (
1459
+ Fortran::common::LanguageFeature::CUDA))
1460
+ initializeDeviceComponentAllocator (converter, var.getSymbol (), symMap);
1440
1461
}
1441
1462
1442
1463
// ===--------------------------------------------------------------===//
0 commit comments