Skip to content

Commit 2c27dd5

Browse files
committed
[flang] Simplify index intrinsic for few limited cases
Primarily targeted simplification case of substring being a singleton by inlining a search loop (with an exception where runtime function performs better). Few trivial simplifications also covered.
1 parent 27fa1d0 commit 2c27dd5

File tree

2 files changed

+554
-0
lines changed

2 files changed

+554
-0
lines changed

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2284,6 +2284,214 @@ class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> {
22842284
}
22852285
};
22862286

2287+
static std::pair<mlir::Value, hlfir::AssociateOp>
2288+
getVariable(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value val) {
2289+
// If it is an expression - create a variable from it, or forward
2290+
// the value otherwise.
2291+
hlfir::AssociateOp associate;
2292+
if (!mlir::isa<hlfir::ExprType>(val.getType()))
2293+
return {val, associate};
2294+
hlfir::Entity entity{val};
2295+
mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder);
2296+
associate = hlfir::genAssociateExpr(loc, builder, entity, entity.getType(),
2297+
"", byRefAttr);
2298+
return {associate.getBase(), associate};
2299+
}
2300+
2301+
class IndexOpConversion : public mlir::OpRewritePattern<hlfir::IndexOp> {
2302+
public:
2303+
using mlir::OpRewritePattern<hlfir::IndexOp>::OpRewritePattern;
2304+
2305+
llvm::LogicalResult
2306+
matchAndRewrite(hlfir::IndexOp op,
2307+
mlir::PatternRewriter &rewriter) const override {
2308+
// We simplify only limited cases:
2309+
// 1) a substring length shall be known at compile time
2310+
// 2) if a substring length is 0 then replace with 1 for forward search,
2311+
// or otherwise with the string length + 1 (builder shall const-fold if
2312+
// lookup direction is known at compile time).
2313+
// 3) for known string length at compile time, if it is
2314+
// shorter than substring => replace with zero.
2315+
// 4) if a substring length is one => inline as simple search loop
2316+
// 5) for forward search with input strings of kind=1 runtime is faster.
2317+
// Do not simplify in all the other cases relying on a runtime call.
2318+
2319+
fir::FirOpBuilder builder{rewriter, op.getOperation()};
2320+
const mlir::Location &loc = op->getLoc();
2321+
2322+
auto resultTy = op.getType();
2323+
mlir::Value back = op.getBack();
2324+
mlir::Value substrLen =
2325+
hlfir::genCharLength(loc, builder, hlfir::Entity{op.getSubstr()});
2326+
2327+
auto substrLenCst = fir::getIntIfConstant(substrLen);
2328+
if (!substrLenCst) {
2329+
return rewriter.notifyMatchFailure(
2330+
op, "substring length unknown at compile time");
2331+
}
2332+
mlir::Value strLen =
2333+
hlfir::genCharLength(loc, builder, hlfir::Entity{op.getStr()});
2334+
auto i1Ty = builder.getI1Type();
2335+
auto idxTy = builder.getIndexType();
2336+
if (*substrLenCst == 0) {
2337+
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
2338+
// zero length substring. For back search replace with
2339+
// strLen+1, or otherwise with 1.
2340+
mlir::Value strEnd = mlir::arith::AddIOp::create(
2341+
builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
2342+
if (back)
2343+
back = builder.createConvert(loc, i1Ty, back);
2344+
else
2345+
back = builder.createIntegerConstant(loc, i1Ty, 0);
2346+
mlir::Value result =
2347+
mlir::arith::SelectOp::create(builder, loc, back, strEnd, oneIdx);
2348+
2349+
rewriter.replaceOp(op, builder.createConvert(loc, resultTy, result));
2350+
return mlir::success();
2351+
}
2352+
2353+
if (auto strLenCst = fir::getIntIfConstant(strLen)) {
2354+
if (*strLenCst < *substrLenCst) {
2355+
rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 0));
2356+
return mlir::success();
2357+
}
2358+
if (*strLenCst == 0) {
2359+
// both strings have zero length
2360+
rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 1));
2361+
return mlir::success();
2362+
}
2363+
}
2364+
if (*substrLenCst != 1) {
2365+
return rewriter.notifyMatchFailure(
2366+
op, "rely on runtime implementation if substring length > 1");
2367+
}
2368+
// For forward search and character kind=1 the runtime uses memchr
2369+
// which well optimized. But it looks like memchr idiom is not recognized
2370+
// in LLVM yet. On a micro-kernel test with strings of length 40 runtime
2371+
// had ~2x less execution time vs inlined code. For unknown search direction
2372+
// at compile time pessimistically assume "forward".
2373+
std::optional<bool> isBack;
2374+
if (back) {
2375+
if (auto backCst = fir::getIntIfConstant(back))
2376+
isBack = *backCst != 0;
2377+
} else {
2378+
isBack = false;
2379+
}
2380+
auto charTy = mlir::cast<fir::CharacterType>(
2381+
hlfir::getFortranElementType(op.getSubstr().getType()));
2382+
unsigned kind = charTy.getFKind();
2383+
if (kind == 1 && (!isBack || !*isBack)) {
2384+
return rewriter.notifyMatchFailure(
2385+
op, "rely on runtime implementation for character kind 1");
2386+
}
2387+
2388+
// All checks are passed here. Generate single character search loop.
2389+
auto [strV, strAssociate] = getVariable(builder, loc, op.getStr());
2390+
auto [substrV, substrAssociate] =
2391+
getVariable(builder, loc, op.getSubstr());
2392+
hlfir::Entity str{strV};
2393+
hlfir::Entity substr{substrV};
2394+
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
2395+
2396+
auto genExtractAndConvertToInt = [&charTy, &idxTy, &oneIdx,
2397+
kind](mlir::Location loc,
2398+
fir::FirOpBuilder &builder,
2399+
hlfir::Entity &charStr,
2400+
mlir::Value index) {
2401+
auto bits = builder.getKindMap().getCharacterBitsize(kind);
2402+
auto intTy = builder.getIntegerType(bits);
2403+
auto charLen1Ty =
2404+
fir::CharacterType::getSingleton(builder.getContext(), kind);
2405+
mlir::Type designatorTy =
2406+
fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy));
2407+
auto idxAttr = builder.getIntegerAttr(idxTy, 0);
2408+
2409+
auto singleChr = hlfir::DesignateOp::create(
2410+
builder, loc, designatorTy, charStr, /*component=*/{},
2411+
/*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
2412+
/*substring=*/mlir::ValueRange{index, index},
2413+
/*complexPart=*/std::nullopt,
2414+
/*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{oneIdx},
2415+
fir::FortranVariableFlagsAttr{});
2416+
auto chrVal = fir::LoadOp::create(builder, loc, singleChr);
2417+
mlir::Value intVal = fir::ExtractValueOp::create(
2418+
builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr));
2419+
return intVal;
2420+
};
2421+
2422+
auto wantChar = genExtractAndConvertToInt(loc, builder, substr, oneIdx);
2423+
2424+
// Generate search loop body with the following C equivalent:
2425+
// idx_t result = 0;
2426+
// idx_t end = strlen + 1;
2427+
// char want = substr[0];
2428+
// for (idx_t idx = 1; idx < end; ++idx) {
2429+
// if (result == 0) {
2430+
// idx_t at = back ? end - idx: idx;
2431+
// result = str[at-1] == want ? at : result;
2432+
// }
2433+
// }
2434+
if (!back)
2435+
back = builder.createIntegerConstant(loc, i1Ty, 0);
2436+
else
2437+
back = builder.createConvert(loc, i1Ty, back);
2438+
mlir::Value strEnd = mlir::arith::AddIOp::create(
2439+
builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
2440+
mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
2441+
auto genSearchBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
2442+
mlir::ValueRange index,
2443+
mlir::ValueRange reductionArgs)
2444+
-> llvm::SmallVector<mlir::Value, 1> {
2445+
assert(index.size() == 1 && "expected single loop");
2446+
assert(reductionArgs.size() == 1 && "expected single reduction value");
2447+
mlir::Value inRes = reductionArgs[0];
2448+
auto resEQzero = mlir::arith::CmpIOp::create(
2449+
builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroIdx);
2450+
2451+
mlir::Value res =
2452+
builder
2453+
.genIfOp(loc, {idxTy}, resEQzero,
2454+
/*withElseRegion=*/true)
2455+
.genThen([&]() {
2456+
mlir::Value idx = builder.createConvert(loc, idxTy, index[0]);
2457+
// offset = back ? end - idx : idx;
2458+
mlir::Value offset = mlir::arith::SelectOp::create(
2459+
builder, loc, back,
2460+
mlir::arith::SubIOp::create(builder, loc, strEnd, idx),
2461+
idx);
2462+
2463+
auto haveChar =
2464+
genExtractAndConvertToInt(loc, builder, str, offset);
2465+
auto charsEQ = mlir::arith::CmpIOp::create(
2466+
builder, loc, mlir::arith::CmpIPredicate::eq, haveChar,
2467+
wantChar);
2468+
mlir::Value newVal = mlir::arith::SelectOp::create(
2469+
builder, loc, charsEQ, offset, inRes);
2470+
2471+
fir::ResultOp::create(builder, loc, newVal);
2472+
})
2473+
.genElse([&]() { fir::ResultOp::create(builder, loc, inRes); })
2474+
.getResults()[0];
2475+
return {res};
2476+
};
2477+
2478+
llvm::SmallVector<mlir::Value, 1> loopOut =
2479+
hlfir::genLoopNestWithReductions(loc, builder, {strLen},
2480+
/*reductionInits=*/{zeroIdx},
2481+
genSearchBody,
2482+
/*isUnordered=*/false);
2483+
mlir::Value result = builder.createConvert(loc, resultTy, loopOut[0]);
2484+
2485+
if (strAssociate)
2486+
hlfir::EndAssociateOp::create(builder, loc, strAssociate);
2487+
if (substrAssociate)
2488+
hlfir::EndAssociateOp::create(builder, loc, substrAssociate);
2489+
2490+
rewriter.replaceOp(op, result);
2491+
return mlir::success();
2492+
}
2493+
};
2494+
22872495
template <typename Op>
22882496
class MatmulConversion : public mlir::OpRewritePattern<Op> {
22892497
public:
@@ -2955,6 +3163,7 @@ class SimplifyHLFIRIntrinsics
29553163
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
29563164
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
29573165
patterns.insert<CmpCharOpConversion>(context);
3166+
patterns.insert<IndexOpConversion>(context);
29583167
patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
29593168
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
29603169
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);

0 commit comments

Comments
 (0)