1212
1313#include " mlir/Dialect/SCF/Utils/Utils.h"
1414#include " mlir/Analysis/SliceAnalysis.h"
15+ #include " mlir/Dialect/Affine/IR/AffineOps.h"
1516#include " mlir/Dialect/Arith/IR/Arith.h"
1617#include " mlir/Dialect/Arith/Utils/Utils.h"
1718#include " mlir/Dialect/Func/IR/FuncOps.h"
@@ -671,9 +672,26 @@ LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
671672 return success ();
672673}
673674
675+ Range emitNormalizedLoopBoundsForIndexType (RewriterBase &rewriter, Location loc,
676+ OpFoldResult lb, OpFoldResult ub,
677+ OpFoldResult step) {
678+ Range normalizedLoopBounds;
679+ normalizedLoopBounds.offset = rewriter.getIndexAttr (0 );
680+ normalizedLoopBounds.stride = rewriter.getIndexAttr (1 );
681+ AffineExpr s0, s1, s2;
682+ bindSymbols (rewriter.getContext (), s0, s1, s2);
683+ AffineExpr e = (s1 - s0).ceilDiv (s2);
684+ normalizedLoopBounds.size =
685+ affine::makeComposedFoldedAffineApply (rewriter, loc, e, {lb, ub, step});
686+ return normalizedLoopBounds;
687+ }
688+
674689Range mlir::emitNormalizedLoopBounds (RewriterBase &rewriter, Location loc,
675690 OpFoldResult lb, OpFoldResult ub,
676691 OpFoldResult step) {
692+ if (getType (lb).isIndex ()) {
693+ return emitNormalizedLoopBoundsForIndexType (rewriter, loc, lb, ub, step);
694+ }
677695 // For non-index types, generate `arith` instructions
678696 // Check if the loop is already known to have a constant zero lower bound or
679697 // a constant one step.
@@ -714,9 +732,38 @@ Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
714732 return {newLowerBound, newUpperBound, newStep};
715733}
716734
735+ static void denormalizeInductionVariableForIndexType (RewriterBase &rewriter,
736+ Location loc,
737+ Value normalizedIv,
738+ OpFoldResult origLb,
739+ OpFoldResult origStep) {
740+ AffineExpr d0, s0, s1;
741+ bindSymbols (rewriter.getContext (), s0, s1);
742+ bindDims (rewriter.getContext (), d0);
743+ AffineExpr e = d0 * s1 + s0;
744+ OpFoldResult denormalizedIv = affine::makeComposedFoldedAffineApply (
745+ rewriter, loc, e, ArrayRef<OpFoldResult>{normalizedIv, origLb, origStep});
746+ Value denormalizedIvVal =
747+ getValueOrCreateConstantIndexOp (rewriter, loc, denormalizedIv);
748+ SmallPtrSet<Operation *, 1 > preservedUses;
749+ // If an `affine.apply` operation is generated for denormalization, the use
750+ // of `origLb` in those ops must not be replaced. These arent not generated
751+ // when `origLb == 0` and `origStep == 1`.
752+ if (!isConstantIntValue (origLb, 0 ) || !isConstantIntValue (origStep, 1 )) {
753+ if (Operation *preservedUse = denormalizedIvVal.getDefiningOp ()) {
754+ preservedUses.insert (preservedUse);
755+ }
756+ }
757+ rewriter.replaceAllUsesExcept (normalizedIv, denormalizedIvVal, preservedUses);
758+ }
759+
717760void mlir::denormalizeInductionVariable (RewriterBase &rewriter, Location loc,
718761 Value normalizedIv, OpFoldResult origLb,
719762 OpFoldResult origStep) {
763+ if (getType (origLb).isIndex ()) {
764+ return denormalizeInductionVariableForIndexType (rewriter, loc, normalizedIv,
765+ origLb, origStep);
766+ }
720767 Value denormalizedIv;
721768 SmallPtrSet<Operation *, 2 > preserve;
722769 bool isStepOne = isConstantIntValue (origStep, 1 );
@@ -739,10 +786,29 @@ void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
739786 rewriter.replaceAllUsesExcept (normalizedIv, denormalizedIv, preserve);
740787}
741788
789+ static OpFoldResult getProductOfIndexes (RewriterBase &rewriter, Location loc,
790+ ArrayRef<OpFoldResult> values) {
791+ assert (!values.empty () && " unexecpted empty array" );
792+ AffineExpr s0, s1;
793+ bindSymbols (rewriter.getContext (), s0, s1);
794+ AffineExpr mul = s0 * s1;
795+ OpFoldResult products = rewriter.getIndexAttr (1 );
796+ for (auto v : values) {
797+ products = affine::makeComposedFoldedAffineApply (
798+ rewriter, loc, mul, ArrayRef<OpFoldResult>{products, v});
799+ }
800+ return products;
801+ }
802+
742803// / Helper function to multiply a sequence of values.
743804static Value getProductOfIntsOrIndexes (RewriterBase &rewriter, Location loc,
744805 ArrayRef<Value> values) {
745806 assert (!values.empty () && " unexpected empty list" );
807+ if (getType (values.front ()).isIndex ()) {
808+ SmallVector<OpFoldResult> ofrs = getAsOpFoldResult (values);
809+ OpFoldResult product = getProductOfIndexes (rewriter, loc, ofrs);
810+ return getValueOrCreateConstantIndexOp (rewriter, loc, product);
811+ }
746812 std::optional<Value> productOf;
747813 for (auto v : values) {
748814 auto vOne = getConstantIntValue (v);
@@ -757,7 +823,7 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
757823 if (!productOf) {
758824 productOf = rewriter
759825 .create <arith::ConstantOp>(
760- loc, rewriter.getOneAttr (values.front (). getType ( )))
826+ loc, rewriter.getOneAttr (getType ( values.front ())))
761827 .getResult ();
762828 }
763829 return productOf.value ();
@@ -774,6 +840,16 @@ static Value getProductOfIntsOrIndexes(RewriterBase &rewriter, Location loc,
774840static std::pair<SmallVector<Value>, SmallPtrSet<Operation *, 2 >>
775841delinearizeInductionVariable (RewriterBase &rewriter, Location loc,
776842 Value linearizedIv, ArrayRef<Value> ubs) {
843+
844+ if (linearizedIv.getType ().isIndex ()) {
845+ Operation *delinearizedOp =
846+ rewriter.create <affine::AffineDelinearizeIndexOp>(loc, linearizedIv,
847+ ubs);
848+ auto resultVals = llvm::map_to_vector (
849+ delinearizedOp->getResults (), [](OpResult r) -> Value { return r; });
850+ return {resultVals, SmallPtrSet<Operation *, 2 >{delinearizedOp}};
851+ }
852+
777853 SmallVector<Value> delinearizedIvs (ubs.size ());
778854 SmallPtrSet<Operation *, 2 > preservedUsers;
779855
0 commit comments