Skip to content

Commit a7d7018

Browse files
authored
[flang] Lower hlfir.cmpchar into inline implementation in simplify-hlfir-intrinsics (#155461)
1 parent 0df4463 commit a7d7018

File tree

2 files changed

+817
-1
lines changed

2 files changed

+817
-1
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 207 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2078,6 +2078,212 @@ class ArrayShiftConversion : public mlir::OpRewritePattern<Op> {
20782078
}
20792079
};
20802080

2081+
class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> {
2082+
public:
2083+
using mlir::OpRewritePattern<hlfir::CmpCharOp>::OpRewritePattern;
2084+
2085+
llvm::LogicalResult
2086+
matchAndRewrite(hlfir::CmpCharOp cmp,
2087+
mlir::PatternRewriter &rewriter) const override {
2088+
2089+
fir::FirOpBuilder builder{rewriter, cmp.getOperation()};
2090+
const mlir::Location &loc = cmp->getLoc();
2091+
2092+
auto toVariable =
2093+
[&builder,
2094+
&loc](mlir::Value val) -> std::pair<mlir::Value, hlfir::AssociateOp> {
2095+
mlir::Value opnd;
2096+
hlfir::AssociateOp associate;
2097+
if (mlir::isa<hlfir::ExprType>(val.getType())) {
2098+
hlfir::Entity entity{val};
2099+
mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder);
2100+
associate = hlfir::genAssociateExpr(loc, builder, entity,
2101+
entity.getType(), "", byRefAttr);
2102+
opnd = associate.getBase();
2103+
} else {
2104+
opnd = val;
2105+
}
2106+
return {opnd, associate};
2107+
};
2108+
2109+
auto [lhsOpnd, lhsAssociate] = toVariable(cmp.getLchr());
2110+
auto [rhsOpnd, rhsAssociate] = toVariable(cmp.getRchr());
2111+
2112+
hlfir::Entity lhs{lhsOpnd};
2113+
hlfir::Entity rhs{rhsOpnd};
2114+
2115+
auto charTy = mlir::cast<fir::CharacterType>(lhs.getFortranElementType());
2116+
unsigned kind = charTy.getFKind();
2117+
2118+
auto bits = builder.getKindMap().getCharacterBitsize(kind);
2119+
auto intTy = builder.getIntegerType(bits);
2120+
2121+
auto idxTy = builder.getIndexType();
2122+
auto charLen1Ty =
2123+
fir::CharacterType::getSingleton(builder.getContext(), kind);
2124+
mlir::Type designatorType =
2125+
fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy));
2126+
auto idxAttr = builder.getIntegerAttr(idxTy, 0);
2127+
2128+
auto genExtractAndConvertToInt =
2129+
[&idxAttr, &intTy, &designatorType](
2130+
mlir::Location loc, fir::FirOpBuilder &builder,
2131+
hlfir::Entity &charStr, mlir::Value index, mlir::Value length) {
2132+
auto singleChr = hlfir::DesignateOp::create(
2133+
builder, loc, designatorType, charStr, /*component=*/{},
2134+
/*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
2135+
/*substring=*/mlir::ValueRange{index, index},
2136+
/*complexPart=*/std::nullopt,
2137+
/*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{length},
2138+
fir::FortranVariableFlagsAttr{});
2139+
auto chrVal = fir::LoadOp::create(builder, loc, singleChr);
2140+
mlir::Value intVal = fir::ExtractValueOp::create(
2141+
builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr));
2142+
return intVal;
2143+
};
2144+
2145+
mlir::arith::CmpIPredicate predicate = cmp.getPredicate();
2146+
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
2147+
2148+
mlir::Value lhsLen = builder.createConvert(
2149+
loc, idxTy, hlfir::genCharLength(loc, builder, lhs));
2150+
mlir::Value rhsLen = builder.createConvert(
2151+
loc, idxTy, hlfir::genCharLength(loc, builder, rhs));
2152+
2153+
enum class GenCmp { LeftToRight, LeftToBlank, BlankToRight };
2154+
2155+
mlir::Value zeroInt = builder.createIntegerConstant(loc, intTy, 0);
2156+
mlir::Value oneInt = builder.createIntegerConstant(loc, intTy, 1);
2157+
mlir::Value negOneInt = builder.createIntegerConstant(loc, intTy, -1);
2158+
mlir::Value blankInt = builder.createIntegerConstant(loc, intTy, ' ');
2159+
2160+
auto step = GenCmp::LeftToRight;
2161+
auto genCmp = [&](mlir::Location loc, fir::FirOpBuilder &builder,
2162+
mlir::ValueRange index, mlir::ValueRange reductionArgs)
2163+
-> llvm::SmallVector<mlir::Value, 1> {
2164+
assert(index.size() == 1 && "expected single loop");
2165+
assert(reductionArgs.size() == 1 && "expected single reduction value");
2166+
mlir::Value inRes = reductionArgs[0];
2167+
auto accEQzero = mlir::arith::CmpIOp::create(
2168+
builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroInt);
2169+
2170+
mlir::Value res =
2171+
builder
2172+
.genIfOp(loc, {intTy}, accEQzero,
2173+
/*withElseRegion=*/true)
2174+
.genThen([&]() {
2175+
mlir::Value offset =
2176+
builder.createConvert(loc, idxTy, index[0]);
2177+
mlir::Value lhsInt;
2178+
mlir::Value rhsInt;
2179+
if (step == GenCmp::LeftToRight) {
2180+
lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset,
2181+
oneIdx);
2182+
rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset,
2183+
oneIdx);
2184+
} else if (step == GenCmp::LeftToBlank) {
2185+
// lhsLen > rhsLen
2186+
offset =
2187+
mlir::arith::AddIOp::create(builder, loc, rhsLen, offset);
2188+
2189+
lhsInt = genExtractAndConvertToInt(loc, builder, lhs, offset,
2190+
oneIdx);
2191+
rhsInt = blankInt;
2192+
} else if (step == GenCmp::BlankToRight) {
2193+
// rhsLen > lhsLen
2194+
offset =
2195+
mlir::arith::AddIOp::create(builder, loc, lhsLen, offset);
2196+
2197+
lhsInt = blankInt;
2198+
rhsInt = genExtractAndConvertToInt(loc, builder, rhs, offset,
2199+
oneIdx);
2200+
} else {
2201+
llvm_unreachable(
2202+
"unknown compare step for CmpCharOp lowering");
2203+
}
2204+
2205+
mlir::Value newVal = mlir::arith::SelectOp::create(
2206+
builder, loc,
2207+
mlir::arith::CmpIOp::create(builder, loc,
2208+
mlir::arith::CmpIPredicate::ult,
2209+
lhsInt, rhsInt),
2210+
negOneInt, inRes);
2211+
newVal = mlir::arith::SelectOp::create(
2212+
builder, loc,
2213+
mlir::arith::CmpIOp::create(builder, loc,
2214+
mlir::arith::CmpIPredicate::ugt,
2215+
lhsInt, rhsInt),
2216+
oneInt, newVal);
2217+
fir::ResultOp::create(builder, loc, newVal);
2218+
})
2219+
.genElse([&]() { fir::ResultOp::create(builder, loc, inRes); })
2220+
.getResults()[0];
2221+
2222+
return {res};
2223+
};
2224+
2225+
// First generate comparison of two strings for the legth of the shorter
2226+
// one.
2227+
mlir::Value minLen = mlir::arith::SelectOp::create(
2228+
builder, loc,
2229+
mlir::arith::CmpIOp::create(
2230+
builder, loc, mlir::arith::CmpIPredicate::slt, lhsLen, rhsLen),
2231+
lhsLen, rhsLen);
2232+
2233+
llvm::SmallVector<mlir::Value, 1> loopOut =
2234+
hlfir::genLoopNestWithReductions(loc, builder, {minLen},
2235+
/*reductionInits=*/{zeroInt}, genCmp,
2236+
/*isUnordered=*/false);
2237+
mlir::Value partRes = loopOut[0];
2238+
2239+
auto lhsLonger = mlir::arith::CmpIOp::create(
2240+
builder, loc, mlir::arith::CmpIPredicate::sgt, lhsLen, rhsLen);
2241+
mlir::Value tempRes =
2242+
builder
2243+
.genIfOp(loc, {intTy}, lhsLonger,
2244+
/*withElseRegion=*/true)
2245+
.genThen([&]() {
2246+
// If left is the longer string generate compare left to blank.
2247+
step = GenCmp::LeftToBlank;
2248+
auto lenDiff =
2249+
mlir::arith::SubIOp::create(builder, loc, lhsLen, rhsLen);
2250+
2251+
llvm::SmallVector<mlir::Value, 1> output =
2252+
hlfir::genLoopNestWithReductions(loc, builder, {lenDiff},
2253+
/*reductionInits=*/{partRes},
2254+
genCmp,
2255+
/*isUnordered=*/false);
2256+
mlir::Value res = output[0];
2257+
fir::ResultOp::create(builder, loc, res);
2258+
})
2259+
.genElse([&]() {
2260+
// If right is the longer string generate compare blank to
2261+
// right.
2262+
step = GenCmp::BlankToRight;
2263+
auto lenDiff =
2264+
mlir::arith::SubIOp::create(builder, loc, rhsLen, lhsLen);
2265+
llvm::SmallVector<mlir::Value, 1> output =
2266+
hlfir::genLoopNestWithReductions(loc, builder, {lenDiff},
2267+
/*reductionInits=*/{partRes},
2268+
genCmp,
2269+
/*isUnordered=*/false);
2270+
2271+
mlir::Value res = output[0];
2272+
fir::ResultOp::create(builder, loc, res);
2273+
})
2274+
.getResults()[0];
2275+
if (lhsAssociate)
2276+
hlfir::EndAssociateOp::create(builder, loc, lhsAssociate);
2277+
if (rhsAssociate)
2278+
hlfir::EndAssociateOp::create(builder, loc, rhsAssociate);
2279+
2280+
auto finalCmpResult =
2281+
mlir::arith::CmpIOp::create(builder, loc, predicate, tempRes, zeroInt);
2282+
rewriter.replaceOp(cmp, finalCmpResult);
2283+
return mlir::success();
2284+
}
2285+
};
2286+
20812287
template <typename Op>
20822288
class MatmulConversion : public mlir::OpRewritePattern<Op> {
20832289
public:
@@ -2748,8 +2954,8 @@ class SimplifyHLFIRIntrinsics
27482954
patterns.insert<ReductionConversion<hlfir::SumOp>>(context);
27492955
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
27502956
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
2957+
patterns.insert<CmpCharOpConversion>(context);
27512958
patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
2752-
27532959
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
27542960
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);
27552961
patterns.insert<ReductionConversion<hlfir::AllOp>>(context);

0 commit comments

Comments
 (0)