@@ -2078,6 +2078,212 @@ class ArrayShiftConversion : public mlir::OpRewritePattern<Op> {
2078
2078
}
2079
2079
};
2080
2080
2081
+ class CmpCharOpConversion : public mlir ::OpRewritePattern<hlfir::CmpCharOp> {
2082
+ public:
2083
+ using mlir::OpRewritePattern<hlfir::CmpCharOp>::OpRewritePattern;
2084
+
2085
+ llvm::LogicalResult
2086
+ matchAndRewrite (hlfir::CmpCharOp cmp,
2087
+ mlir::PatternRewriter &rewriter) const override {
2088
+
2089
+ fir::FirOpBuilder builder{rewriter, cmp.getOperation ()};
2090
+ const mlir::Location &loc = cmp->getLoc ();
2091
+
2092
+ auto toVariable =
2093
+ [&builder,
2094
+ &loc](mlir::Value val) -> std::pair<mlir::Value, hlfir::AssociateOp> {
2095
+ mlir::Value opnd;
2096
+ hlfir::AssociateOp associate;
2097
+ if (mlir::isa<hlfir::ExprType>(val.getType ())) {
2098
+ hlfir::Entity entity{val};
2099
+ mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr (builder);
2100
+ associate = hlfir::genAssociateExpr (loc, builder, entity,
2101
+ entity.getType (), " " , byRefAttr);
2102
+ opnd = associate.getBase ();
2103
+ } else {
2104
+ opnd = val;
2105
+ }
2106
+ return {opnd, associate};
2107
+ };
2108
+
2109
+ auto [lhsOpnd, lhsAssociate] = toVariable (cmp.getLchr ());
2110
+ auto [rhsOpnd, rhsAssociate] = toVariable (cmp.getRchr ());
2111
+
2112
+ hlfir::Entity lhs{lhsOpnd};
2113
+ hlfir::Entity rhs{rhsOpnd};
2114
+
2115
+ auto charTy = mlir::cast<fir::CharacterType>(lhs.getFortranElementType ());
2116
+ unsigned kind = charTy.getFKind ();
2117
+
2118
+ auto bits = builder.getKindMap ().getCharacterBitsize (kind);
2119
+ auto intTy = builder.getIntegerType (bits);
2120
+
2121
+ auto idxTy = builder.getIndexType ();
2122
+ auto charLen1Ty =
2123
+ fir::CharacterType::getSingleton (builder.getContext (), kind);
2124
+ mlir::Type designatorType =
2125
+ fir::ReferenceType::get (charLen1Ty, fir::isa_volatile_type (charTy));
2126
+ auto idxAttr = builder.getIntegerAttr (idxTy, 0 );
2127
+
2128
+ auto genExtractAndConvertToInt =
2129
+ [&idxAttr, &intTy, &designatorType](
2130
+ mlir::Location loc, fir::FirOpBuilder &builder,
2131
+ hlfir::Entity &charStr, mlir::Value index, mlir::Value length) {
2132
+ auto singleChr = hlfir::DesignateOp::create (
2133
+ builder, loc, designatorType, charStr, /* component=*/ {},
2134
+ /* compShape=*/ mlir::Value{}, hlfir::DesignateOp::Subscripts{},
2135
+ /* substring=*/ mlir::ValueRange{index, index},
2136
+ /* complexPart=*/ std::nullopt ,
2137
+ /* shape=*/ mlir::Value{}, /* typeParams=*/ mlir::ValueRange{length},
2138
+ fir::FortranVariableFlagsAttr{});
2139
+ auto chrVal = fir::LoadOp::create (builder, loc, singleChr);
2140
+ mlir::Value intVal = fir::ExtractValueOp::create (
2141
+ builder, loc, intTy, chrVal, builder.getArrayAttr (idxAttr));
2142
+ return intVal;
2143
+ };
2144
+
2145
+ mlir::arith::CmpIPredicate predicate = cmp.getPredicate ();
2146
+ mlir::Value oneIdx = builder.createIntegerConstant (loc, idxTy, 1 );
2147
+
2148
+ mlir::Value lhsLen = builder.createConvert (
2149
+ loc, idxTy, hlfir::genCharLength (loc, builder, lhs));
2150
+ mlir::Value rhsLen = builder.createConvert (
2151
+ loc, idxTy, hlfir::genCharLength (loc, builder, rhs));
2152
+
2153
+ enum class GenCmp { LeftToRight, LeftToBlank, BlankToRight };
2154
+
2155
+ mlir::Value zeroInt = builder.createIntegerConstant (loc, intTy, 0 );
2156
+ mlir::Value oneInt = builder.createIntegerConstant (loc, intTy, 1 );
2157
+ mlir::Value negOneInt = builder.createIntegerConstant (loc, intTy, -1 );
2158
+ mlir::Value blankInt = builder.createIntegerConstant (loc, intTy, ' ' );
2159
+
2160
+ auto step = GenCmp::LeftToRight;
2161
+ auto genCmp = [&](mlir::Location loc, fir::FirOpBuilder &builder,
2162
+ mlir::ValueRange index, mlir::ValueRange reductionArgs)
2163
+ -> llvm::SmallVector<mlir::Value, 1 > {
2164
+ assert (index.size () == 1 && " expected single loop" );
2165
+ assert (reductionArgs.size () == 1 && " expected single reduction value" );
2166
+ mlir::Value inRes = reductionArgs[0 ];
2167
+ auto accEQzero = mlir::arith::CmpIOp::create (
2168
+ builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroInt);
2169
+
2170
+ mlir::Value res =
2171
+ builder
2172
+ .genIfOp (loc, {intTy}, accEQzero,
2173
+ /* withElseRegion=*/ true )
2174
+ .genThen ([&]() {
2175
+ mlir::Value offset =
2176
+ builder.createConvert (loc, idxTy, index[0 ]);
2177
+ mlir::Value lhsInt;
2178
+ mlir::Value rhsInt;
2179
+ if (step == GenCmp::LeftToRight) {
2180
+ lhsInt = genExtractAndConvertToInt (loc, builder, lhs, offset,
2181
+ oneIdx);
2182
+ rhsInt = genExtractAndConvertToInt (loc, builder, rhs, offset,
2183
+ oneIdx);
2184
+ } else if (step == GenCmp::LeftToBlank) {
2185
+ // lhsLen > rhsLen
2186
+ offset =
2187
+ mlir::arith::AddIOp::create (builder, loc, rhsLen, offset);
2188
+
2189
+ lhsInt = genExtractAndConvertToInt (loc, builder, lhs, offset,
2190
+ oneIdx);
2191
+ rhsInt = blankInt;
2192
+ } else if (step == GenCmp::BlankToRight) {
2193
+ // rhsLen > lhsLen
2194
+ offset =
2195
+ mlir::arith::AddIOp::create (builder, loc, lhsLen, offset);
2196
+
2197
+ lhsInt = blankInt;
2198
+ rhsInt = genExtractAndConvertToInt (loc, builder, rhs, offset,
2199
+ oneIdx);
2200
+ } else {
2201
+ llvm_unreachable (
2202
+ " unknown compare step for CmpCharOp lowering" );
2203
+ }
2204
+
2205
+ mlir::Value newVal = mlir::arith::SelectOp::create (
2206
+ builder, loc,
2207
+ mlir::arith::CmpIOp::create (builder, loc,
2208
+ mlir::arith::CmpIPredicate::ult,
2209
+ lhsInt, rhsInt),
2210
+ negOneInt, inRes);
2211
+ newVal = mlir::arith::SelectOp::create (
2212
+ builder, loc,
2213
+ mlir::arith::CmpIOp::create (builder, loc,
2214
+ mlir::arith::CmpIPredicate::ugt,
2215
+ lhsInt, rhsInt),
2216
+ oneInt, newVal);
2217
+ fir::ResultOp::create (builder, loc, newVal);
2218
+ })
2219
+ .genElse ([&]() { fir::ResultOp::create (builder, loc, inRes); })
2220
+ .getResults ()[0 ];
2221
+
2222
+ return {res};
2223
+ };
2224
+
2225
+ // First generate comparison of two strings for the legth of the shorter
2226
+ // one.
2227
+ mlir::Value minLen = mlir::arith::SelectOp::create (
2228
+ builder, loc,
2229
+ mlir::arith::CmpIOp::create (
2230
+ builder, loc, mlir::arith::CmpIPredicate::slt, lhsLen, rhsLen),
2231
+ lhsLen, rhsLen);
2232
+
2233
+ llvm::SmallVector<mlir::Value, 1 > loopOut =
2234
+ hlfir::genLoopNestWithReductions (loc, builder, {minLen},
2235
+ /* reductionInits=*/ {zeroInt}, genCmp,
2236
+ /* isUnordered=*/ false );
2237
+ mlir::Value partRes = loopOut[0 ];
2238
+
2239
+ auto lhsLonger = mlir::arith::CmpIOp::create (
2240
+ builder, loc, mlir::arith::CmpIPredicate::sgt, lhsLen, rhsLen);
2241
+ mlir::Value tempRes =
2242
+ builder
2243
+ .genIfOp (loc, {intTy}, lhsLonger,
2244
+ /* withElseRegion=*/ true )
2245
+ .genThen ([&]() {
2246
+ // If left is the longer string generate compare left to blank.
2247
+ step = GenCmp::LeftToBlank;
2248
+ auto lenDiff =
2249
+ mlir::arith::SubIOp::create (builder, loc, lhsLen, rhsLen);
2250
+
2251
+ llvm::SmallVector<mlir::Value, 1 > output =
2252
+ hlfir::genLoopNestWithReductions (loc, builder, {lenDiff},
2253
+ /* reductionInits=*/ {partRes},
2254
+ genCmp,
2255
+ /* isUnordered=*/ false );
2256
+ mlir::Value res = output[0 ];
2257
+ fir::ResultOp::create (builder, loc, res);
2258
+ })
2259
+ .genElse ([&]() {
2260
+ // If right is the longer string generate compare blank to
2261
+ // right.
2262
+ step = GenCmp::BlankToRight;
2263
+ auto lenDiff =
2264
+ mlir::arith::SubIOp::create (builder, loc, rhsLen, lhsLen);
2265
+ llvm::SmallVector<mlir::Value, 1 > output =
2266
+ hlfir::genLoopNestWithReductions (loc, builder, {lenDiff},
2267
+ /* reductionInits=*/ {partRes},
2268
+ genCmp,
2269
+ /* isUnordered=*/ false );
2270
+
2271
+ mlir::Value res = output[0 ];
2272
+ fir::ResultOp::create (builder, loc, res);
2273
+ })
2274
+ .getResults ()[0 ];
2275
+ if (lhsAssociate)
2276
+ hlfir::EndAssociateOp::create (builder, loc, lhsAssociate);
2277
+ if (rhsAssociate)
2278
+ hlfir::EndAssociateOp::create (builder, loc, rhsAssociate);
2279
+
2280
+ auto finalCmpResult =
2281
+ mlir::arith::CmpIOp::create (builder, loc, predicate, tempRes, zeroInt);
2282
+ rewriter.replaceOp (cmp, finalCmpResult);
2283
+ return mlir::success ();
2284
+ }
2285
+ };
2286
+
2081
2287
template <typename Op>
2082
2288
class MatmulConversion : public mlir ::OpRewritePattern<Op> {
2083
2289
public:
@@ -2748,8 +2954,8 @@ class SimplifyHLFIRIntrinsics
2748
2954
patterns.insert <ReductionConversion<hlfir::SumOp>>(context);
2749
2955
patterns.insert <ArrayShiftConversion<hlfir::CShiftOp>>(context);
2750
2956
patterns.insert <ArrayShiftConversion<hlfir::EOShiftOp>>(context);
2957
+ patterns.insert <CmpCharOpConversion>(context);
2751
2958
patterns.insert <MatmulConversion<hlfir::MatmulTransposeOp>>(context);
2752
-
2753
2959
patterns.insert <ReductionConversion<hlfir::CountOp>>(context);
2754
2960
patterns.insert <ReductionConversion<hlfir::AnyOp>>(context);
2755
2961
patterns.insert <ReductionConversion<hlfir::AllOp>>(context);
0 commit comments