@@ -2078,6 +2078,212 @@ class ArrayShiftConversion : public mlir::OpRewritePattern<Op> {
20782078 }
20792079};
20802080
2081+ class CmpCharOpConversion : public mlir ::OpRewritePattern<hlfir::CmpCharOp> {
2082+ public:
2083+ using mlir::OpRewritePattern<hlfir::CmpCharOp>::OpRewritePattern;
2084+
2085+ llvm::LogicalResult
2086+ matchAndRewrite (hlfir::CmpCharOp cmp,
2087+ mlir::PatternRewriter &rewriter) const override {
2088+
2089+ fir::FirOpBuilder builder{rewriter, cmp.getOperation ()};
2090+ const mlir::Location &loc = cmp->getLoc ();
2091+
2092+ auto toVariable =
2093+ [&builder,
2094+ &loc](mlir::Value val) -> std::pair<mlir::Value, hlfir::AssociateOp> {
2095+ mlir::Value opnd;
2096+ hlfir::AssociateOp associate;
2097+ if (mlir::isa<hlfir::ExprType>(val.getType ())) {
2098+ hlfir::Entity entity{val};
2099+ mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr (builder);
2100+ associate = hlfir::genAssociateExpr (loc, builder, entity,
2101+ entity.getType (), " " , byRefAttr);
2102+ opnd = associate.getBase ();
2103+ } else {
2104+ opnd = val;
2105+ }
2106+ return {opnd, associate};
2107+ };
2108+
2109+ auto [lhsOpnd, lhsAssociate] = toVariable (cmp.getLchr ());
2110+ auto [rhsOpnd, rhsAssociate] = toVariable (cmp.getRchr ());
2111+
2112+ hlfir::Entity lhs{lhsOpnd};
2113+ hlfir::Entity rhs{rhsOpnd};
2114+
2115+ auto charTy = mlir::cast<fir::CharacterType>(lhs.getFortranElementType ());
2116+ unsigned kind = charTy.getFKind ();
2117+
2118+ auto bits = builder.getKindMap ().getCharacterBitsize (kind);
2119+ auto intTy = builder.getIntegerType (bits);
2120+
2121+ auto idxTy = builder.getIndexType ();
2122+ auto charLen1Ty =
2123+ fir::CharacterType::getSingleton (builder.getContext (), kind);
2124+ mlir::Type designatorType =
2125+ fir::ReferenceType::get (charLen1Ty, fir::isa_volatile_type (charTy));
2126+ auto idxAttr = builder.getIntegerAttr (idxTy, 0 );
2127+
2128+ auto genExtractAndConvertToInt =
2129+ [&idxAttr, &intTy, &designatorType](
2130+ mlir::Location loc, fir::FirOpBuilder &builder,
2131+ hlfir::Entity &charStr, mlir::Value index, mlir::Value length) {
2132+ auto singleChr = hlfir::DesignateOp::create (
2133+ builder, loc, designatorType, charStr, /* component=*/ {},
2134+ /* compShape=*/ mlir::Value{}, hlfir::DesignateOp::Subscripts{},
2135+ /* substring=*/ mlir::ValueRange{index, index},
2136+ /* complexPart=*/ std::nullopt ,
2137+ /* shape=*/ mlir::Value{}, /* typeParams=*/ mlir::ValueRange{length},
2138+ fir::FortranVariableFlagsAttr{});
2139+ auto chrVal = fir::LoadOp::create (builder, loc, singleChr);
2140+ mlir::Value intVal = fir::ExtractValueOp::create (
2141+ builder, loc, intTy, chrVal, builder.getArrayAttr (idxAttr));
2142+ return intVal;
2143+ };
2144+
2145+ mlir::arith::CmpIPredicate predicate = cmp.getPredicate ();
2146+ mlir::Value oneIdx = builder.createIntegerConstant (loc, idxTy, 1 );
2147+
2148+ mlir::Value lhsLen = builder.createConvert (
2149+ loc, idxTy, hlfir::genCharLength (loc, builder, lhs));
2150+ mlir::Value rhsLen = builder.createConvert (
2151+ loc, idxTy, hlfir::genCharLength (loc, builder, rhs));
2152+
2153+ enum class GenCmp { LeftToRight, LeftToBlank, BlankToRight };
2154+
2155+ mlir::Value zeroInt = builder.createIntegerConstant (loc, intTy, 0 );
2156+ mlir::Value oneInt = builder.createIntegerConstant (loc, intTy, 1 );
2157+ mlir::Value negOneInt = builder.createIntegerConstant (loc, intTy, -1 );
2158+ mlir::Value blankInt = builder.createIntegerConstant (loc, intTy, ' ' );
2159+
2160+ auto step = GenCmp::LeftToRight;
2161+ auto genCmp = [&](mlir::Location loc, fir::FirOpBuilder &builder,
2162+ mlir::ValueRange index, mlir::ValueRange reductionArgs)
2163+ -> llvm::SmallVector<mlir::Value, 1 > {
2164+ assert (index.size () == 1 && " expected single loop" );
2165+ assert (reductionArgs.size () == 1 && " expected single reduction value" );
2166+ mlir::Value inRes = reductionArgs[0 ];
2167+ auto accEQzero = mlir::arith::CmpIOp::create (
2168+ builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroInt);
2169+
2170+ mlir::Value res =
2171+ builder
2172+ .genIfOp (loc, {intTy}, accEQzero,
2173+ /* withElseRegion=*/ true )
2174+ .genThen ([&]() {
2175+ mlir::Value offset =
2176+ builder.createConvert (loc, idxTy, index[0 ]);
2177+ mlir::Value lhsInt;
2178+ mlir::Value rhsInt;
2179+ if (step == GenCmp::LeftToRight) {
2180+ lhsInt = genExtractAndConvertToInt (loc, builder, lhs, offset,
2181+ oneIdx);
2182+ rhsInt = genExtractAndConvertToInt (loc, builder, rhs, offset,
2183+ oneIdx);
2184+ } else if (step == GenCmp::LeftToBlank) {
2185+ // lhsLen > rhsLen
2186+ offset =
2187+ mlir::arith::AddIOp::create (builder, loc, rhsLen, offset);
2188+
2189+ lhsInt = genExtractAndConvertToInt (loc, builder, lhs, offset,
2190+ oneIdx);
2191+ rhsInt = blankInt;
2192+ } else if (step == GenCmp::BlankToRight) {
2193+ // rhsLen > lhsLen
2194+ offset =
2195+ mlir::arith::AddIOp::create (builder, loc, lhsLen, offset);
2196+
2197+ lhsInt = blankInt;
2198+ rhsInt = genExtractAndConvertToInt (loc, builder, rhs, offset,
2199+ oneIdx);
2200+ } else {
2201+ llvm_unreachable (
2202+ " unknown compare step for CmpCharOp lowering" );
2203+ }
2204+
2205+ mlir::Value newVal = mlir::arith::SelectOp::create (
2206+ builder, loc,
2207+ mlir::arith::CmpIOp::create (builder, loc,
2208+ mlir::arith::CmpIPredicate::ult,
2209+ lhsInt, rhsInt),
2210+ negOneInt, inRes);
2211+ newVal = mlir::arith::SelectOp::create (
2212+ builder, loc,
2213+ mlir::arith::CmpIOp::create (builder, loc,
2214+ mlir::arith::CmpIPredicate::ugt,
2215+ lhsInt, rhsInt),
2216+ oneInt, newVal);
2217+ fir::ResultOp::create (builder, loc, newVal);
2218+ })
2219+ .genElse ([&]() { fir::ResultOp::create (builder, loc, inRes); })
2220+ .getResults ()[0 ];
2221+
2222+ return {res};
2223+ };
2224+
2225+ // First generate comparison of two strings for the legth of the shorter
2226+ // one.
2227+ mlir::Value minLen = mlir::arith::SelectOp::create (
2228+ builder, loc,
2229+ mlir::arith::CmpIOp::create (
2230+ builder, loc, mlir::arith::CmpIPredicate::slt, lhsLen, rhsLen),
2231+ lhsLen, rhsLen);
2232+
2233+ llvm::SmallVector<mlir::Value, 1 > loopOut =
2234+ hlfir::genLoopNestWithReductions (loc, builder, {minLen},
2235+ /* reductionInits=*/ {zeroInt}, genCmp,
2236+ /* isUnordered=*/ false );
2237+ mlir::Value partRes = loopOut[0 ];
2238+
2239+ auto lhsLonger = mlir::arith::CmpIOp::create (
2240+ builder, loc, mlir::arith::CmpIPredicate::sgt, lhsLen, rhsLen);
2241+ mlir::Value tempRes =
2242+ builder
2243+ .genIfOp (loc, {intTy}, lhsLonger,
2244+ /* withElseRegion=*/ true )
2245+ .genThen ([&]() {
2246+ // If left is the longer string generate compare left to blank.
2247+ step = GenCmp::LeftToBlank;
2248+ auto lenDiff =
2249+ mlir::arith::SubIOp::create (builder, loc, lhsLen, rhsLen);
2250+
2251+ llvm::SmallVector<mlir::Value, 1 > output =
2252+ hlfir::genLoopNestWithReductions (loc, builder, {lenDiff},
2253+ /* reductionInits=*/ {partRes},
2254+ genCmp,
2255+ /* isUnordered=*/ false );
2256+ mlir::Value res = output[0 ];
2257+ fir::ResultOp::create (builder, loc, res);
2258+ })
2259+ .genElse ([&]() {
2260+ // If right is the longer string generate compare blank to
2261+ // right.
2262+ step = GenCmp::BlankToRight;
2263+ auto lenDiff =
2264+ mlir::arith::SubIOp::create (builder, loc, rhsLen, lhsLen);
2265+ llvm::SmallVector<mlir::Value, 1 > output =
2266+ hlfir::genLoopNestWithReductions (loc, builder, {lenDiff},
2267+ /* reductionInits=*/ {partRes},
2268+ genCmp,
2269+ /* isUnordered=*/ false );
2270+
2271+ mlir::Value res = output[0 ];
2272+ fir::ResultOp::create (builder, loc, res);
2273+ })
2274+ .getResults ()[0 ];
2275+ if (lhsAssociate)
2276+ hlfir::EndAssociateOp::create (builder, loc, lhsAssociate);
2277+ if (rhsAssociate)
2278+ hlfir::EndAssociateOp::create (builder, loc, rhsAssociate);
2279+
2280+ auto finalCmpResult =
2281+ mlir::arith::CmpIOp::create (builder, loc, predicate, tempRes, zeroInt);
2282+ rewriter.replaceOp (cmp, finalCmpResult);
2283+ return mlir::success ();
2284+ }
2285+ };
2286+
20812287template <typename Op>
20822288class MatmulConversion : public mlir ::OpRewritePattern<Op> {
20832289public:
@@ -2748,8 +2954,8 @@ class SimplifyHLFIRIntrinsics
27482954 patterns.insert <ReductionConversion<hlfir::SumOp>>(context);
27492955 patterns.insert <ArrayShiftConversion<hlfir::CShiftOp>>(context);
27502956 patterns.insert <ArrayShiftConversion<hlfir::EOShiftOp>>(context);
2957+ patterns.insert <CmpCharOpConversion>(context);
27512958 patterns.insert <MatmulConversion<hlfir::MatmulTransposeOp>>(context);
2752-
27532959 patterns.insert <ReductionConversion<hlfir::CountOp>>(context);
27542960 patterns.insert <ReductionConversion<hlfir::AnyOp>>(context);
27552961 patterns.insert <ReductionConversion<hlfir::AllOp>>(context);
0 commit comments