|
| 1 | +//===-- CUDA.cpp -- CUDA Fortran specific lowering ------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
| 10 | +// |
| 11 | +//===----------------------------------------------------------------------===// |
| 12 | + |
| 13 | +#include "flang/Lower/CUDA.h" |
| 14 | + |
| 15 | +#define DEBUG_TYPE "flang-lower-cuda" |
| 16 | + |
| 17 | +void Fortran::lower::initializeDeviceComponentAllocator( |
| 18 | + Fortran::lower::AbstractConverter &converter, |
| 19 | + const Fortran::semantics::Symbol &sym, const fir::MutableBoxValue &box) { |
| 20 | + if (const auto *details{ |
| 21 | + sym.GetUltimate() |
| 22 | + .detailsIf<Fortran::semantics::ObjectEntityDetails>()}) { |
| 23 | + const Fortran::semantics::DeclTypeSpec *type{details->type()}; |
| 24 | + const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived() |
| 25 | + : nullptr}; |
| 26 | + if (derived) { |
| 27 | + if (!FindCUDADeviceAllocatableUltimateComponent(*derived)) |
| 28 | + return; // No device components. |
| 29 | + |
| 30 | + fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
| 31 | + mlir::Location loc = converter.getCurrentLocation(); |
| 32 | + |
| 33 | + mlir::Type baseTy = fir::unwrapRefType(box.getAddr().getType()); |
| 34 | + |
| 35 | + // Only pointer and allocatable needs post allocation initialization |
| 36 | + // of components descriptors. |
| 37 | + if (!fir::isAllocatableType(baseTy) && !fir::isPointerType(baseTy)) |
| 38 | + return; |
| 39 | + |
| 40 | + // Extract the derived type. |
| 41 | + mlir::Type ty = fir::getDerivedType(baseTy); |
| 42 | + auto recTy = mlir::dyn_cast<fir::RecordType>(ty); |
| 43 | + assert(recTy && "expected fir::RecordType"); |
| 44 | + |
| 45 | + if (auto boxTy = mlir::dyn_cast<fir::BaseBoxType>(baseTy)) |
| 46 | + baseTy = boxTy.getEleTy(); |
| 47 | + baseTy = fir::unwrapRefType(baseTy); |
| 48 | + |
| 49 | + Fortran::semantics::UltimateComponentIterator components{*derived}; |
| 50 | + mlir::Value loadedBox = fir::LoadOp::create(builder, loc, box.getAddr()); |
| 51 | + mlir::Value addr; |
| 52 | + if (auto seqTy = mlir::dyn_cast<fir::SequenceType>(baseTy)) { |
| 53 | + mlir::Type idxTy = builder.getIndexType(); |
| 54 | + mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
| 55 | + mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); |
| 56 | + llvm::SmallVector<fir::DoLoopOp> loops; |
| 57 | + llvm::SmallVector<mlir::Value> indices; |
| 58 | + llvm::SmallVector<mlir::Value> extents; |
| 59 | + for (unsigned i = 0; i < seqTy.getDimension(); ++i) { |
| 60 | + mlir::Value dim = builder.createIntegerConstant(loc, idxTy, i); |
| 61 | + auto dimInfo = fir::BoxDimsOp::create(builder, loc, idxTy, idxTy, |
| 62 | + idxTy, loadedBox, dim); |
| 63 | + mlir::Value lbub = mlir::arith::AddIOp::create( |
| 64 | + builder, loc, dimInfo.getResult(0), dimInfo.getResult(1)); |
| 65 | + mlir::Value ext = |
| 66 | + mlir::arith::SubIOp::create(builder, loc, lbub, one); |
| 67 | + mlir::Value cmp = mlir::arith::CmpIOp::create( |
| 68 | + builder, loc, mlir::arith::CmpIPredicate::sgt, ext, zero); |
| 69 | + ext = mlir::arith::SelectOp::create(builder, loc, cmp, ext, zero); |
| 70 | + extents.push_back(ext); |
| 71 | + |
| 72 | + auto loop = fir::DoLoopOp::create( |
| 73 | + builder, loc, dimInfo.getResult(0), dimInfo.getResult(1), |
| 74 | + dimInfo.getResult(2), /*isUnordered=*/true, |
| 75 | + /*finalCount=*/false, mlir::ValueRange{}); |
| 76 | + loops.push_back(loop); |
| 77 | + indices.push_back(loop.getInductionVar()); |
| 78 | + builder.setInsertionPointToStart(loop.getBody()); |
| 79 | + } |
| 80 | + mlir::Value boxAddr = fir::BoxAddrOp::create(builder, loc, loadedBox); |
| 81 | + auto shape = fir::ShapeOp::create(builder, loc, extents); |
| 82 | + addr = fir::ArrayCoorOp::create( |
| 83 | + builder, loc, fir::ReferenceType::get(recTy), boxAddr, shape, |
| 84 | + /*slice=*/mlir::Value{}, indices, /*typeparms=*/mlir::ValueRange{}); |
| 85 | + } else { |
| 86 | + addr = fir::BoxAddrOp::create(builder, loc, loadedBox); |
| 87 | + } |
| 88 | + for (const auto &compSym : components) { |
| 89 | + if (Fortran::semantics::IsDeviceAllocatable(compSym)) { |
| 90 | + llvm::SmallVector<mlir::Value> coord; |
| 91 | + mlir::Type fieldTy = gatherDeviceComponentCoordinatesAndType( |
| 92 | + builder, loc, compSym, recTy, coord); |
| 93 | + assert(coord.size() == 1 && "expect one coordinate"); |
| 94 | + mlir::Value comp = fir::CoordinateOp::create( |
| 95 | + builder, loc, builder.getRefType(fieldTy), addr, coord[0]); |
| 96 | + cuf::DataAttributeAttr dataAttr = |
| 97 | + Fortran::lower::translateSymbolCUFDataAttribute( |
| 98 | + builder.getContext(), compSym); |
| 99 | + cuf::SetAllocatorIndexOp::create(builder, loc, comp, dataAttr); |
| 100 | + } |
| 101 | + } |
| 102 | + } |
| 103 | + } |
| 104 | +} |
| 105 | + |
| 106 | +mlir::Type Fortran::lower::gatherDeviceComponentCoordinatesAndType( |
| 107 | + fir::FirOpBuilder &builder, mlir::Location loc, |
| 108 | + const Fortran::semantics::Symbol &sym, fir::RecordType recTy, |
| 109 | + llvm::SmallVector<mlir::Value> &coordinates) { |
| 110 | + unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString()); |
| 111 | + mlir::Type fieldTy; |
| 112 | + if (fieldIdx != std::numeric_limits<unsigned>::max()) { |
| 113 | + // Field found in the base record type. |
| 114 | + auto fieldName = recTy.getTypeList()[fieldIdx].first; |
| 115 | + fieldTy = recTy.getTypeList()[fieldIdx].second; |
| 116 | + mlir::Value fieldIndex = fir::FieldIndexOp::create( |
| 117 | + builder, loc, fir::FieldType::get(fieldTy.getContext()), fieldName, |
| 118 | + recTy, |
| 119 | + /*typeParams=*/mlir::ValueRange{}); |
| 120 | + coordinates.push_back(fieldIndex); |
| 121 | + } else { |
| 122 | + // Field not found in base record type, search in potential |
| 123 | + // record type components. |
| 124 | + for (auto component : recTy.getTypeList()) { |
| 125 | + if (auto childRecTy = mlir::dyn_cast<fir::RecordType>(component.second)) { |
| 126 | + fieldIdx = childRecTy.getFieldIndex(sym.name().ToString()); |
| 127 | + if (fieldIdx != std::numeric_limits<unsigned>::max()) { |
| 128 | + mlir::Value parentFieldIndex = fir::FieldIndexOp::create( |
| 129 | + builder, loc, fir::FieldType::get(childRecTy.getContext()), |
| 130 | + component.first, recTy, |
| 131 | + /*typeParams=*/mlir::ValueRange{}); |
| 132 | + coordinates.push_back(parentFieldIndex); |
| 133 | + auto fieldName = childRecTy.getTypeList()[fieldIdx].first; |
| 134 | + fieldTy = childRecTy.getTypeList()[fieldIdx].second; |
| 135 | + mlir::Value childFieldIndex = fir::FieldIndexOp::create( |
| 136 | + builder, loc, fir::FieldType::get(fieldTy.getContext()), |
| 137 | + fieldName, childRecTy, |
| 138 | + /*typeParams=*/mlir::ValueRange{}); |
| 139 | + coordinates.push_back(childFieldIndex); |
| 140 | + break; |
| 141 | + } |
| 142 | + } |
| 143 | + } |
| 144 | + } |
| 145 | + if (coordinates.empty()) |
| 146 | + TODO(loc, "device resident component in complex derived-type hierarchy"); |
| 147 | + return fieldTy; |
| 148 | +} |
0 commit comments