@@ -318,9 +318,12 @@ void ConditionOp::getSuccessorRegions(
318318
319319void ForOp::build (OpBuilder &builder, OperationState &result, Value lb,
320320 Value ub, Value step, ValueRange initArgs,
321- BodyBuilderFn bodyBuilder) {
321+ BodyBuilderFn bodyBuilder, bool unsignedCmp ) {
322322 OpBuilder::InsertionGuard guard (builder);
323323
324+ if (unsignedCmp)
325+ result.addAttribute (getUnsignedCmpAttrName (result.name ),
326+ builder.getUnitAttr ());
324327 result.addOperands ({lb, ub, step});
325328 result.addOperands (initArgs);
326329 for (Value v : initArgs)
@@ -450,6 +453,9 @@ static void printInitializationList(OpAsmPrinter &p,
450453}
451454
452455void ForOp::print (OpAsmPrinter &p) {
456+ if (getUnsignedCmp ())
457+ p << " unsigned" ;
458+
453459 p << " " << getInductionVar () << " = " << getLowerBound () << " to "
454460 << getUpperBound () << " step " << getStep ();
455461
@@ -462,7 +468,8 @@ void ForOp::print(OpAsmPrinter &p) {
462468 p.printRegion (getRegion (),
463469 /* printEntryBlockArgs=*/ false ,
464470 /* printBlockTerminators=*/ !getInitArgs ().empty ());
465- p.printOptionalAttrDict ((*this )->getAttrs ());
471+ p.printOptionalAttrDict ((*this )->getAttrs (),
472+ /* elidedAttrs=*/ getUnsignedCmpAttrName ().strref ());
466473}
467474
468475ParseResult ForOp::parse (OpAsmParser &parser, OperationState &result) {
@@ -472,6 +479,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
472479 OpAsmParser::Argument inductionVariable;
473480 OpAsmParser::UnresolvedOperand lb, ub, step;
474481
482+ if (succeeded (parser.parseOptionalKeyword (" unsigned" )))
483+ result.addAttribute (getUnsignedCmpAttrName (result.name ),
484+ builder.getUnitAttr ());
485+
475486 // Parse the induction variable followed by '='.
476487 if (parser.parseOperand (inductionVariable.ssaName ) || parser.parseEqual () ||
477488 // Parse loop bounds.
@@ -562,7 +573,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
562573 inits.append (newInitOperands.begin (), newInitOperands.end ());
563574 scf::ForOp newLoop = scf::ForOp::create (
564575 rewriter, getLoc (), getLowerBound (), getUpperBound (), getStep (), inits,
565- [](OpBuilder &, Location, Value, ValueRange) {});
576+ [](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp () );
566577 newLoop->setAttrs (getPrunedAttributeList (getOperation (), {}));
567578
568579 // Generate the new yield values and append them to the scf.yield operation.
@@ -806,7 +817,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
806817 // 2. Create the new forOp shell.
807818 scf::ForOp newForOp = scf::ForOp::create (
808819 rewriter, forOp.getLoc (), forOp.getLowerBound (), forOp.getUpperBound (),
809- forOp.getStep (), newIterOperands);
820+ forOp.getStep (), newIterOperands, /* bodyBuilder=*/ nullptr ,
821+ forOp.getUnsignedCmp ());
810822 newForOp->setAttrs (forOp->getAttrs ());
811823 Block &newBlock = newForOp.getRegion ().front ();
812824 SmallVector<Value, 4 > newBlockTransferArgs (newBlock.getArguments ().begin (),
@@ -931,7 +943,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
931943
932944 scf::ForOp newForOp =
933945 scf::ForOp::create (rewriter, forOp.getLoc (), forOp.getLowerBound (),
934- forOp.getUpperBound (), forOp.getStep (), newIterArgs);
946+ forOp.getUpperBound (), forOp.getStep (), newIterArgs,
947+ /* bodyBuilder=*/ nullptr , forOp.getUnsignedCmp ());
935948 newForOp->setAttrs (forOp->getAttrs ());
936949 Block &newBlock = newForOp.getRegion ().front ();
937950
@@ -989,12 +1002,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
9891002// / Util function that tries to compute a constant diff between u and l.
9901003// / Returns std::nullopt when the difference between two AffineValueMap is
9911004// / dynamic.
992- static std::optional<int64_t > computeConstDiff (Value l, Value u) {
1005+ static std::optional<APInt > computeConstDiff (Value l, Value u) {
9931006 IntegerAttr clb, cub;
9941007 if (matchPattern (l, m_Constant (&clb)) && matchPattern (u, m_Constant (&cub))) {
9951008 llvm::APInt lbValue = clb.getValue ();
9961009 llvm::APInt ubValue = cub.getValue ();
997- return ( ubValue - lbValue). getSExtValue () ;
1010+ return ubValue - lbValue;
9981011 }
9991012
10001013 // Else a simple pattern match for x + c or c + x
@@ -1003,7 +1016,7 @@ static std::optional<int64_t> computeConstDiff(Value l, Value u) {
10031016 u, m_Op<arith::AddIOp>(matchers::m_Val (l), m_ConstantInt (&diff))) ||
10041017 matchPattern (
10051018 u, m_Op<arith::AddIOp>(m_ConstantInt (&diff), matchers::m_Val (l))))
1006- return diff. getSExtValue () ;
1019+ return diff;
10071020 return std::nullopt ;
10081021}
10091022
@@ -1022,13 +1035,15 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
10221035 return success ();
10231036 }
10241037
1025- std::optional<int64_t > diff =
1038+ std::optional<APInt > diff =
10261039 computeConstDiff (op.getLowerBound (), op.getUpperBound ());
10271040 if (!diff)
10281041 return failure ();
10291042
10301043 // If the loop is known to have 0 iterations, remove it.
1031- if (*diff <= 0 ) {
1044+ bool zeroOrLessIterations =
1045+ diff->isZero () || (!op.getUnsignedCmp () && diff->isNegative ());
1046+ if (zeroOrLessIterations) {
10321047 rewriter.replaceOp (op, op.getInitArgs ());
10331048 return success ();
10341049 }
@@ -3384,9 +3399,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
33843399
33853400 if (functionType.getNumInputs () != operands.size ()) {
33863401 return parser.emitError (typeLoc)
3387- << " expected as many input types as operands "
3388- << " (expected " << operands.size () << " got "
3389- << functionType.getNumInputs () << " )" ;
3402+ << " expected as many input types as operands " << " (expected "
3403+ << operands.size () << " got " << functionType.getNumInputs () << " )" ;
33903404 }
33913405
33923406 // Resolve input operands.
0 commit comments