Skip to content

Commit bbae6a4

Browse files
authored
[flang] Simplify hlfir.index in a few limited cases. (#161558)
Primarily targeted simplification case of substring being a singleton by inlining a search loop (with an exception where runtime function performs better). Few trivial simplifications also covered. This is a reapply of #157883 with additional fix to avoid generation of new ops during analysis that mess up greedy rewriter if we end up bailing out without any simplification but just leaving few stranded new ops. For technical reasons this patch comes as a new PR.
1 parent 4845b3e commit bbae6a4

File tree

4 files changed

+618
-19
lines changed

4 files changed

+618
-19
lines changed

flang/include/flang/Optimizer/Builder/HLFIRTools.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,6 +324,10 @@ void genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
324324
mlir::Value genCharLength(mlir::Location loc, fir::FirOpBuilder &builder,
325325
Entity entity);
326326

327+
/// Return character length if known at compile time. Unlike genCharLength
328+
/// it does not create any new op as specifically is intended for analysis.
329+
std::optional<std::int64_t> getCharLengthIfConst(Entity entity);
330+
327331
mlir::Value genRank(mlir::Location loc, fir::FirOpBuilder &builder,
328332
Entity entity, mlir::Type resultType);
329333

flang/lib/Optimizer/Builder/HLFIRTools.cpp

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,34 @@ mlir::Value hlfir::genLBound(mlir::Location loc, fir::FirOpBuilder &builder,
676676
return dimInfo.getLowerBound();
677677
}
678678

679+
static bool
680+
getExprLengthParameters(mlir::Value expr,
681+
llvm::SmallVectorImpl<mlir::Value> &result) {
682+
if (auto concat = expr.getDefiningOp<hlfir::ConcatOp>()) {
683+
result.push_back(concat.getLength());
684+
return true;
685+
}
686+
if (auto setLen = expr.getDefiningOp<hlfir::SetLengthOp>()) {
687+
result.push_back(setLen.getLength());
688+
return true;
689+
}
690+
if (auto elemental = expr.getDefiningOp<hlfir::ElementalOp>()) {
691+
result.append(elemental.getTypeparams().begin(),
692+
elemental.getTypeparams().end());
693+
return true;
694+
}
695+
if (auto evalInMem = expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
696+
result.append(evalInMem.getTypeparams().begin(),
697+
evalInMem.getTypeparams().end());
698+
return true;
699+
}
700+
if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
701+
result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
702+
return true;
703+
}
704+
return false;
705+
}
706+
679707
void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
680708
Entity entity,
681709
llvm::SmallVectorImpl<mlir::Value> &result) {
@@ -688,29 +716,14 @@ void hlfir::genLengthParameters(mlir::Location loc, fir::FirOpBuilder &builder,
688716
// Going through fir::ExtendedValue would create a temp,
689717
// which is not desired for an inquiry.
690718
// TODO: make this an interface when adding further character producing ops.
691-
if (auto concat = expr.getDefiningOp<hlfir::ConcatOp>()) {
692-
result.push_back(concat.getLength());
693-
return;
694-
} else if (auto concat = expr.getDefiningOp<hlfir::SetLengthOp>()) {
695-
result.push_back(concat.getLength());
696-
return;
697-
} else if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>()) {
719+
720+
if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>()) {
698721
hlfir::genLengthParameters(loc, builder, hlfir::Entity{asExpr.getVar()},
699722
result);
700723
return;
701-
} else if (auto elemental = expr.getDefiningOp<hlfir::ElementalOp>()) {
702-
result.append(elemental.getTypeparams().begin(),
703-
elemental.getTypeparams().end());
704-
return;
705-
} else if (auto evalInMem =
706-
expr.getDefiningOp<hlfir::EvaluateInMemoryOp>()) {
707-
result.append(evalInMem.getTypeparams().begin(),
708-
evalInMem.getTypeparams().end());
709-
return;
710-
} else if (auto apply = expr.getDefiningOp<hlfir::ApplyOp>()) {
711-
result.append(apply.getTypeparams().begin(), apply.getTypeparams().end());
712-
return;
713724
}
725+
if (getExprLengthParameters(expr, result))
726+
return;
714727
if (entity.isCharacter()) {
715728
result.push_back(hlfir::GetLengthOp::create(builder, loc, expr));
716729
return;
@@ -733,6 +746,36 @@ mlir::Value hlfir::genCharLength(mlir::Location loc, fir::FirOpBuilder &builder,
733746
return lenParams[0];
734747
}
735748

749+
std::optional<std::int64_t> hlfir::getCharLengthIfConst(hlfir::Entity entity) {
750+
if (!entity.isCharacter()) {
751+
return std::nullopt;
752+
}
753+
if (mlir::isa<hlfir::ExprType>(entity.getType())) {
754+
mlir::Value expr = entity;
755+
if (auto reassoc = expr.getDefiningOp<hlfir::NoReassocOp>())
756+
expr = reassoc.getVal();
757+
758+
if (auto asExpr = expr.getDefiningOp<hlfir::AsExprOp>())
759+
return getCharLengthIfConst(hlfir::Entity{asExpr.getVar()});
760+
761+
llvm::SmallVector<mlir::Value> param;
762+
if (getExprLengthParameters(expr, param)) {
763+
assert(param.size() == 1 && "characters must have one length parameters");
764+
return fir::getIntIfConstant(param.pop_back_val());
765+
}
766+
return std::nullopt;
767+
}
768+
769+
// entity is a var
770+
if (mlir::Value len = tryGettingNonDeferredCharLen(entity))
771+
return fir::getIntIfConstant(len);
772+
auto charType =
773+
mlir::cast<fir::CharacterType>(entity.getFortranElementType());
774+
if (charType.hasConstantLen())
775+
return charType.getLen();
776+
return std::nullopt;
777+
}
778+
736779
mlir::Value hlfir::genRank(mlir::Location loc, fir::FirOpBuilder &builder,
737780
hlfir::Entity entity, mlir::Type resultType) {
738781
if (!entity.isAssumedRank())

flang/lib/Optimizer/HLFIR/Transforms/SimplifyHLFIRIntrinsics.cpp

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2284,6 +2284,212 @@ class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> {
22842284
}
22852285
};
22862286

2287+
static std::pair<mlir::Value, hlfir::AssociateOp>
2288+
getVariable(fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value val) {
2289+
// If it is an expression - create a variable from it, or forward
2290+
// the value otherwise.
2291+
hlfir::AssociateOp associate;
2292+
if (!mlir::isa<hlfir::ExprType>(val.getType()))
2293+
return {val, associate};
2294+
hlfir::Entity entity{val};
2295+
mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr(builder);
2296+
associate = hlfir::genAssociateExpr(loc, builder, entity, entity.getType(),
2297+
"", byRefAttr);
2298+
return {associate.getBase(), associate};
2299+
}
2300+
2301+
class IndexOpConversion : public mlir::OpRewritePattern<hlfir::IndexOp> {
2302+
public:
2303+
using mlir::OpRewritePattern<hlfir::IndexOp>::OpRewritePattern;
2304+
2305+
llvm::LogicalResult
2306+
matchAndRewrite(hlfir::IndexOp op,
2307+
mlir::PatternRewriter &rewriter) const override {
2308+
// We simplify only limited cases:
2309+
// 1) a substring length shall be known at compile time
2310+
// 2) if a substring length is 0 then replace with 1 for forward search,
2311+
// or otherwise with the string length + 1 (builder shall const-fold if
2312+
// lookup direction is known at compile time).
2313+
// 3) for known string length at compile time, if it is
2314+
// shorter than substring => replace with zero.
2315+
// 4) if a substring length is one => inline as simple search loop
2316+
// 5) for forward search with input strings of kind=1 runtime is faster.
2317+
// Do not simplify in all the other cases relying on a runtime call.
2318+
2319+
fir::FirOpBuilder builder{rewriter, op.getOperation()};
2320+
const mlir::Location &loc = op->getLoc();
2321+
2322+
auto resultTy = op.getType();
2323+
mlir::Value back = op.getBack();
2324+
auto substrLenCst =
2325+
hlfir::getCharLengthIfConst(hlfir::Entity{op.getSubstr()});
2326+
if (!substrLenCst) {
2327+
return rewriter.notifyMatchFailure(
2328+
op, "substring length unknown at compile time");
2329+
}
2330+
hlfir::Entity strEntity{op.getStr()};
2331+
auto i1Ty = builder.getI1Type();
2332+
auto idxTy = builder.getIndexType();
2333+
if (*substrLenCst == 0) {
2334+
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
2335+
// zero length substring. For back search replace with
2336+
// strLen+1, or otherwise with 1.
2337+
mlir::Value strLen = hlfir::genCharLength(loc, builder, strEntity);
2338+
mlir::Value strEnd = mlir::arith::AddIOp::create(
2339+
builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
2340+
if (back)
2341+
back = builder.createConvert(loc, i1Ty, back);
2342+
else
2343+
back = builder.createIntegerConstant(loc, i1Ty, 0);
2344+
mlir::Value result =
2345+
mlir::arith::SelectOp::create(builder, loc, back, strEnd, oneIdx);
2346+
2347+
rewriter.replaceOp(op, builder.createConvert(loc, resultTy, result));
2348+
return mlir::success();
2349+
}
2350+
2351+
if (auto strLenCst = hlfir::getCharLengthIfConst(strEntity)) {
2352+
if (*strLenCst < *substrLenCst) {
2353+
rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 0));
2354+
return mlir::success();
2355+
}
2356+
if (*strLenCst == 0) {
2357+
// both strings have zero length
2358+
rewriter.replaceOp(op, builder.createIntegerConstant(loc, resultTy, 1));
2359+
return mlir::success();
2360+
}
2361+
}
2362+
if (*substrLenCst != 1) {
2363+
return rewriter.notifyMatchFailure(
2364+
op, "rely on runtime implementation if substring length > 1");
2365+
}
2366+
// For forward search and character kind=1 the runtime uses memchr
2367+
// which well optimized. But it looks like memchr idiom is not recognized
2368+
// in LLVM yet. On a micro-kernel test with strings of length 40 runtime
2369+
// had ~2x less execution time vs inlined code. For unknown search direction
2370+
// at compile time pessimistically assume "forward".
2371+
std::optional<bool> isBack;
2372+
if (back) {
2373+
if (auto backCst = fir::getIntIfConstant(back))
2374+
isBack = *backCst != 0;
2375+
} else {
2376+
isBack = false;
2377+
}
2378+
auto charTy = mlir::cast<fir::CharacterType>(
2379+
hlfir::getFortranElementType(op.getSubstr().getType()));
2380+
unsigned kind = charTy.getFKind();
2381+
if (kind == 1 && (!isBack || !*isBack)) {
2382+
return rewriter.notifyMatchFailure(
2383+
op, "rely on runtime implementation for character kind 1");
2384+
}
2385+
2386+
// All checks are passed here. Generate single character search loop.
2387+
auto [strV, strAssociate] = getVariable(builder, loc, op.getStr());
2388+
auto [substrV, substrAssociate] = getVariable(builder, loc, op.getSubstr());
2389+
hlfir::Entity str{strV};
2390+
hlfir::Entity substr{substrV};
2391+
mlir::Value oneIdx = builder.createIntegerConstant(loc, idxTy, 1);
2392+
2393+
auto genExtractAndConvertToInt = [&charTy, &idxTy, &oneIdx,
2394+
kind](mlir::Location loc,
2395+
fir::FirOpBuilder &builder,
2396+
hlfir::Entity &charStr,
2397+
mlir::Value index) {
2398+
auto bits = builder.getKindMap().getCharacterBitsize(kind);
2399+
auto intTy = builder.getIntegerType(bits);
2400+
auto charLen1Ty =
2401+
fir::CharacterType::getSingleton(builder.getContext(), kind);
2402+
mlir::Type designatorTy =
2403+
fir::ReferenceType::get(charLen1Ty, fir::isa_volatile_type(charTy));
2404+
auto idxAttr = builder.getIntegerAttr(idxTy, 0);
2405+
2406+
auto singleChr = hlfir::DesignateOp::create(
2407+
builder, loc, designatorTy, charStr, /*component=*/{},
2408+
/*compShape=*/mlir::Value{}, hlfir::DesignateOp::Subscripts{},
2409+
/*substring=*/mlir::ValueRange{index, index},
2410+
/*complexPart=*/std::nullopt,
2411+
/*shape=*/mlir::Value{}, /*typeParams=*/mlir::ValueRange{oneIdx},
2412+
fir::FortranVariableFlagsAttr{});
2413+
auto chrVal = fir::LoadOp::create(builder, loc, singleChr);
2414+
mlir::Value intVal = fir::ExtractValueOp::create(
2415+
builder, loc, intTy, chrVal, builder.getArrayAttr(idxAttr));
2416+
return intVal;
2417+
};
2418+
2419+
auto wantChar = genExtractAndConvertToInt(loc, builder, substr, oneIdx);
2420+
2421+
// Generate search loop body with the following C equivalent:
2422+
// idx_t result = 0;
2423+
// idx_t end = strlen + 1;
2424+
// char want = substr[0];
2425+
// for (idx_t idx = 1; idx < end; ++idx) {
2426+
// if (result == 0) {
2427+
// idx_t at = back ? end - idx: idx;
2428+
// result = str[at-1] == want ? at : result;
2429+
// }
2430+
// }
2431+
mlir::Value strLen = hlfir::genCharLength(loc, builder, strEntity);
2432+
if (!back)
2433+
back = builder.createIntegerConstant(loc, i1Ty, 0);
2434+
else
2435+
back = builder.createConvert(loc, i1Ty, back);
2436+
mlir::Value strEnd = mlir::arith::AddIOp::create(
2437+
builder, loc, builder.createConvert(loc, idxTy, strLen), oneIdx);
2438+
mlir::Value zeroIdx = builder.createIntegerConstant(loc, idxTy, 0);
2439+
auto genSearchBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
2440+
mlir::ValueRange index,
2441+
mlir::ValueRange reductionArgs)
2442+
-> llvm::SmallVector<mlir::Value, 1> {
2443+
assert(index.size() == 1 && "expected single loop");
2444+
assert(reductionArgs.size() == 1 && "expected single reduction value");
2445+
mlir::Value inRes = reductionArgs[0];
2446+
auto resEQzero = mlir::arith::CmpIOp::create(
2447+
builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroIdx);
2448+
2449+
mlir::Value res =
2450+
builder
2451+
.genIfOp(loc, {idxTy}, resEQzero,
2452+
/*withElseRegion=*/true)
2453+
.genThen([&]() {
2454+
mlir::Value idx = builder.createConvert(loc, idxTy, index[0]);
2455+
// offset = back ? end - idx : idx;
2456+
mlir::Value offset = mlir::arith::SelectOp::create(
2457+
builder, loc, back,
2458+
mlir::arith::SubIOp::create(builder, loc, strEnd, idx),
2459+
idx);
2460+
2461+
auto haveChar =
2462+
genExtractAndConvertToInt(loc, builder, str, offset);
2463+
auto charsEQ = mlir::arith::CmpIOp::create(
2464+
builder, loc, mlir::arith::CmpIPredicate::eq, haveChar,
2465+
wantChar);
2466+
mlir::Value newVal = mlir::arith::SelectOp::create(
2467+
builder, loc, charsEQ, offset, inRes);
2468+
2469+
fir::ResultOp::create(builder, loc, newVal);
2470+
})
2471+
.genElse([&]() { fir::ResultOp::create(builder, loc, inRes); })
2472+
.getResults()[0];
2473+
return {res};
2474+
};
2475+
2476+
llvm::SmallVector<mlir::Value, 1> loopOut =
2477+
hlfir::genLoopNestWithReductions(loc, builder, {strLen},
2478+
/*reductionInits=*/{zeroIdx},
2479+
genSearchBody,
2480+
/*isUnordered=*/false);
2481+
mlir::Value result = builder.createConvert(loc, resultTy, loopOut[0]);
2482+
2483+
if (strAssociate)
2484+
hlfir::EndAssociateOp::create(builder, loc, strAssociate);
2485+
if (substrAssociate)
2486+
hlfir::EndAssociateOp::create(builder, loc, substrAssociate);
2487+
2488+
rewriter.replaceOp(op, result);
2489+
return mlir::success();
2490+
}
2491+
};
2492+
22872493
template <typename Op>
22882494
class MatmulConversion : public mlir::OpRewritePattern<Op> {
22892495
public:
@@ -2955,6 +3161,7 @@ class SimplifyHLFIRIntrinsics
29553161
patterns.insert<ArrayShiftConversion<hlfir::CShiftOp>>(context);
29563162
patterns.insert<ArrayShiftConversion<hlfir::EOShiftOp>>(context);
29573163
patterns.insert<CmpCharOpConversion>(context);
3164+
patterns.insert<IndexOpConversion>(context);
29583165
patterns.insert<MatmulConversion<hlfir::MatmulTransposeOp>>(context);
29593166
patterns.insert<ReductionConversion<hlfir::CountOp>>(context);
29603167
patterns.insert<ReductionConversion<hlfir::AnyOp>>(context);

0 commit comments

Comments
 (0)