|
17 | 17 |
|
18 | 18 | #define DEBUG_TYPE "flang-lower-cuda" |
19 | 19 |
|
20 | | -void Fortran::lower::initializeDeviceComponentAllocator( |
21 | | - Fortran::lower::AbstractConverter &converter, |
22 | | - const Fortran::semantics::Symbol &sym, const fir::MutableBoxValue &box) { |
23 | | - if (const auto *details{ |
24 | | - sym.GetUltimate() |
25 | | - .detailsIf<Fortran::semantics::ObjectEntityDetails>()}) { |
26 | | - const Fortran::semantics::DeclTypeSpec *type{details->type()}; |
27 | | - const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived() |
28 | | - : nullptr}; |
29 | | - if (derived) { |
30 | | - if (!FindCUDADeviceAllocatableUltimateComponent(*derived)) |
31 | | - return; // No device components. |
32 | | - |
33 | | - fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
34 | | - mlir::Location loc = converter.getCurrentLocation(); |
35 | | - |
36 | | - mlir::Type baseTy = fir::unwrapRefType(box.getAddr().getType()); |
37 | | - |
38 | | - // Only pointer and allocatable needs post allocation initialization |
39 | | - // of components descriptors. |
40 | | - if (!fir::isAllocatableType(baseTy) && !fir::isPointerType(baseTy)) |
41 | | - return; |
42 | | - |
43 | | - // Extract the derived type. |
44 | | - mlir::Type ty = fir::getDerivedType(baseTy); |
45 | | - auto recTy = mlir::dyn_cast<fir::RecordType>(ty); |
46 | | - assert(recTy && "expected fir::RecordType"); |
47 | | - |
48 | | - if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseTy)) |
49 | | - baseTy = boxTy.getEleTy(); |
50 | | - baseTy = fir::unwrapRefType(baseTy); |
51 | | - |
52 | | - Fortran::semantics::UltimateComponentIterator components{*derived}; |
53 | | - mlir::Value loadedBox = fir::LoadOp::create(builder, loc, box.getAddr()); |
54 | | - mlir::Value addr; |
55 | | - if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(baseTy)) { |
56 | | - mlir::Type idxTy = builder.getIndexType(); |
57 | | - mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
58 | | - mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); |
59 | | - llvm::SmallVector<fir::DoLoopOp> loops; |
60 | | - llvm::SmallVector<mlir::Value> indices; |
61 | | - llvm::SmallVector<mlir::Value> extents; |
62 | | - for (unsigned i = 0; i < seqTy.getDimension(); ++i) { |
63 | | - mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); |
64 | | - auto dimInfo = fir::BoxDimsOp::create(builder, loc, idxTy, idxTy, |
65 | | - idxTy, loadedBox, dim); |
66 | | - mlir::Value lbub = mlir::arith::AddIOp::create( |
67 | | - builder, loc, dimInfo.getResult(0), dimInfo.getResult(1)); |
68 | | - mlir::Value ext = |
69 | | - mlir::arith::SubIOp::create(builder, loc, lbub, one); |
70 | | - mlir::Value cmp = mlir::arith::CmpIOp::create( |
71 | | - builder, loc, mlir::arith::CmpIPredicate::sgt, ext, zero); |
72 | | - ext = mlir::arith::SelectOp::create(builder, loc, cmp, ext, zero); |
73 | | - extents.push_back(ext); |
74 | | - |
75 | | - auto loop = fir::DoLoopOp::create( |
76 | | - builder, loc, dimInfo.getResult(0), dimInfo.getResult(1), |
77 | | - dimInfo.getResult(2), /*isUnordered=*/true, |
78 | | - /*finalCount=*/false, mlir::ValueRange{}); |
79 | | - loops.push_back(loop); |
80 | | - indices.push_back(loop.getInductionVar()); |
81 | | - builder.setInsertionPointToStart(loop.getBody()); |
82 | | - } |
83 | | - mlir::Value boxAddr = fir::BoxAddrOp::create(builder, loc, loadedBox); |
84 | | - auto shape = fir::ShapeOp::create(builder, loc, extents); |
85 | | - addr = fir::ArrayCoorOp::create( |
86 | | - builder, loc, fir::ReferenceType::get(recTy), boxAddr, shape, |
87 | | - /*slice=*/mlir::Value{}, indices, /*typeparms=*/mlir::ValueRange{}); |
88 | | - } else { |
89 | | - addr = fir::BoxAddrOp::create(builder, loc, loadedBox); |
90 | | - } |
91 | | - for (const auto &compSym : components) { |
92 | | - if (Fortran::semantics::IsDeviceAllocatable(compSym)) { |
93 | | - llvm::SmallVector<mlir::Value> coord; |
94 | | - mlir::Type fieldTy = gatherDeviceComponentCoordinatesAndType( |
95 | | - builder, loc, compSym, recTy, coord); |
96 | | - assert(coord.size() == 1 && "expect one coordinate"); |
97 | | - mlir::Value comp = fir::CoordinateOp::create( |
98 | | - builder, loc, builder.getRefType(fieldTy), addr, coord[0]); |
99 | | - cuf::DataAttributeAttr dataAttr = |
100 | | - Fortran::lower::translateSymbolCUFDataAttribute( |
101 | | - builder.getContext(), compSym); |
102 | | - cuf::SetAllocatorIndexOp::create(builder, loc, comp, dataAttr); |
103 | | - } |
104 | | - } |
105 | | - } |
106 | | - } |
107 | | -} |
108 | | - |
109 | 20 | mlir::Type Fortran::lower::gatherDeviceComponentCoordinatesAndType( |
110 | 21 | fir::FirOpBuilder &builder, mlir::Location loc, |
111 | 22 | const Fortran::semantics::Symbol &sym, fir::RecordType recTy, |
|
0 commit comments