@@ -2284,6 +2284,212 @@ class CmpCharOpConversion : public mlir::OpRewritePattern<hlfir::CmpCharOp> {
2284
2284
}
2285
2285
};
2286
2286
2287
+ static std::pair<mlir::Value, hlfir::AssociateOp>
2288
+ getVariable (fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value val) {
2289
+ // If it is an expression - create a variable from it, or forward
2290
+ // the value otherwise.
2291
+ hlfir::AssociateOp associate;
2292
+ if (!mlir::isa<hlfir::ExprType>(val.getType ()))
2293
+ return {val, associate};
2294
+ hlfir::Entity entity{val};
2295
+ mlir::NamedAttribute byRefAttr = fir::getAdaptToByRefAttr (builder);
2296
+ associate = hlfir::genAssociateExpr (loc, builder, entity, entity.getType (),
2297
+ " " , byRefAttr);
2298
+ return {associate.getBase (), associate};
2299
+ }
2300
+
2301
+ class IndexOpConversion : public mlir ::OpRewritePattern<hlfir::IndexOp> {
2302
+ public:
2303
+ using mlir::OpRewritePattern<hlfir::IndexOp>::OpRewritePattern;
2304
+
2305
+ llvm::LogicalResult
2306
+ matchAndRewrite (hlfir::IndexOp op,
2307
+ mlir::PatternRewriter &rewriter) const override {
2308
+ // We simplify only limited cases:
2309
+ // 1) a substring length shall be known at compile time
2310
+ // 2) if a substring length is 0 then replace with 1 for forward search,
2311
+ // or otherwise with the string length + 1 (builder shall const-fold if
2312
+ // lookup direction is known at compile time).
2313
+ // 3) for known string length at compile time, if it is
2314
+ // shorter than substring => replace with zero.
2315
+ // 4) if a substring length is one => inline as simple search loop
2316
+ // 5) for forward search with input strings of kind=1 runtime is faster.
2317
+ // Do not simplify in all the other cases relying on a runtime call.
2318
+
2319
+ fir::FirOpBuilder builder{rewriter, op.getOperation ()};
2320
+ const mlir::Location &loc = op->getLoc ();
2321
+
2322
+ auto resultTy = op.getType ();
2323
+ mlir::Value back = op.getBack ();
2324
+ auto substrLenCst =
2325
+ hlfir::getCharLengthIfConst (hlfir::Entity{op.getSubstr ()});
2326
+ if (!substrLenCst) {
2327
+ return rewriter.notifyMatchFailure (
2328
+ op, " substring length unknown at compile time" );
2329
+ }
2330
+ hlfir::Entity strEntity{op.getStr ()};
2331
+ auto i1Ty = builder.getI1Type ();
2332
+ auto idxTy = builder.getIndexType ();
2333
+ if (*substrLenCst == 0 ) {
2334
+ mlir::Value oneIdx = builder.createIntegerConstant (loc, idxTy, 1 );
2335
+ // zero length substring. For back search replace with
2336
+ // strLen+1, or otherwise with 1.
2337
+ mlir::Value strLen = hlfir::genCharLength (loc, builder, strEntity);
2338
+ mlir::Value strEnd = mlir::arith::AddIOp::create (
2339
+ builder, loc, builder.createConvert (loc, idxTy, strLen), oneIdx);
2340
+ if (back)
2341
+ back = builder.createConvert (loc, i1Ty, back);
2342
+ else
2343
+ back = builder.createIntegerConstant (loc, i1Ty, 0 );
2344
+ mlir::Value result =
2345
+ mlir::arith::SelectOp::create (builder, loc, back, strEnd, oneIdx);
2346
+
2347
+ rewriter.replaceOp (op, builder.createConvert (loc, resultTy, result));
2348
+ return mlir::success ();
2349
+ }
2350
+
2351
+ if (auto strLenCst = hlfir::getCharLengthIfConst (strEntity)) {
2352
+ if (*strLenCst < *substrLenCst) {
2353
+ rewriter.replaceOp (op, builder.createIntegerConstant (loc, resultTy, 0 ));
2354
+ return mlir::success ();
2355
+ }
2356
+ if (*strLenCst == 0 ) {
2357
+ // both strings have zero length
2358
+ rewriter.replaceOp (op, builder.createIntegerConstant (loc, resultTy, 1 ));
2359
+ return mlir::success ();
2360
+ }
2361
+ }
2362
+ if (*substrLenCst != 1 ) {
2363
+ return rewriter.notifyMatchFailure (
2364
+ op, " rely on runtime implementation if substring length > 1" );
2365
+ }
2366
+ // For forward search and character kind=1 the runtime uses memchr
2367
+ // which well optimized. But it looks like memchr idiom is not recognized
2368
+ // in LLVM yet. On a micro-kernel test with strings of length 40 runtime
2369
+ // had ~2x less execution time vs inlined code. For unknown search direction
2370
+ // at compile time pessimistically assume "forward".
2371
+ std::optional<bool > isBack;
2372
+ if (back) {
2373
+ if (auto backCst = fir::getIntIfConstant (back))
2374
+ isBack = *backCst != 0 ;
2375
+ } else {
2376
+ isBack = false ;
2377
+ }
2378
+ auto charTy = mlir::cast<fir::CharacterType>(
2379
+ hlfir::getFortranElementType (op.getSubstr ().getType ()));
2380
+ unsigned kind = charTy.getFKind ();
2381
+ if (kind == 1 && (!isBack || !*isBack)) {
2382
+ return rewriter.notifyMatchFailure (
2383
+ op, " rely on runtime implementation for character kind 1" );
2384
+ }
2385
+
2386
+ // All checks are passed here. Generate single character search loop.
2387
+ auto [strV, strAssociate] = getVariable (builder, loc, op.getStr ());
2388
+ auto [substrV, substrAssociate] = getVariable (builder, loc, op.getSubstr ());
2389
+ hlfir::Entity str{strV};
2390
+ hlfir::Entity substr{substrV};
2391
+ mlir::Value oneIdx = builder.createIntegerConstant (loc, idxTy, 1 );
2392
+
2393
+ auto genExtractAndConvertToInt = [&charTy, &idxTy, &oneIdx,
2394
+ kind](mlir::Location loc,
2395
+ fir::FirOpBuilder &builder,
2396
+ hlfir::Entity &charStr,
2397
+ mlir::Value index) {
2398
+ auto bits = builder.getKindMap ().getCharacterBitsize (kind);
2399
+ auto intTy = builder.getIntegerType (bits);
2400
+ auto charLen1Ty =
2401
+ fir::CharacterType::getSingleton (builder.getContext (), kind);
2402
+ mlir::Type designatorTy =
2403
+ fir::ReferenceType::get (charLen1Ty, fir::isa_volatile_type (charTy));
2404
+ auto idxAttr = builder.getIntegerAttr (idxTy, 0 );
2405
+
2406
+ auto singleChr = hlfir::DesignateOp::create (
2407
+ builder, loc, designatorTy, charStr, /* component=*/ {},
2408
+ /* compShape=*/ mlir::Value{}, hlfir::DesignateOp::Subscripts{},
2409
+ /* substring=*/ mlir::ValueRange{index, index},
2410
+ /* complexPart=*/ std::nullopt ,
2411
+ /* shape=*/ mlir::Value{}, /* typeParams=*/ mlir::ValueRange{oneIdx},
2412
+ fir::FortranVariableFlagsAttr{});
2413
+ auto chrVal = fir::LoadOp::create (builder, loc, singleChr);
2414
+ mlir::Value intVal = fir::ExtractValueOp::create (
2415
+ builder, loc, intTy, chrVal, builder.getArrayAttr (idxAttr));
2416
+ return intVal;
2417
+ };
2418
+
2419
+ auto wantChar = genExtractAndConvertToInt (loc, builder, substr, oneIdx);
2420
+
2421
+ // Generate search loop body with the following C equivalent:
2422
+ // idx_t result = 0;
2423
+ // idx_t end = strlen + 1;
2424
+ // char want = substr[0];
2425
+ // for (idx_t idx = 1; idx < end; ++idx) {
2426
+ // if (result == 0) {
2427
+ // idx_t at = back ? end - idx: idx;
2428
+ // result = str[at-1] == want ? at : result;
2429
+ // }
2430
+ // }
2431
+ mlir::Value strLen = hlfir::genCharLength (loc, builder, strEntity);
2432
+ if (!back)
2433
+ back = builder.createIntegerConstant (loc, i1Ty, 0 );
2434
+ else
2435
+ back = builder.createConvert (loc, i1Ty, back);
2436
+ mlir::Value strEnd = mlir::arith::AddIOp::create (
2437
+ builder, loc, builder.createConvert (loc, idxTy, strLen), oneIdx);
2438
+ mlir::Value zeroIdx = builder.createIntegerConstant (loc, idxTy, 0 );
2439
+ auto genSearchBody = [&](mlir::Location loc, fir::FirOpBuilder &builder,
2440
+ mlir::ValueRange index,
2441
+ mlir::ValueRange reductionArgs)
2442
+ -> llvm::SmallVector<mlir::Value, 1 > {
2443
+ assert (index.size () == 1 && " expected single loop" );
2444
+ assert (reductionArgs.size () == 1 && " expected single reduction value" );
2445
+ mlir::Value inRes = reductionArgs[0 ];
2446
+ auto resEQzero = mlir::arith::CmpIOp::create (
2447
+ builder, loc, mlir::arith::CmpIPredicate::eq, inRes, zeroIdx);
2448
+
2449
+ mlir::Value res =
2450
+ builder
2451
+ .genIfOp (loc, {idxTy}, resEQzero,
2452
+ /* withElseRegion=*/ true )
2453
+ .genThen ([&]() {
2454
+ mlir::Value idx = builder.createConvert (loc, idxTy, index[0 ]);
2455
+ // offset = back ? end - idx : idx;
2456
+ mlir::Value offset = mlir::arith::SelectOp::create (
2457
+ builder, loc, back,
2458
+ mlir::arith::SubIOp::create (builder, loc, strEnd, idx),
2459
+ idx);
2460
+
2461
+ auto haveChar =
2462
+ genExtractAndConvertToInt (loc, builder, str, offset);
2463
+ auto charsEQ = mlir::arith::CmpIOp::create (
2464
+ builder, loc, mlir::arith::CmpIPredicate::eq, haveChar,
2465
+ wantChar);
2466
+ mlir::Value newVal = mlir::arith::SelectOp::create (
2467
+ builder, loc, charsEQ, offset, inRes);
2468
+
2469
+ fir::ResultOp::create (builder, loc, newVal);
2470
+ })
2471
+ .genElse ([&]() { fir::ResultOp::create (builder, loc, inRes); })
2472
+ .getResults ()[0 ];
2473
+ return {res};
2474
+ };
2475
+
2476
+ llvm::SmallVector<mlir::Value, 1 > loopOut =
2477
+ hlfir::genLoopNestWithReductions (loc, builder, {strLen},
2478
+ /* reductionInits=*/ {zeroIdx},
2479
+ genSearchBody,
2480
+ /* isUnordered=*/ false );
2481
+ mlir::Value result = builder.createConvert (loc, resultTy, loopOut[0 ]);
2482
+
2483
+ if (strAssociate)
2484
+ hlfir::EndAssociateOp::create (builder, loc, strAssociate);
2485
+ if (substrAssociate)
2486
+ hlfir::EndAssociateOp::create (builder, loc, substrAssociate);
2487
+
2488
+ rewriter.replaceOp (op, result);
2489
+ return mlir::success ();
2490
+ }
2491
+ };
2492
+
2287
2493
template <typename Op>
2288
2494
class MatmulConversion : public mlir ::OpRewritePattern<Op> {
2289
2495
public:
@@ -2955,6 +3161,7 @@ class SimplifyHLFIRIntrinsics
2955
3161
patterns.insert <ArrayShiftConversion<hlfir::CShiftOp>>(context);
2956
3162
patterns.insert <ArrayShiftConversion<hlfir::EOShiftOp>>(context);
2957
3163
patterns.insert <CmpCharOpConversion>(context);
3164
+ patterns.insert <IndexOpConversion>(context);
2958
3165
patterns.insert <MatmulConversion<hlfir::MatmulTransposeOp>>(context);
2959
3166
patterns.insert <ReductionConversion<hlfir::CountOp>>(context);
2960
3167
patterns.insert <ReductionConversion<hlfir::AnyOp>>(context);
0 commit comments