-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[flang] Inline hlfir.reshape as hlfir.elemental. #124683
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -951,6 +951,218 @@ class DotProductConversion | |
| } | ||
| }; | ||
|
|
||
| class ReshapeAsElementalConversion | ||
| : public mlir::OpRewritePattern<hlfir::ReshapeOp> { | ||
| public: | ||
| using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern; | ||
|
|
||
| llvm::LogicalResult | ||
| matchAndRewrite(hlfir::ReshapeOp reshape, | ||
| mlir::PatternRewriter &rewriter) const override { | ||
| // Do not inline RESHAPE with ORDER yet. The runtime implementation | ||
| // may be good enough, unless the temporary creation overhead | ||
| // is high. | ||
| // TODO: If ORDER is constant, then we can still easily inline. | ||
| // TODO: If the result's rank is 1, then we can assume ORDER == (/1/). | ||
| if (reshape.getOrder()) | ||
| return rewriter.notifyMatchFailure(reshape, | ||
| "RESHAPE with ORDER argument"); | ||
|
|
||
| // Verify that the element types of ARRAY, PAD and the result | ||
| // match before doing any transformations. For example, | ||
| // the character types of different lengths may appear in the dead | ||
| // code, and it just does not make sense to inline hlfir.reshape | ||
| // in this case (a runtime call might have less code size footprint). | ||
| hlfir::Entity result = hlfir::Entity{reshape}; | ||
| hlfir::Entity array = hlfir::Entity{reshape.getArray()}; | ||
| mlir::Type elementType = array.getFortranElementType(); | ||
| if (result.getFortranElementType() != elementType) | ||
| return rewriter.notifyMatchFailure( | ||
| reshape, "ARRAY and result have different types"); | ||
| mlir::Value pad = reshape.getPad(); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PAD is dynamically optional. If its actual argument is an OPTIONAL/POINTER/ALLOCATABLE, its presence should be checked at runtime. You probably need to do something about that here (or at least to detect and do not do the transformation for now).
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks! I moved the reads from PAD under the check of whether we have to read from it or not. |
||
| if (pad && hlfir::getFortranElementType(pad.getType()) != elementType) | ||
| return rewriter.notifyMatchFailure(reshape, | ||
| "ARRAY and PAD have different types"); | ||
tblah marked this conversation as resolved.
Show resolved
Hide resolved
jeanPerier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| // TODO: selecting between ARRAY and PAD of non-trivial element types | ||
| // requires more work. We have to select between two references | ||
| // to elements in ARRAY and PAD. This requires conditional | ||
| // bufferization of the element, if ARRAY/PAD is an expression. | ||
| if (pad && !fir::isa_trivial(elementType)) | ||
| return rewriter.notifyMatchFailure(reshape, | ||
| "PAD present with non-trivial type"); | ||
|
|
||
| mlir::Location loc = reshape.getLoc(); | ||
| fir::FirOpBuilder builder{rewriter, reshape.getOperation()}; | ||
| // Assume that all the indices arithmetic does not overflow | ||
| // the IndexType. | ||
| builder.setIntegerOverflowFlags(mlir::arith::IntegerOverflowFlags::nuw); | ||
tblah marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| llvm::SmallVector<mlir::Value, 1> typeParams; | ||
| hlfir::genLengthParameters(loc, builder, array, typeParams); | ||
|
|
||
| // Fetch the extents of ARRAY, PAD and result beforehand. | ||
| llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents = | ||
| hlfir::genExtentsVector(loc, builder, array); | ||
|
|
||
| // If PAD is present, we have to use array size to start taking | ||
| // elements from the PAD array. | ||
| mlir::Value arraySize = | ||
| pad ? computeArraySize(loc, builder, arrayExtents) : nullptr; | ||
| hlfir::Entity shape = hlfir::Entity{reshape.getShape()}; | ||
| llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents; | ||
| mlir::Type indexType = builder.getIndexType(); | ||
| for (int idx = 0; idx < result.getRank(); ++idx) | ||
| resultExtents.push_back(hlfir::loadElementAt( | ||
| loc, builder, shape, | ||
| builder.createIntegerConstant(loc, indexType, idx + 1))); | ||
| auto resultShape = builder.create<fir::ShapeOp>(loc, resultExtents); | ||
|
|
||
| auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder, | ||
| mlir::ValueRange inputIndices) -> hlfir::Entity { | ||
| mlir::Value linearIndex = | ||
| computeLinearIndex(loc, builder, resultExtents, inputIndices); | ||
| fir::IfOp ifOp; | ||
| if (pad) { | ||
| // PAD is present. Check if this element comes from the PAD array. | ||
| mlir::Value isInsideArray = builder.create<mlir::arith::CmpIOp>( | ||
| loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize); | ||
| ifOp = builder.create<fir::IfOp>(loc, elementType, isInsideArray, | ||
| /*withElseRegion=*/true); | ||
|
|
||
| // In the 'else' block, return an element from the PAD. | ||
| builder.setInsertionPointToStart(&ifOp.getElseRegion().front()); | ||
| // PAD is dynamically optional, but we can unconditionally access it | ||
| // in the 'else' block. If we have to start taking elements from it, | ||
| // then it must be present in a valid program. | ||
| llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents = | ||
| hlfir::genExtentsVector(loc, builder, hlfir::Entity{pad}); | ||
| // Subtract the ARRAY size from the zero-based linear index | ||
| // to get the zero-based linear index into PAD. | ||
| mlir::Value padLinearIndex = | ||
| builder.create<mlir::arith::SubIOp>(loc, linearIndex, arraySize); | ||
| llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices = | ||
| delinearizeIndex(loc, builder, padExtents, padLinearIndex, | ||
| /*wrapAround=*/true); | ||
| mlir::Value padElement = | ||
| hlfir::loadElementAt(loc, builder, hlfir::Entity{pad}, padIndices); | ||
| builder.create<fir::ResultOp>(loc, padElement); | ||
|
|
||
| // In the 'then' block, return an element from the ARRAY. | ||
| builder.setInsertionPointToStart(&ifOp.getThenRegion().front()); | ||
| } | ||
|
|
||
| llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices = | ||
| delinearizeIndex(loc, builder, arrayExtents, linearIndex, | ||
| /*wrapAround=*/false); | ||
| mlir::Value arrayElement = | ||
| hlfir::loadElementAt(loc, builder, array, arrayIndices); | ||
|
|
||
| if (ifOp) { | ||
| builder.create<fir::ResultOp>(loc, arrayElement); | ||
| builder.setInsertionPointAfter(ifOp); | ||
| arrayElement = ifOp.getResult(0); | ||
| } | ||
|
|
||
| return hlfir::Entity{arrayElement}; | ||
| }; | ||
| hlfir::ElementalOp elementalOp = hlfir::genElementalOp( | ||
| loc, builder, elementType, resultShape, typeParams, genKernel, | ||
| /*isUnordered=*/true, | ||
| /*polymorphicMold=*/result.isPolymorphic() ? array : mlir::Value{}, | ||
| reshape.getResult().getType()); | ||
| assert(elementalOp.getResult().getType() == reshape.getResult().getType()); | ||
| rewriter.replaceOp(reshape, elementalOp); | ||
| return mlir::success(); | ||
| } | ||
|
|
||
| private: | ||
| /// Compute zero-based linear index given an array extents | ||
| /// and one-based indices: | ||
| /// \p extents: [e0, e1, ..., en] | ||
| /// \p indices: [i0, i1, ..., in] | ||
| /// | ||
| /// linear-index := | ||
| /// (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1) | ||
| static mlir::Value computeLinearIndex(mlir::Location loc, | ||
| fir::FirOpBuilder &builder, | ||
| mlir::ValueRange extents, | ||
| mlir::ValueRange indices) { | ||
| std::size_t rank = extents.size(); | ||
| assert(rank = indices.size()); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: Should be ==? GCC was giving a warning.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you!
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed by 381416a |
||
| mlir::Type indexType = builder.getIndexType(); | ||
| mlir::Value zero = builder.createIntegerConstant(loc, indexType, 0); | ||
| mlir::Value one = builder.createIntegerConstant(loc, indexType, 1); | ||
| mlir::Value linearIndex = zero; | ||
| for (auto idx : llvm::enumerate(llvm::reverse(indices))) { | ||
| mlir::Value tmp = builder.create<mlir::arith::SubIOp>( | ||
| loc, builder.createConvert(loc, indexType, idx.value()), one); | ||
| tmp = builder.create<mlir::arith::AddIOp>(loc, linearIndex, tmp); | ||
| if (idx.index() + 1 < rank) | ||
| tmp = builder.create<mlir::arith::MulIOp>( | ||
| loc, tmp, | ||
| builder.createConvert(loc, indexType, | ||
| extents[rank - idx.index() - 2])); | ||
|
|
||
| linearIndex = tmp; | ||
| } | ||
| return linearIndex; | ||
| } | ||
|
|
||
| /// Compute one-based array indices from the given zero-based \p linearIndex | ||
| /// and the array \p extents [e0, e1, ..., en]. | ||
| /// i0 := linearIndex % e0 + 1 | ||
| /// linearIndex := linearIndex / e0 | ||
| /// i1 := linearIndex % e1 + 1 | ||
| /// linearIndex := linearIndex / e1 | ||
| /// ... | ||
| /// i(n-1) := linearIndex % e(n-1) + 1 | ||
| /// linearIndex := linearIndex / e(n-1) | ||
| /// if (wrapAround) { | ||
| /// // If the index is allowed to wrap around, then | ||
| /// // we need to modulo it by the last dimension's extent. | ||
| /// in := linearIndex % en + 1 | ||
| /// } else { | ||
| /// in := linearIndex + 1 | ||
| /// } | ||
| static llvm::SmallVector<mlir::Value, Fortran::common::maxRank> | ||
| delinearizeIndex(mlir::Location loc, fir::FirOpBuilder &builder, | ||
| mlir::ValueRange extents, mlir::Value linearIndex, | ||
| bool wrapAround) { | ||
| llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices; | ||
| mlir::Type indexType = builder.getIndexType(); | ||
| mlir::Value one = builder.createIntegerConstant(loc, indexType, 1); | ||
| linearIndex = builder.createConvert(loc, indexType, linearIndex); | ||
|
|
||
| for (std::size_t dim = 0; dim < extents.size(); ++dim) { | ||
| mlir::Value extent = builder.createConvert(loc, indexType, extents[dim]); | ||
| // Avoid the modulo for the last index, unless wrap around is allowed. | ||
| mlir::Value currentIndex = linearIndex; | ||
| if (dim != extents.size() - 1 || wrapAround) | ||
| currentIndex = | ||
| builder.create<mlir::arith::RemUIOp>(loc, linearIndex, extent); | ||
| // The result of the last division is unused, so it will be DCEd. | ||
| linearIndex = | ||
| builder.create<mlir::arith::DivUIOp>(loc, linearIndex, extent); | ||
| indices.push_back( | ||
| builder.create<mlir::arith::AddIOp>(loc, currentIndex, one)); | ||
| } | ||
| return indices; | ||
| } | ||
|
|
||
| /// Return size of an array given its extents. | ||
| static mlir::Value computeArraySize(mlir::Location loc, | ||
| fir::FirOpBuilder &builder, | ||
| mlir::ValueRange extents) { | ||
| mlir::Type indexType = builder.getIndexType(); | ||
| mlir::Value size = builder.createIntegerConstant(loc, indexType, 1); | ||
| for (auto extent : extents) | ||
| size = builder.create<mlir::arith::MulIOp>( | ||
| loc, size, builder.createConvert(loc, indexType, extent)); | ||
| return size; | ||
| } | ||
| }; | ||
|
|
||
| class SimplifyHLFIRIntrinsics | ||
| : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> { | ||
| public: | ||
|
|
@@ -987,6 +1199,7 @@ class SimplifyHLFIRIntrinsics | |
| patterns.insert<MatmulConversion<hlfir::MatmulOp>>(context); | ||
|
|
||
| patterns.insert<DotProductConversion>(context); | ||
| patterns.insert<ReshapeAsElementalConversion>(context); | ||
|
|
||
| if (mlir::failed(mlir::applyPatternsGreedily( | ||
| getOperation(), std::move(patterns), config))) { | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.