1818#include " mlir/Dialect/SCF/IR/SCF.h"
1919#include " mlir/IR/BuiltinOps.h"
2020#include " mlir/IR/IRMapping.h"
21+ #include " mlir/IR/OpDefinition.h"
2122#include " mlir/IR/PatternMatch.h"
2223#include " mlir/Interfaces/SideEffectInterfaces.h"
2324#include " mlir/Transforms/RegionUtils.h"
2930
3031using namespace mlir ;
3132
32- namespace {
33- // This structure is to pass and return sets of loop parameters without
34- // confusing the order.
35- struct LoopParams {
36- Value lowerBound;
37- Value upperBound;
38- Value step;
39- };
40- } // namespace
41-
4233SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields (
4334 RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
4435 ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
@@ -473,17 +464,9 @@ LogicalResult mlir::loopUnrollByFactor(
473464 return success ();
474465}
475466
476- // / Transform a loop with a strictly positive step
477- // / for %i = %lb to %ub step %s
478- // / into a 0-based loop with step 1
479- // / for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
480- // / %i = %ii * %s + %lb
481- // / Insert the induction variable remapping in the body of `inner`, which is
482- // / expected to be either `loop` or another loop perfectly nested under `loop`.
483- // / Insert the definition of new bounds immediate before `outer`, which is
484- // / expected to be either `loop` or its parent in the loop nest.
485- static LoopParams emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
486- Value lb, Value ub, Value step) {
467+ LoopParams mlir::emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
468+ OpFoldResult lb, OpFoldResult ub,
469+ OpFoldResult step) {
487470 // For non-index types, generate `arith` instructions
488471 // Check if the loop is already known to have a constant zero lower bound or
489472 // a constant one step.
@@ -495,45 +478,54 @@ static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
495478 if (auto stepCst = getConstantIntValue (step))
496479 isStepOne = stepCst.value () == 1 ;
497480
481+ Type loopParamsType = getType (lb);
482+ assert (loopParamsType == getType (ub) && loopParamsType == getType (step) &&
483+ " expected matching types" );
484+
498485 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
499486 // assuming the step is strictly positive. Update the bounds and the step
500487 // of the loop to go from 0 to the number of iterations, if necessary.
501488 if (isZeroBased && isStepOne)
502489 return {lb, ub, step};
503490
504- Value diff = isZeroBased ? ub : rewriter.create <arith::SubIOp>(loc, ub, lb);
505- Value newUpperBound =
506- isStepOne ? diff : rewriter.create <arith::CeilDivSIOp>(loc, diff, step);
491+ OpFoldResult diff = ub;
492+ if (!isZeroBased) {
493+ diff = rewriter.createOrFold <arith::SubIOp>(
494+ loc, getValueOrCreateConstantIntOp (rewriter, loc, ub),
495+ getValueOrCreateConstantIntOp (rewriter, loc, lb));
496+ }
497+ OpFoldResult newUpperBound = diff;
498+ if (!isStepOne) {
499+ newUpperBound = rewriter.createOrFold <arith::CeilDivSIOp>(
500+ loc, getValueOrCreateConstantIntOp (rewriter, loc, diff),
501+ getValueOrCreateConstantIntOp (rewriter, loc, step));
502+ }
507503
508- Value newLowerBound = isZeroBased
509- ? lb
510- : rewriter.create <arith::ConstantOp>(
511- loc, rewriter.getZeroAttr (lb.getType ()));
512- Value newStep = isStepOne
513- ? step
514- : rewriter.create <arith::ConstantOp>(
515- loc, rewriter.getIntegerAttr (step.getType (), 1 ));
504+ OpFoldResult newLowerBound = rewriter.getZeroAttr (loopParamsType);
505+ OpFoldResult newStep = rewriter.getOneAttr (loopParamsType);
516506
517507 return {newLowerBound, newUpperBound, newStep};
518508}
519509
520- // / Get back the original induction variable values after loop normalization
521- static void denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
522- Value normalizedIv, Value origLb,
523- Value origStep) {
510+ void mlir::denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
511+ Value normalizedIv, OpFoldResult origLb,
512+ OpFoldResult origStep) {
524513 Value denormalizedIv;
525514 SmallPtrSet<Operation *, 2 > preserve;
526515 bool isStepOne = isConstantIntValue (origStep, 1 );
527516 bool isZeroBased = isConstantIntValue (origLb, 0 );
528517
529518 Value scaled = normalizedIv;
530519 if (!isStepOne) {
531- scaled = rewriter.create <arith::MulIOp>(loc, normalizedIv, origStep);
520+ Value origStepValue =
521+ getValueOrCreateConstantIntOp (rewriter, loc, origStep);
522+ scaled = rewriter.create <arith::MulIOp>(loc, normalizedIv, origStepValue);
532523 preserve.insert (scaled.getDefiningOp ());
533524 }
534525 denormalizedIv = scaled;
535526 if (!isZeroBased) {
536- denormalizedIv = rewriter.create <arith::AddIOp>(loc, scaled, origLb);
527+ Value origLbValue = getValueOrCreateConstantIntOp (rewriter, loc, origLb);
528+ denormalizedIv = rewriter.create <arith::AddIOp>(loc, scaled, origLbValue);
537529 preserve.insert (denormalizedIv.getDefiningOp ());
538530 }
539531
@@ -638,9 +630,12 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
638630 emitNormalizedLoopBounds (rewriter, loop.getLoc (), lb, ub, step);
639631
640632 rewriter.modifyOpInPlace (loop, [&]() {
641- loop.setLowerBound (newLoopParams.lowerBound );
642- loop.setUpperBound (newLoopParams.upperBound );
643- loop.setStep (newLoopParams.step );
633+ loop.setLowerBound (getValueOrCreateConstantIntOp (
634+ rewriter, loop.getLoc (), newLoopParams.lowerBound ));
635+ loop.setUpperBound (getValueOrCreateConstantIntOp (
636+ rewriter, loop.getLoc (), newLoopParams.upperBound ));
637+ loop.setStep (getValueOrCreateConstantIntOp (rewriter, loop.getLoc (),
638+ newLoopParams.step ));
644639 });
645640
646641 rewriter.setInsertionPointToStart (innermost.getBody ());
@@ -778,18 +773,16 @@ void mlir::collapseParallelLoops(
778773 llvm::sort (dims);
779774
780775 // Normalize ParallelOp's iteration pattern.
781- SmallVector<Value, 3 > normalizedLowerBounds, normalizedSteps,
782- normalizedUpperBounds;
776+ SmallVector<Value, 3 > normalizedUpperBounds;
783777 for (unsigned i = 0 , e = loops.getNumLoops (); i < e; ++i) {
784778 OpBuilder::InsertionGuard g2 (rewriter);
785779 rewriter.setInsertionPoint (loops);
786780 Value lb = loops.getLowerBound ()[i];
787781 Value ub = loops.getUpperBound ()[i];
788782 Value step = loops.getStep ()[i];
789783 auto newLoopParams = emitNormalizedLoopBounds (rewriter, loc, lb, ub, step);
790- normalizedLowerBounds.push_back (newLoopParams.lowerBound );
791- normalizedUpperBounds.push_back (newLoopParams.upperBound );
792- normalizedSteps.push_back (newLoopParams.step );
784+ normalizedUpperBounds.push_back (getValueOrCreateConstantIntOp (
785+ rewriter, loops.getLoc (), newLoopParams.upperBound ));
793786
794787 rewriter.setInsertionPointToStart (loops.getBody ());
795788 denormalizeInductionVariable (rewriter, loc, loops.getInductionVars ()[i], lb,
0 commit comments