|
14 | 14 | #include "flang/Lower/AbstractConverter.h" |
15 | 15 | #include "flang/Lower/Allocatable.h" |
16 | 16 | #include "flang/Lower/BoxAnalyzer.h" |
17 | | -#include "flang/Lower/CUDA.h" |
18 | 17 | #include "flang/Lower/CallInterface.h" |
19 | 18 | #include "flang/Lower/ConvertConstant.h" |
20 | 19 | #include "flang/Lower/ConvertExpr.h" |
21 | 20 | #include "flang/Lower/ConvertExprToHLFIR.h" |
22 | 21 | #include "flang/Lower/ConvertProcedureDesignator.h" |
| 22 | +#include "flang/Lower/Cuda.h" |
23 | 23 | #include "flang/Lower/Mangler.h" |
24 | 24 | #include "flang/Lower/PFTBuilder.h" |
25 | 25 | #include "flang/Lower/StatementContext.h" |
@@ -814,24 +814,81 @@ initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter, |
814 | 814 | baseTy = boxTy.getEleTy(); |
815 | 815 | baseTy = fir::unwrapRefType(baseTy); |
816 | 816 |
|
817 | | - if (fir::isAllocatableType(fir::getBase(exv).getType()) || |
818 | | - fir::isPointerType(fir::getBase(exv).getType())) |
| 817 | + if (mlir::isa<fir::SequenceType>(baseTy) && |
| 818 | + (fir::isAllocatableType(fir::getBase(exv).getType()) || |
| 819 | + fir::isPointerType(fir::getBase(exv).getType()))) |
819 | 820 | return; // Allocator index need to be set after allocation. |
820 | 821 |
|
821 | 822 | auto recTy = |
822 | 823 | mlir::dyn_cast<fir::RecordType>(fir::unwrapSequenceType(baseTy)); |
823 | 824 | assert(recTy && "expected fir::RecordType"); |
824 | 825 |
|
| 826 | + llvm::SmallVector<mlir::Value> coordinates; |
825 | 827 | Fortran::semantics::UltimateComponentIterator components{*derived}; |
826 | 828 | for (const auto &sym : components) { |
827 | 829 | if (Fortran::semantics::IsDeviceAllocatable(sym)) { |
828 | | - llvm::SmallVector<mlir::Value> coord; |
829 | | - mlir::Type fieldTy = |
830 | | - Fortran::lower::gatherDeviceComponentCoordinatesAndType( |
831 | | - builder, loc, sym, recTy, coord); |
| 830 | + unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString()); |
| 831 | + mlir::Type fieldTy; |
| 832 | + llvm::SmallVector<mlir::Value> coordinates; |
| 833 | + |
| 834 | + if (fieldIdx != std::numeric_limits<unsigned>::max()) { |
| 835 | + // Field found in the base record type. |
| 836 | + auto fieldName = recTy.getTypeList()[fieldIdx].first; |
| 837 | + fieldTy = recTy.getTypeList()[fieldIdx].second; |
| 838 | + mlir::Value fieldIndex = fir::FieldIndexOp::create( |
| 839 | + builder, loc, fir::FieldType::get(fieldTy.getContext()), |
| 840 | + fieldName, recTy, |
| 841 | + /*typeParams=*/mlir::ValueRange{}); |
| 842 | + coordinates.push_back(fieldIndex); |
| 843 | + } else { |
| 844 | + // Field not found in base record type, search in potential |
| 845 | + // record type components. |
| 846 | + for (auto component : recTy.getTypeList()) { |
| 847 | + if (auto childRecTy = |
| 848 | + mlir::dyn_cast<fir::RecordType>(component.second)) { |
| 849 | + fieldIdx = childRecTy.getFieldIndex(sym.name().ToString()); |
| 850 | + if (fieldIdx != std::numeric_limits<unsigned>::max()) { |
| 851 | + mlir::Value parentFieldIndex = fir::FieldIndexOp::create( |
| 852 | + builder, loc, |
| 853 | + fir::FieldType::get(childRecTy.getContext()), |
| 854 | + component.first, recTy, |
| 855 | + /*typeParams=*/mlir::ValueRange{}); |
| 856 | + coordinates.push_back(parentFieldIndex); |
| 857 | + auto fieldName = childRecTy.getTypeList()[fieldIdx].first; |
| 858 | + fieldTy = childRecTy.getTypeList()[fieldIdx].second; |
| 859 | + mlir::Value childFieldIndex = fir::FieldIndexOp::create( |
| 860 | + builder, loc, fir::FieldType::get(fieldTy.getContext()), |
| 861 | + fieldName, childRecTy, |
| 862 | + /*typeParams=*/mlir::ValueRange{}); |
| 863 | + coordinates.push_back(childFieldIndex); |
| 864 | + break; |
| 865 | + } |
| 866 | + } |
| 867 | + } |
| 868 | + } |
| 869 | + |
| 870 | + if (coordinates.empty()) |
| 871 | + TODO(loc, "device resident component in complex derived-type " |
| 872 | + "hierarchy"); |
| 873 | + |
832 | 874 | mlir::Value base = fir::getBase(exv); |
833 | | - mlir::Value comp = fir::CoordinateOp::create( |
834 | | - builder, loc, builder.getRefType(fieldTy), base, coord); |
| 875 | + mlir::Value comp; |
| 876 | + if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(base.getType()))) { |
| 877 | + mlir::Value box = fir::LoadOp::create(builder, loc, base); |
| 878 | + mlir::Value addr = fir::BoxAddrOp::create(builder, loc, box); |
| 879 | + llvm::SmallVector<mlir::Value> lenParams; |
| 880 | + assert(coordinates.size() == 1 && "expect one coordinate"); |
| 881 | + auto field = mlir::dyn_cast<fir::FieldIndexOp>( |
| 882 | + coordinates[0].getDefiningOp()); |
| 883 | + comp = hlfir::DesignateOp::create( |
| 884 | + builder, loc, builder.getRefType(fieldTy), addr, |
| 885 | + /*component=*/field.getFieldName(), |
| 886 | + /*componentShape=*/mlir::Value{}, |
| 887 | + hlfir::DesignateOp::Subscripts{}); |
| 888 | + } else { |
| 889 | + comp = fir::CoordinateOp::create( |
| 890 | + builder, loc, builder.getRefType(fieldTy), base, coordinates); |
| 891 | + } |
835 | 892 | cuf::DataAttributeAttr dataAttr = |
836 | 893 | Fortran::lower::translateSymbolCUFDataAttribute( |
837 | 894 | builder.getContext(), sym); |
|
0 commit comments