@@ -951,6 +951,218 @@ class DotProductConversion
951951 }
952952};
953953
954+ class ReshapeAsElementalConversion
955+ : public mlir::OpRewritePattern<hlfir::ReshapeOp> {
956+ public:
957+ using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern;
958+
959+ llvm::LogicalResult
960+ matchAndRewrite (hlfir::ReshapeOp reshape,
961+ mlir::PatternRewriter &rewriter) const override {
962+ // Do not inline RESHAPE with ORDER yet. The runtime implementation
963+ // may be good enough, unless the temporary creation overhead
964+ // is high.
965+ // TODO: If ORDER is constant, then we can still easily inline.
966+ // TODO: If the result's rank is 1, then we can assume ORDER == (/1/).
967+ if (reshape.getOrder ())
968+ return rewriter.notifyMatchFailure (reshape,
969+ " RESHAPE with ORDER argument" );
970+
971+ // Verify that the element types of ARRAY, PAD and the result
972+ // match before doing any transformations. For example,
973+ // the character types of different lengths may appear in the dead
974+ // code, and it just does not make sense to inline hlfir.reshape
975+ // in this case (a runtime call might have less code size footprint).
976+ hlfir::Entity result = hlfir::Entity{reshape};
977+ hlfir::Entity array = hlfir::Entity{reshape.getArray ()};
978+ mlir::Type elementType = array.getFortranElementType ();
979+ if (result.getFortranElementType () != elementType)
980+ return rewriter.notifyMatchFailure (
981+ reshape, " ARRAY and result have different types" );
982+ mlir::Value pad = reshape.getPad ();
983+ if (pad && hlfir::getFortranElementType (pad.getType ()) != elementType)
984+ return rewriter.notifyMatchFailure (reshape,
985+ " ARRAY and PAD have different types" );
986+
987+ // TODO: selecting between ARRAY and PAD of non-trivial element types
988+ // requires more work. We have to select between two references
989+ // to elements in ARRAY and PAD. This requires conditional
990+ // bufferization of the element, if ARRAY/PAD is an expression.
991+ if (pad && !fir::isa_trivial (elementType))
992+ return rewriter.notifyMatchFailure (reshape,
993+ " PAD present with non-trivial type" );
994+
995+ mlir::Location loc = reshape.getLoc ();
996+ fir::FirOpBuilder builder{rewriter, reshape.getOperation ()};
997+ // Assume that all the indices arithmetic does not overflow
998+ // the IndexType.
999+ builder.setIntegerOverflowFlags (mlir::arith::IntegerOverflowFlags::nuw);
1000+
1001+ llvm::SmallVector<mlir::Value, 1 > typeParams;
1002+ hlfir::genLengthParameters (loc, builder, array, typeParams);
1003+
1004+ // Fetch the extents of ARRAY, PAD and result beforehand.
1005+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
1006+ hlfir::genExtentsVector (loc, builder, array);
1007+
1008+ // If PAD is present, we have to use array size to start taking
1009+ // elements from the PAD array.
1010+ mlir::Value arraySize =
1011+ pad ? computeArraySize (loc, builder, arrayExtents) : nullptr ;
1012+ hlfir::Entity shape = hlfir::Entity{reshape.getShape ()};
1013+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
1014+ mlir::Type indexType = builder.getIndexType ();
1015+ for (int idx = 0 ; idx < result.getRank (); ++idx)
1016+ resultExtents.push_back (hlfir::loadElementAt (
1017+ loc, builder, shape,
1018+ builder.createIntegerConstant (loc, indexType, idx + 1 )));
1019+ auto resultShape = builder.create <fir::ShapeOp>(loc, resultExtents);
1020+
1021+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1022+ mlir::ValueRange inputIndices) -> hlfir::Entity {
1023+ mlir::Value linearIndex =
1024+ computeLinearIndex (loc, builder, resultExtents, inputIndices);
1025+ fir::IfOp ifOp;
1026+ if (pad) {
1027+ // PAD is present. Check if this element comes from the PAD array.
1028+ mlir::Value isInsideArray = builder.create <mlir::arith::CmpIOp>(
1029+ loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize);
1030+ ifOp = builder.create <fir::IfOp>(loc, elementType, isInsideArray,
1031+ /* withElseRegion=*/ true );
1032+
1033+ // In the 'else' block, return an element from the PAD.
1034+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
1035+ // PAD is dynamically optional, but we can unconditionally access it
1036+ // in the 'else' block. If we have to start taking elements from it,
1037+ // then it must be present in a valid program.
1038+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents =
1039+ hlfir::genExtentsVector (loc, builder, hlfir::Entity{pad});
1040+ // Subtract the ARRAY size from the zero-based linear index
1041+ // to get the zero-based linear index into PAD.
1042+ mlir::Value padLinearIndex =
1043+ builder.create <mlir::arith::SubIOp>(loc, linearIndex, arraySize);
1044+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
1045+ delinearizeIndex (loc, builder, padExtents, padLinearIndex,
1046+ /* wrapAround=*/ true );
1047+ mlir::Value padElement =
1048+ hlfir::loadElementAt (loc, builder, hlfir::Entity{pad}, padIndices);
1049+ builder.create <fir::ResultOp>(loc, padElement);
1050+
1051+ // In the 'then' block, return an element from the ARRAY.
1052+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
1053+ }
1054+
1055+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
1056+ delinearizeIndex (loc, builder, arrayExtents, linearIndex,
1057+ /* wrapAround=*/ false );
1058+ mlir::Value arrayElement =
1059+ hlfir::loadElementAt (loc, builder, array, arrayIndices);
1060+
1061+ if (ifOp) {
1062+ builder.create <fir::ResultOp>(loc, arrayElement);
1063+ builder.setInsertionPointAfter (ifOp);
1064+ arrayElement = ifOp.getResult (0 );
1065+ }
1066+
1067+ return hlfir::Entity{arrayElement};
1068+ };
1069+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
1070+ loc, builder, elementType, resultShape, typeParams, genKernel,
1071+ /* isUnordered=*/ true ,
1072+ /* polymorphicMold=*/ result.isPolymorphic () ? array : mlir::Value{},
1073+ reshape.getResult ().getType ());
1074+ assert (elementalOp.getResult ().getType () == reshape.getResult ().getType ());
1075+ rewriter.replaceOp (reshape, elementalOp);
1076+ return mlir::success ();
1077+ }
1078+
1079+ private:
1080+ // / Compute zero-based linear index given an array extents
1081+ // / and one-based indices:
1082+ // / \p extents: [e0, e1, ..., en]
1083+ // / \p indices: [i0, i1, ..., in]
1084+ // /
1085+ // / linear-index :=
1086+ // / (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1)
1087+ static mlir::Value computeLinearIndex (mlir::Location loc,
1088+ fir::FirOpBuilder &builder,
1089+ mlir::ValueRange extents,
1090+ mlir::ValueRange indices) {
1091+ std::size_t rank = extents.size ();
1092+ assert (rank = indices.size ());
1093+ mlir::Type indexType = builder.getIndexType ();
1094+ mlir::Value zero = builder.createIntegerConstant (loc, indexType, 0 );
1095+ mlir::Value one = builder.createIntegerConstant (loc, indexType, 1 );
1096+ mlir::Value linearIndex = zero;
1097+ for (auto idx : llvm::enumerate (llvm::reverse (indices))) {
1098+ mlir::Value tmp = builder.create <mlir::arith::SubIOp>(
1099+ loc, builder.createConvert (loc, indexType, idx.value ()), one);
1100+ tmp = builder.create <mlir::arith::AddIOp>(loc, linearIndex, tmp);
1101+ if (idx.index () + 1 < rank)
1102+ tmp = builder.create <mlir::arith::MulIOp>(
1103+ loc, tmp,
1104+ builder.createConvert (loc, indexType,
1105+ extents[rank - idx.index () - 2 ]));
1106+
1107+ linearIndex = tmp;
1108+ }
1109+ return linearIndex;
1110+ }
1111+
1112+ // / Compute one-based array indices from the given zero-based \p linearIndex
1113+ // / and the array \p extents [e0, e1, ..., en].
1114+ // / i0 := linearIndex % e0 + 1
1115+ // / linearIndex := linearIndex / e0
1116+ // / i1 := linearIndex % e1 + 1
1117+ // / linearIndex := linearIndex / e1
1118+ // / ...
1119+ // / i(n-1) := linearIndex % e(n-1) + 1
1120+ // / linearIndex := linearIndex / e(n-1)
1121+ // / if (wrapAround) {
1122+ // / // If the index is allowed to wrap around, then
1123+ // / // we need to modulo it by the last dimension's extent.
1124+ // / in := linearIndex % en + 1
1125+ // / } else {
1126+ // / in := linearIndex + 1
1127+ // / }
1128+ static llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
1129+ delinearizeIndex (mlir::Location loc, fir::FirOpBuilder &builder,
1130+ mlir::ValueRange extents, mlir::Value linearIndex,
1131+ bool wrapAround) {
1132+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
1133+ mlir::Type indexType = builder.getIndexType ();
1134+ mlir::Value one = builder.createIntegerConstant (loc, indexType, 1 );
1135+ linearIndex = builder.createConvert (loc, indexType, linearIndex);
1136+
1137+ for (std::size_t dim = 0 ; dim < extents.size (); ++dim) {
1138+ mlir::Value extent = builder.createConvert (loc, indexType, extents[dim]);
1139+ // Avoid the modulo for the last index, unless wrap around is allowed.
1140+ mlir::Value currentIndex = linearIndex;
1141+ if (dim != extents.size () - 1 || wrapAround)
1142+ currentIndex =
1143+ builder.create <mlir::arith::RemUIOp>(loc, linearIndex, extent);
1144+ // The result of the last division is unused, so it will be DCEd.
1145+ linearIndex =
1146+ builder.create <mlir::arith::DivUIOp>(loc, linearIndex, extent);
1147+ indices.push_back (
1148+ builder.create <mlir::arith::AddIOp>(loc, currentIndex, one));
1149+ }
1150+ return indices;
1151+ }
1152+
1153+ // / Return size of an array given its extents.
1154+ static mlir::Value computeArraySize (mlir::Location loc,
1155+ fir::FirOpBuilder &builder,
1156+ mlir::ValueRange extents) {
1157+ mlir::Type indexType = builder.getIndexType ();
1158+ mlir::Value size = builder.createIntegerConstant (loc, indexType, 1 );
1159+ for (auto extent : extents)
1160+ size = builder.create <mlir::arith::MulIOp>(
1161+ loc, size, builder.createConvert (loc, indexType, extent));
1162+ return size;
1163+ }
1164+ };
1165+
9541166class SimplifyHLFIRIntrinsics
9551167 : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
9561168public:
@@ -987,6 +1199,7 @@ class SimplifyHLFIRIntrinsics
9871199 patterns.insert <MatmulConversion<hlfir::MatmulOp>>(context);
9881200
9891201 patterns.insert <DotProductConversion>(context);
1202+ patterns.insert <ReshapeAsElementalConversion>(context);
9901203
9911204 if (mlir::failed (mlir::applyPatternsGreedily (
9921205 getOperation (), std::move (patterns), config))) {
0 commit comments