@@ -2284,6 +2284,214 @@ class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> {
22842284 }
22852285};
22862286
2287+ static std::pair<mlir::Value, hlfir::AssociateOp>
2288+ getVariable (fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value val) {
2289+ // If it is an expression - create a variable from it, or forward
2290+ // the value otherwise.
2291+ hlfir::AssociateOp associate;
2292+ if (!mlir::isa<hlfir::ExprType>(val.getType ()))
2293+ return {val, associate};
2294+ hlfir::Entity entity{val};
2295+ mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr (builder);
2296+ associate = hlfir::genAssociateExpr (loc, builder, entity, entity.getType (),
2297+ " " , byRefAttr);
2298+ return {associate.getBase (), associate};
2299+ }
2300+
2301+ class IndexOpConversion : public mlir ::OpRewritePattern<hlfir::IndexOp> {
2302+ public:
2303+ using mlir::OpRewritePattern<hlfir::IndexOp>::OpRewritePattern;
2304+
2305+ llvm::LogicalResult
2306+ matchAndRewrite (hlfir::IndexOp op,
2307+ mlir::PatternRewriter &rewriter) const override {
2308+ // We simplify only limited cases:
2309+ // 1) a substring length shall be known at compile time
2310+ // 2) if a substring length is 0 then replace with 1 for forward search,
2311+ // or otherwise with the string length + 1 (builder shall const-fold if
2312+ // lookup direction is known at compile time).
2313+ // 3) for known string length at compile time, if it is
2314+ // shorter than substring => replace with zero.
2315+ // 4) if a substring length is one => inline as simple search loop
2316+ // 5) for forward search with input strings of kind=1 runtime is faster.
2317+ // Do not simplify in all the other cases relying on a runtime call.
2318+
2319+ fir::FirOpBuilder builder{rewriter, op.getOperation ()};
2320+ const mlir::Location &loc = op->getLoc ();
2321+
2322+ auto resultTy = op.getType ();
2323+ mlir::Value back = op.getBack ();
2324+ mlir::Value substrLen =
2325+ hlfir::genCharLength (loc, builder, hlfir::Entity{op.getSubstr ()});
2326+
2327+ auto substrLenCst = fir::getIntIfConstant (substrLen);
2328+ if (!substrLenCst) {
2329+ return rewriter.notifyMatchFailure (
2330+ op, " substring length unknown at compile time" );
2331+ }
2332+ mlir::Value strLen =
2333+ hlfir::genCharLength (loc, builder, hlfir::Entity{op.getStr ()});
2334+ auto i1Ty = builder.getI1Type ();
2335+ auto idxTy = builder.getIndexType ();
2336+ if (*substrLenCst == 0 ) {
2337+ mlir::Value oneIdx = builder.createIntegerConstant (loc, idxTy, 1 );
2338+ // zero length substring. For back search replace with
2339+ // strLen+1, or otherwise with 1.
2340+ mlir::Value strEnd = mlir::arith::AddIOp::create (
2341+ builder, loc, builder.createConvert (loc, idxTy, strLen), oneIdx);
2342+ if (back)
2343+ back = builder.createConvert (loc, i1Ty, back);
2344+ else
2345+ back = builder.createIntegerConstant (loc, i1Ty, 0 );
2346+ mlir::Value result =
2347+ mlir::arith::SelectOp::create (builder, loc, back, strEnd, oneIdx);
2348+
2349+ rewriter.replaceOp (op, builder.createConvert (loc, resultTy, result));
2350+ return mlir::success ();
2351+ }
2352+
2353+ if (auto strLenCst = fir::getIntIfConstant (strLen)) {
2354+ if (*strLenCst < *substrLenCst) {
2355+ rewriter.replaceOp (op, builder.createIntegerConstant (loc, resultTy, 0 ));
2356+ return mlir::success ();
2357+ }
2358+ if (*strLenCst == 0 ) {
2359+ // both strings have zero length
2360+ rewriter.replaceOp (op, builder.createIntegerConstant (loc, resultTy, 1 ));
2361+ return mlir::success ();
2362+ }
2363+ }
2364+ if (*substrLenCst != 1 ) {
2365+ return rewriter.notifyMatchFailure (
2366+ op, " rely on runtime implementation if substring length > 1" );
2367+ }
2368+ // For forward search and character kind=1 the runtime uses memchr
2369+ // which well optimized. But it looks like memchr idiom is not recognized
2370+ // in LLVM yet. On a micro-kernel test with strings of length 40 runtime
2371+ // had ~2x less execution time vs inlined code. For unknown search direction
2372+ // at compile time pessimistically assume "forward".
2373+ std::optional<bool > isBack;
2374+ if (back) {
2375+ if (auto backCst = fir::getIntIfConstant (back))
2376+ isBack = *backCst != 0 ;
2377+ } else {
2378+ isBack = false ;
2379+ }
2380+ auto charTy = mlir::cast<fir::CharacterType>(
2381+ hlfir::getFortranElementType (op.getSubstr ().getType ()));
2382+ unsigned kind = charTy.getFKind ();
2383+ if (kind == 1 && (!isBack || !*isBack)) {
2384+ return rewriter.notifyMatchFailure (
2385+ op, " rely on runtime implementation for character kind 1" );
2386+ }
2387+
2388+ // All checks are passed here. Generate single character search loop.
2389+ auto [strV, strAssociate] = getVariable (builder, loc, op.getStr ());
2390+ auto [substrV, substrAssociate] =
2391+ getVariable (builder, loc, op.getSubstr ());
2392+ hlfir::Entity str{strV};
2393+ hlfir::Entity substr{substrV};
2394+ mlir::Value oneIdx = builder.createIntegerConstant (loc, idxTy, 1 );
2395+
2396+ auto genExtractAndConvertToInt = [&charTy, &idxTy, &oneIdx,
2397+ kind](mlir::Location loc,
2398+ fir::FirOpBuilder &builder,
2399+ hlfir::Entity &charStr,
2400+ mlir::Value index) {
2401+ auto bits = builder.getKindMap ().getCharacterBitsize (kind);
2402+ auto intTy = builder.getIntegerType (bits);
2403+ auto charLen1Ty =
2404+ fir::CharacterType::getSingleton (builder.getContext (), kind);
2405+ mlir::Type designatorTy =
2406+ fir::ReferenceType::get (charLen1Ty, fir::isa_volatile_type (charTy));
2407+ auto idxAttr = builder.getIntegerAttr (idxTy, 0 );
2408+
2409+ auto singleChr = hlfir::DesignateOp::create (
2410+ builder, loc, designatorTy, charStr, /* component=*/ {},
2411+ /* compShape=*/ mlir::Value{}, hlfir::DesignateOp::Subscripts{},
2412+ /* substring=*/ mlir::ValueRange{index, index},
2413+ /* complexPart=*/ std::nullopt ,
2414+ /* shape=*/ mlir::Value{}, /* typeParams=*/ mlir::ValueRange{oneIdx},
2415+ fir::FortranVariableFlagsAttr{});
2416+ auto chrVal = fir::LoadOp::create (builder, loc, singleChr);
2417+ mlir::Value intVal = fir::ExtractValueOp::create (
2418+ builder, loc, intTy, chrVal, builder.getArrayAttr (idxAttr));
2419+ return intVal;
2420+ };
2421+
2422+ auto wantChar = genExtractAndConvertToInt (loc, builder, substr, oneIdx);
2423+
2424+ // Generate search loop body with the following C equivalent:
2425+ // idx_t result = 0;
2426+ // idx_t end = strlen + 1;
2427+ // char want = substr[0];
2428+ // for (idx_t idx = 1; idx < end; ++idx) {
2429+ // if (result == 0) {
2430+ // idx_t at = back ? end - idx: idx;
2431+ // result = str[at-1] == want ? at : result;
2432+ // }
2433+ // }
2434+ if (!back)
2435+ back = builder.createIntegerConstant (loc, i1Ty, 0 );
2436+ else
2437+ back = builder.createConvert (loc, i1Ty, back);
2438+ mlir::Value strEnd = mlir::arith::AddIOp::create (
2439+ builder, loc, builder.createConvert (loc, idxTy, strLen), oneIdx);
2440+ mlir::Value zeroIdx = builder.createIntegerConstant (loc, idxTy, 0 );
2441+ auto genSearchBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
2442+ mlir::ValueRange index,
2443+ mlir::ValueRange reductionArgs)
2444+ -> llvm::SmallVector<mlir::Value, 1 > {
2445+ assert (index.size () == 1 && " expected single loop" );
2446+ assert (reductionArgs.size () == 1 && " expected single reduction value" );
2447+ mlir::Value inRes = reductionArgs[0 ];
2448+ auto resEQzero = mlir::arith::CmpIOp::create (
2449+ builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroIdx);
2450+
2451+ mlir::Value res =
2452+ builder
2453+ .genIfOp (loc, {idxTy}, resEQzero,
2454+ /* withElseRegion=*/ true )
2455+ .genThen ([&]() {
2456+ mlir::Value idx = builder.createConvert (loc, idxTy, index[0 ]);
2457+ // offset = back ? end - idx : idx;
2458+ mlir::Value offset = mlir::arith::SelectOp::create (
2459+ builder, loc, back,
2460+ mlir::arith::SubIOp::create (builder, loc, strEnd, idx),
2461+ idx);
2462+
2463+ auto haveChar =
2464+ genExtractAndConvertToInt (loc, builder, str, offset);
2465+ auto charsEQ = mlir::arith::CmpIOp::create (
2466+ builder, loc, mlir::arith::CmpIPredicate::eq, haveChar,
2467+ wantChar);
2468+ mlir::Value newVal = mlir::arith::SelectOp::create (
2469+ builder, loc, charsEQ, offset, inRes);
2470+
2471+ fir::ResultOp::create (builder, loc, newVal);
2472+ })
2473+ .genElse ([&]() { fir::ResultOp::create (builder, loc, inRes); })
2474+ .getResults ()[0 ];
2475+ return {res};
2476+ };
2477+
2478+ llvm::SmallVector<mlir::Value, 1 > loopOut =
2479+ hlfir::genLoopNestWithReductions (loc, builder, {strLen},
2480+ /* reductionInits=*/ {zeroIdx},
2481+ genSearchBody,
2482+ /* isUnordered=*/ false );
2483+ mlir::Value result = builder.createConvert (loc, resultTy, loopOut[0 ]);
2484+
2485+ if (strAssociate)
2486+ hlfir::EndAssociateOp::create (builder, loc, strAssociate);
2487+ if (substrAssociate)
2488+ hlfir::EndAssociateOp::create (builder, loc, substrAssociate);
2489+
2490+ rewriter.replaceOp (op, result);
2491+ return mlir::success ();
2492+ }
2493+ };
2494+
22872495template <typename Op>
22882496class MatmulConversion : public mlir ::OpRewritePattern<Op> {
22892497public:
@@ -2955,6 +3163,7 @@ class SimplifyHLFIRIntrinsics
29553163 patterns.insert <ArrayShiftConversion<hlfir::CShiftOp>>(context);
29563164 patterns.insert <ArrayShiftConversion<hlfir::EOShiftOp>>(context);
29573165 patterns.insert <CmpCharOpConversion>(context);
3166+ patterns.insert <IndexOpConversion>(context);
29583167 patterns.insert <MatmulConversion<hlfir::MatmulTransposeOp>>(context);
29593168 patterns.insert <ReductionConversion<hlfir::CountOp>>(context);
29603169 patterns.insert <ReductionConversion<hlfir::AnyOp>>(context);
0 commit comments