|
14 | 14 | #include "flang/Lower/AbstractConverter.h"
|
15 | 15 | #include "flang/Lower/Allocatable.h"
|
16 | 16 | #include "flang/Lower/BoxAnalyzer.h"
|
17 |
| -#include "flang/Lower/CUDA.h" |
18 | 17 | #include "flang/Lower/CallInterface.h"
|
19 | 18 | #include "flang/Lower/ConvertConstant.h"
|
20 | 19 | #include "flang/Lower/ConvertExpr.h"
|
21 | 20 | #include "flang/Lower/ConvertExprToHLFIR.h"
|
22 | 21 | #include "flang/Lower/ConvertProcedureDesignator.h"
|
| 22 | +#include "flang/Lower/Cuda.h" |
23 | 23 | #include "flang/Lower/Mangler.h"
|
24 | 24 | #include "flang/Lower/PFTBuilder.h"
|
25 | 25 | #include "flang/Lower/StatementContext.h"
|
@@ -814,24 +814,81 @@ initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter,
|
814 | 814 | baseTy = boxTy.getEleTy();
|
815 | 815 | baseTy = fir::unwrapRefType(baseTy);
|
816 | 816 |
|
817 |
| - if (fir::isAllocatableType(fir::getBase(exv).getType()) || |
818 |
| - fir::isPointerType(fir::getBase(exv).getType())) |
| 817 | + if (mlir::isa<fir::SequenceType>(baseTy) && |
| 818 | + (fir::isAllocatableType(fir::getBase(exv).getType()) || |
| 819 | + fir::isPointerType(fir::getBase(exv).getType()))) |
819 | 820 | return; // Allocator index need to be set after allocation.
|
820 | 821 |
|
821 | 822 | auto recTy =
|
822 | 823 | mlir::dyn_cast<fir::RecordType>(fir::unwrapSequenceType(baseTy));
|
823 | 824 | assert(recTy && "expected fir::RecordType");
|
824 | 825 |
|
| 826 | + llvm::SmallVector<mlir::Value> coordinates; |
825 | 827 | Fortran::semantics::UltimateComponentIterator components{*derived};
|
826 | 828 | for (const auto &sym : components) {
|
827 | 829 | if (Fortran::semantics::IsDeviceAllocatable(sym)) {
|
828 |
| - llvm::SmallVector<mlir::Value> coord; |
829 |
| - mlir::Type fieldTy = |
830 |
| - Fortran::lower::gatherDeviceComponentCoordinatesAndType( |
831 |
| - builder, loc, sym, recTy, coord); |
| 830 | + unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString()); |
| 831 | + mlir::Type fieldTy; |
| 832 | + llvm::SmallVector<mlir::Value> coordinates; |
| 833 | + |
| 834 | + if (fieldIdx != std::numeric_limits<unsigned>::max()) { |
| 835 | + // Field found in the base record type. |
| 836 | + auto fieldName = recTy.getTypeList()[fieldIdx].first; |
| 837 | + fieldTy = recTy.getTypeList()[fieldIdx].second; |
| 838 | + mlir::Value fieldIndex = fir::FieldIndexOp::create( |
| 839 | + builder, loc, fir::FieldType::get(fieldTy.getContext()), |
| 840 | + fieldName, recTy, |
| 841 | + /*typeParams=*/mlir::ValueRange{}); |
| 842 | + coordinates.push_back(fieldIndex); |
| 843 | + } else { |
| 844 | + // Field not found in base record type, search in potential |
| 845 | + // record type components. |
| 846 | + for (auto component : recTy.getTypeList()) { |
| 847 | + if (auto childRecTy = |
| 848 | + mlir::dyn_cast<fir::RecordType>(component.second)) { |
| 849 | + fieldIdx = childRecTy.getFieldIndex(sym.name().ToString()); |
| 850 | + if (fieldIdx != std::numeric_limits<unsigned>::max()) { |
| 851 | + mlir::Value parentFieldIndex = fir::FieldIndexOp::create( |
| 852 | + builder, loc, |
| 853 | + fir::FieldType::get(childRecTy.getContext()), |
| 854 | + component.first, recTy, |
| 855 | + /*typeParams=*/mlir::ValueRange{}); |
| 856 | + coordinates.push_back(parentFieldIndex); |
| 857 | + auto fieldName = childRecTy.getTypeList()[fieldIdx].first; |
| 858 | + fieldTy = childRecTy.getTypeList()[fieldIdx].second; |
| 859 | + mlir::Value childFieldIndex = fir::FieldIndexOp::create( |
| 860 | + builder, loc, fir::FieldType::get(fieldTy.getContext()), |
| 861 | + fieldName, childRecTy, |
| 862 | + /*typeParams=*/mlir::ValueRange{}); |
| 863 | + coordinates.push_back(childFieldIndex); |
| 864 | + break; |
| 865 | + } |
| 866 | + } |
| 867 | + } |
| 868 | + } |
| 869 | + |
| 870 | + if (coordinates.empty()) |
| 871 | + TODO(loc, "device resident component in complex derived-type " |
| 872 | + "hierarchy"); |
| 873 | + |
832 | 874 | mlir::Value base = fir::getBase(exv);
|
833 |
| - mlir::Value comp = fir::CoordinateOp::create( |
834 |
| - builder, loc, builder.getRefType(fieldTy), base, coord); |
| 875 | + mlir::Value comp; |
| 876 | + if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(base.getType()))) { |
| 877 | + mlir::Value box = fir::LoadOp::create(builder, loc, base); |
| 878 | + mlir::Value addr = fir::BoxAddrOp::create(builder, loc, box); |
| 879 | + llvm::SmallVector<mlir::Value> lenParams; |
| 880 | + assert(coordinates.size() == 1 && "expect one coordinate"); |
| 881 | + auto field = mlir::dyn_cast<fir::FieldIndexOp>( |
| 882 | + coordinates[0].getDefiningOp()); |
| 883 | + comp = hlfir::DesignateOp::create( |
| 884 | + builder, loc, builder.getRefType(fieldTy), addr, |
| 885 | + /*component=*/field.getFieldName(), |
| 886 | + /*componentShape=*/mlir::Value{}, |
| 887 | + hlfir::DesignateOp::Subscripts{}); |
| 888 | + } else { |
| 889 | + comp = fir::CoordinateOp::create( |
| 890 | + builder, loc, builder.getRefType(fieldTy), base, coordinates); |
| 891 | + } |
835 | 892 | cuf::DataAttributeAttr dataAttr =
|
836 | 893 | Fortran::lower::translateSymbolCUFDataAttribute(
|
837 | 894 | builder.getContext(), sym);
|
|
0 commit comments