@@ -951,6 +951,213 @@ class DotProductConversion
951951 }
952952};
953953
954+ class ReshapeAsElementalConversion
955+ : public mlir::OpRewritePattern<hlfir::ReshapeOp> {
956+ public:
957+ using mlir::OpRewritePattern<hlfir::ReshapeOp>::OpRewritePattern;
958+
959+ llvm::LogicalResult
960+ matchAndRewrite (hlfir::ReshapeOp reshape,
961+ mlir::PatternRewriter &rewriter) const override {
962+ // Do not inline RESHAPE with ORDER yet. The runtime implementation
963+ // may be good enough, unless the temporary creation overhead
964+ // is high.
965+ // TODO: If ORDER is constant, then we can still easily inline.
966+ // TODO: If the result's rank is 1, then we can assume ORDER == (/1/).
967+ if (reshape.getOrder ())
968+ return rewriter.notifyMatchFailure (reshape,
969+ " RESHAPE with ORDER argument" );
970+
971+ // Verify that the element types of ARRAY, PAD and the result
972+ // match before doing any transformations.
973+ hlfir::Entity result = hlfir::Entity{reshape};
974+ hlfir::Entity array = hlfir::Entity{reshape.getArray ()};
975+ mlir::Type elementType = array.getFortranElementType ();
976+ if (result.getFortranElementType () != elementType)
977+ return rewriter.notifyMatchFailure (
978+ reshape, " ARRAY and result have different types" );
979+ mlir::Value pad = reshape.getPad ();
980+ if (pad && hlfir::getFortranElementType (pad.getType ()) != elementType)
981+ return rewriter.notifyMatchFailure (reshape,
982+ " ARRAY and PAD have different types" );
983+
984+ // TODO: selecting between ARRAY and PAD of non-trivial element types
985+ // requires more work. We have to select between two references
986+ // to elements in ARRAY and PAD. This requires conditional
987+ // bufferization of the element, if ARRAY/PAD is an expression.
988+ if (pad && !fir::isa_trivial (elementType))
989+ return rewriter.notifyMatchFailure (reshape,
990+ " PAD present with non-trivial type" );
991+
992+ mlir::Location loc = reshape.getLoc ();
993+ fir::FirOpBuilder builder{rewriter, reshape.getOperation ()};
994+ // Assume that all the indices arithmetic does not overflow
995+ // the IndexType.
996+ builder.setIntegerOverflowFlags (mlir::arith::IntegerOverflowFlags::nuw);
997+
998+ llvm::SmallVector<mlir::Value, 1 > typeParams;
999+ hlfir::genLengthParameters (loc, builder, array, typeParams);
1000+
1001+ // Fetch the extents of ARRAY, PAD and result beforehand.
1002+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayExtents =
1003+ hlfir::genExtentsVector (loc, builder, array);
1004+
1005+ mlir::Value arraySize, padSize;
1006+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padExtents;
1007+ if (pad) {
1008+ // If PAD is present, we have to use array size to start taking
1009+ // elements from the PAD array.
1010+ arraySize = computeArraySize (loc, builder, arrayExtents);
1011+
1012+ padExtents = hlfir::genExtentsVector (loc, builder, hlfir::Entity{pad});
1013+ // PAD size is needed to wrap around the linear index addressing
1014+ // the PAD array.
1015+ padSize = computeArraySize (loc, builder, padExtents);
1016+ }
1017+ hlfir::Entity shape = hlfir::Entity{reshape.getShape ()};
1018+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> resultExtents;
1019+ mlir::Type indexType = builder.getIndexType ();
1020+ for (int idx = 0 ; idx < result.getRank (); ++idx)
1021+ resultExtents.push_back (hlfir::loadElementAt (
1022+ loc, builder, shape,
1023+ builder.createIntegerConstant (loc, indexType, idx + 1 )));
1024+ auto resultShape = builder.create <fir::ShapeOp>(loc, resultExtents);
1025+
1026+ auto genKernel = [&](mlir::Location loc, fir::FirOpBuilder &builder,
1027+ mlir::ValueRange inputIndices) -> hlfir::Entity {
1028+ mlir::Value linearIndex =
1029+ computeLinearIndex (loc, builder, resultExtents, inputIndices);
1030+ fir::IfOp ifOp;
1031+ if (pad) {
1032+ // PAD is present. Check if this element comes from the PAD array.
1033+ mlir::Value isInsideArray = builder.create <mlir::arith::CmpIOp>(
1034+ loc, mlir::arith::CmpIPredicate::ult, linearIndex, arraySize);
1035+ ifOp = builder.create <fir::IfOp>(loc, elementType, isInsideArray,
1036+ /* withElseRegion=*/ true );
1037+
1038+ // In the 'else' block, return an element from the PAD.
1039+ builder.setInsertionPointToStart (&ifOp.getElseRegion ().front ());
1040+ // Subtract the ARRAY size from the zero-based linear index
1041+ // to get the zero-based linear index into PAD.
1042+ mlir::Value padLinearIndex =
1043+ builder.create <mlir::arith::SubIOp>(loc, linearIndex, arraySize);
1044+ // PAD wraps around, when additional elements are needed.
1045+ padLinearIndex =
1046+ builder.create <mlir::arith::RemUIOp>(loc, padLinearIndex, padSize);
1047+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> padIndices =
1048+ delinearizeIndex (loc, builder, padExtents, padLinearIndex);
1049+ mlir::Value padElement =
1050+ hlfir::loadElementAt (loc, builder, hlfir::Entity{pad}, padIndices);
1051+ builder.create <fir::ResultOp>(loc, padElement);
1052+
1053+ // In the 'then' block, return an element from the ARRAY.
1054+ builder.setInsertionPointToStart (&ifOp.getThenRegion ().front ());
1055+ }
1056+
1057+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> arrayIndices =
1058+ delinearizeIndex (loc, builder, arrayExtents, linearIndex);
1059+ mlir::Value arrayElement =
1060+ hlfir::loadElementAt (loc, builder, array, arrayIndices);
1061+
1062+ if (ifOp) {
1063+ builder.create <fir::ResultOp>(loc, arrayElement);
1064+ builder.setInsertionPointAfter (ifOp);
1065+ arrayElement = ifOp.getResult (0 );
1066+ }
1067+
1068+ return hlfir::Entity{arrayElement};
1069+ };
1070+ hlfir::ElementalOp elementalOp = hlfir::genElementalOp (
1071+ loc, builder, elementType, resultShape, typeParams, genKernel,
1072+ /* isUnordered=*/ true ,
1073+ /* polymorphicMold=*/ result.isPolymorphic () ? array : mlir::Value{},
1074+ reshape.getResult ().getType ());
1075+ assert (elementalOp.getResult ().getType () == reshape.getResult ().getType ());
1076+ rewriter.replaceOp (reshape, elementalOp);
1077+ return mlir::success ();
1078+ }
1079+
1080+ private:
1081+ // / Compute zero-based linear index given an array extents
1082+ // / and one-based indices:
1083+ // / \p extents: [e0, e1, ..., en]
1084+ // / \p indices: [i0, i1, ..., in]
1085+ // /
1086+ // / linear-index :=
1087+ // / (...((in-1)*e(n-1)+(i(n-1)-1))*e(n-2)+...)*e0+(i0-1)
1088+ static mlir::Value computeLinearIndex (mlir::Location loc,
1089+ fir::FirOpBuilder &builder,
1090+ mlir::ValueRange extents,
1091+ mlir::ValueRange indices) {
1092+ std::size_t rank = extents.size ();
1093+ assert (rank = indices.size ());
1094+ mlir::Type indexType = builder.getIndexType ();
1095+ mlir::Value zero = builder.createIntegerConstant (loc, indexType, 0 );
1096+ mlir::Value one = builder.createIntegerConstant (loc, indexType, 1 );
1097+ mlir::Value linearIndex = zero;
1098+ for (auto idx : llvm::enumerate (llvm::reverse (indices))) {
1099+ mlir::Value tmp = builder.create <mlir::arith::SubIOp>(
1100+ loc, builder.createConvert (loc, indexType, idx.value ()), one);
1101+ tmp = builder.create <mlir::arith::AddIOp>(loc, linearIndex, tmp);
1102+ if (idx.index () + 1 < rank)
1103+ tmp = builder.create <mlir::arith::MulIOp>(
1104+ loc, tmp,
1105+ builder.createConvert (loc, indexType,
1106+ extents[rank - idx.index () - 2 ]));
1107+
1108+ linearIndex = tmp;
1109+ }
1110+ return linearIndex;
1111+ }
1112+
1113+ // / Compute one-based array indices from the given zero-based \p linearIndex
1114+ // / and the array \p extents [e0, e1, ..., en].
1115+ // / i0 := linearIndex % e0 + 1
1116+ // / linearIndex := linearIndex / e0
1117+ // / i1 := linearIndex % e1 + 1
1118+ // / linearIndex := linearIndex / e1
1119+ // / ...
1120+ // / i(n-1) := linearIndex % e(n-1) + 1
1121+ // / linearIndex := linearIndex / e(n-1)
1122+ // / in := linearIndex + 1
1123+ static llvm::SmallVector<mlir::Value, Fortran::common::maxRank>
1124+ delinearizeIndex (mlir::Location loc, fir::FirOpBuilder &builder,
1125+ mlir::ValueRange extents, mlir::Value linearIndex) {
1126+ llvm::SmallVector<mlir::Value, Fortran::common::maxRank> indices;
1127+ mlir::Type indexType = builder.getIndexType ();
1128+ mlir::Value one = builder.createIntegerConstant (loc, indexType, 1 );
1129+ linearIndex = builder.createConvert (loc, indexType, linearIndex);
1130+
1131+ for (std::size_t dim = 0 ; dim < extents.size (); ++dim) {
1132+ mlir::Value currentIndex;
1133+ if (dim == extents.size () - 1 ) {
1134+ currentIndex = linearIndex;
1135+ } else {
1136+ mlir::Value extent =
1137+ builder.createConvert (loc, indexType, extents[dim]);
1138+ currentIndex =
1139+ builder.create <mlir::arith::RemUIOp>(loc, linearIndex, extent);
1140+ linearIndex =
1141+ builder.create <mlir::arith::DivUIOp>(loc, linearIndex, extent);
1142+ }
1143+ indices.push_back (
1144+ builder.create <mlir::arith::AddIOp>(loc, currentIndex, one));
1145+ }
1146+ return indices;
1147+ }
1148+
1149+ static mlir::Value computeArraySize (mlir::Location loc,
1150+ fir::FirOpBuilder &builder,
1151+ mlir::ValueRange extents) {
1152+ mlir::Type indexType = builder.getIndexType ();
1153+ mlir::Value size = builder.createIntegerConstant (loc, indexType, 1 );
1154+ for (auto extent : extents)
1155+ size = builder.create <mlir::arith::MulIOp>(
1156+ loc, size, builder.createConvert (loc, indexType, extent));
1157+ return size;
1158+ }
1159+ };
1160+
9541161class SimplifyHLFIRIntrinsics
9551162 : public hlfir::impl::SimplifyHLFIRIntrinsicsBase<SimplifyHLFIRIntrinsics> {
9561163public:
@@ -987,6 +1194,7 @@ class SimplifyHLFIRIntrinsics
9871194 patterns.insert <MatmulConversion<hlfir::MatmulOp>>(context);
9881195
9891196 patterns.insert <DotProductConversion>(context);
1197+ patterns.insert <ReshapeAsElementalConversion>(context);
9901198
9911199 if (mlir::failed (mlir::applyPatternsGreedily (
9921200 getOperation (), std::move (patterns), config))) {
0 commit comments