Skip to content

Commit 21b607a

Browse files
[mlir][SCF] scf.for: Add support for unsigned integer comparison (#153379)
Add a new unit attribute to allow for unsigned integer comparison. Example: ```mlir scf.for unsigned %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 { // body } ``` Discussion: https://discourse.llvm.org/t/scf-should-scf-for-support-unsigned-comparison/84655
1 parent 6bb8f6f commit 21b607a

File tree

19 files changed

+133
-41
lines changed

19 files changed

+133
-41
lines changed

mlir/include/mlir/Dialect/SCF/IR/SCFOps.td

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,13 @@ def ForOp : SCF_Op<"for",
169169
region capturing the loop body. The induction variable is represented as an
170170
argument of this region. This SSA value is a signless integer or index.
171171
The step is a value of same type but required to be positive, the lower and
172-
upper bounds can be also negative or zero. The lower and upper bounds specify
173-
a half-open range: the iteration is executed iff the signed comparison of induction
174-
variable value is less than the upper bound and bigger or equal to the lower bound.
172+
upper bounds can be also negative or zero. The lower and upper bounds
173+
specify a half-open range: the iteration is executed iff the comparison of
174+
induction variable value is less than the upper bound and bigger or equal
175+
to the lower bound.
176+
177+
By default, the integer comparison is signed. If the `unsignedCmp` unit
178+
attribute is specified, the integer comparison is unsigned.
175179

176180
The body region must contain exactly one block that terminates with
177181
`scf.yield`. Calling ForOp::build will create such a region and insert
@@ -184,8 +188,8 @@ def ForOp : SCF_Op<"for",
184188
... // body
185189
}
186190
...
187-
// Integer case.
188-
scf.for %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
191+
// Unsigned integer case.
192+
scf.for unsigned %iv_32 = %lb_32 to %ub_32 step %step_32 : i32 {
189193
... // body
190194
}
191195
```
@@ -258,15 +262,17 @@ def ForOp : SCF_Op<"for",
258262
let arguments = (ins AnySignlessIntegerOrIndex:$lowerBound,
259263
AnySignlessIntegerOrIndex:$upperBound,
260264
AnySignlessIntegerOrIndex:$step,
261-
Variadic<AnyType>:$initArgs);
265+
Variadic<AnyType>:$initArgs,
266+
UnitAttr:$unsignedCmp);
262267
let results = (outs Variadic<AnyType>:$results);
263268
let regions = (region SizedRegion<1>:$region);
264269

265270
let skipDefaultBuilders = 1;
266271
let builders = [OpBuilder<(ins "Value":$lowerBound, "Value":$upperBound,
267272
"Value":$step, CArg<"ValueRange", "{}">:$initArgs,
268273
CArg<"function_ref<void(OpBuilder &, Location, Value, ValueRange)>",
269-
"nullptr">)>];
274+
"nullptr">,
275+
CArg<"bool", "false">:$unsignedCmp)>];
270276

271277
let extraClassDeclaration = [{
272278
using BodyBuilderFn =

mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -382,8 +382,11 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
382382

383383
// With the body block done, we can fill in the condition block.
384384
rewriter.setInsertionPointToEnd(conditionBlock);
385-
auto comparison = arith::CmpIOp::create(
386-
rewriter, loc, arith::CmpIPredicate::slt, iv, upperBound);
385+
arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
386+
? arith::CmpIPredicate::ult
387+
: arith::CmpIPredicate::slt;
388+
auto comparison =
389+
arith::CmpIOp::create(rewriter, loc, predicate, iv, upperBound);
387390

388391
cf::CondBranchOp::create(rewriter, loc, comparison, firstBodyBlock,
389392
ArrayRef<Value>(), endBlock, ArrayRef<Value>());

mlir/lib/Conversion/SCFToEmitC/SCFToEmitC.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,10 @@ ForLowering::matchAndRewrite(ForOp forOp, OpAdaptor adaptor,
154154
ConversionPatternRewriter &rewriter) const {
155155
Location loc = forOp.getLoc();
156156

157+
if (forOp.getUnsignedCmp())
158+
return rewriter.notifyMatchFailure(forOp,
159+
"unsigned loops are not supported");
160+
157161
// Create an emitc::variable op for each result. These variables will be
158162
// assigned to by emitc::assign ops within the loop body.
159163
SmallVector<Value> resultVariables;

mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,14 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
178178
// Generate the rest of the loop header.
179179
rewriter.setInsertionPointToEnd(header);
180180
auto *mergeBlock = loopOp.getMergeBlock();
181-
auto cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
182-
newIndVar, adaptor.getUpperBound());
181+
Value cmpOp;
182+
if (forOp.getUnsignedCmp()) {
183+
cmpOp = spirv::ULessThanOp::create(rewriter, loc, rewriter.getI1Type(),
184+
newIndVar, adaptor.getUpperBound());
185+
} else {
186+
cmpOp = spirv::SLessThanOp::create(rewriter, loc, rewriter.getI1Type(),
187+
newIndVar, adaptor.getUpperBound());
188+
}
183189

184190
spirv::BranchConditionalOp::create(rewriter, loc, cmpOp, body,
185191
ArrayRef<Value>(), mergeBlock,

mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,8 @@ static FailureOr<PackingResult> buildPackingLoopNestImpl(
594594
auto clonedForOp = scf::ForOp::create(
595595
rewriter, loc, bvm.lookupOrDefault(forOp.getLowerBound()),
596596
bvm.lookupOrDefault(forOp.getUpperBound()),
597-
bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor);
597+
bvm.lookupOrDefault(forOp.getStep()), hoistedPackedTensor,
598+
/*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
598599

599600
// Map the induction var, region args and results to the `clonedForOp`.
600601
bvm.map(forOp.getInductionVar(), clonedForOp.getInductionVar());

mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ static scf::ForOp replaceWithDifferentYield(RewriterBase &rewriter,
5555

5656
scf::ForOp newLoop = scf::ForOp::create(
5757
rewriter, loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(),
58-
loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {});
58+
loop.getStep(), inits, [](OpBuilder &, Location, Value, ValueRange) {},
59+
loop.getUnsignedCmp());
5960

6061
// Generate the new yield with the replaced operand.
6162
auto yieldOp = cast<scf::YieldOp>(loop.getBody()->getTerminator());

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -318,9 +318,12 @@ void ConditionOp::getSuccessorRegions(
318318

319319
void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
320320
Value ub, Value step, ValueRange initArgs,
321-
BodyBuilderFn bodyBuilder) {
321+
BodyBuilderFn bodyBuilder, bool unsignedCmp) {
322322
OpBuilder::InsertionGuard guard(builder);
323323

324+
if (unsignedCmp)
325+
result.addAttribute(getUnsignedCmpAttrName(result.name),
326+
builder.getUnitAttr());
324327
result.addOperands({lb, ub, step});
325328
result.addOperands(initArgs);
326329
for (Value v : initArgs)
@@ -450,6 +453,9 @@ static void printInitializationList(OpAsmPrinter &p,
450453
}
451454

452455
void ForOp::print(OpAsmPrinter &p) {
456+
if (getUnsignedCmp())
457+
p << " unsigned";
458+
453459
p << " " << getInductionVar() << " = " << getLowerBound() << " to "
454460
<< getUpperBound() << " step " << getStep();
455461

@@ -462,7 +468,8 @@ void ForOp::print(OpAsmPrinter &p) {
462468
p.printRegion(getRegion(),
463469
/*printEntryBlockArgs=*/false,
464470
/*printBlockTerminators=*/!getInitArgs().empty());
465-
p.printOptionalAttrDict((*this)->getAttrs());
471+
p.printOptionalAttrDict((*this)->getAttrs(),
472+
/*elidedAttrs=*/getUnsignedCmpAttrName().strref());
466473
}
467474

468475
ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -472,6 +479,10 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
472479
OpAsmParser::Argument inductionVariable;
473480
OpAsmParser::UnresolvedOperand lb, ub, step;
474481

482+
if (succeeded(parser.parseOptionalKeyword("unsigned")))
483+
result.addAttribute(getUnsignedCmpAttrName(result.name),
484+
builder.getUnitAttr());
485+
475486
// Parse the induction variable followed by '='.
476487
if (parser.parseOperand(inductionVariable.ssaName) || parser.parseEqual() ||
477488
// Parse loop bounds.
@@ -562,7 +573,7 @@ ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
562573
inits.append(newInitOperands.begin(), newInitOperands.end());
563574
scf::ForOp newLoop = scf::ForOp::create(
564575
rewriter, getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
565-
[](OpBuilder &, Location, Value, ValueRange) {});
576+
[](OpBuilder &, Location, Value, ValueRange) {}, getUnsignedCmp());
566577
newLoop->setAttrs(getPrunedAttributeList(getOperation(), {}));
567578

568579
// Generate the new yield values and append them to the scf.yield operation.
@@ -806,7 +817,8 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
806817
// 2. Create the new forOp shell.
807818
scf::ForOp newForOp = scf::ForOp::create(
808819
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
809-
forOp.getStep(), newIterOperands);
820+
forOp.getStep(), newIterOperands, /*bodyBuilder=*/nullptr,
821+
forOp.getUnsignedCmp());
810822
newForOp->setAttrs(forOp->getAttrs());
811823
Block &newBlock = newForOp.getRegion().front();
812824
SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
@@ -931,7 +943,8 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
931943

932944
scf::ForOp newForOp =
933945
scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(),
934-
forOp.getUpperBound(), forOp.getStep(), newIterArgs);
946+
forOp.getUpperBound(), forOp.getStep(), newIterArgs,
947+
/*bodyBuilder=*/nullptr, forOp.getUnsignedCmp());
935948
newForOp->setAttrs(forOp->getAttrs());
936949
Block &newBlock = newForOp.getRegion().front();
937950

@@ -989,12 +1002,12 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
9891002
/// Util function that tries to compute a constant diff between u and l.
9901003
/// Returns std::nullopt when the difference between two AffineValueMap is
9911004
/// dynamic.
992-
static std::optional<int64_t> computeConstDiff(Value l, Value u) {
1005+
static std::optional<APInt> computeConstDiff(Value l, Value u) {
9931006
IntegerAttr clb, cub;
9941007
if (matchPattern(l, m_Constant(&clb)) && matchPattern(u, m_Constant(&cub))) {
9951008
llvm::APInt lbValue = clb.getValue();
9961009
llvm::APInt ubValue = cub.getValue();
997-
return (ubValue - lbValue).getSExtValue();
1010+
return ubValue - lbValue;
9981011
}
9991012

10001013
// Else a simple pattern match for x + c or c + x
@@ -1003,7 +1016,7 @@ static std::optional<int64_t> computeConstDiff(Value l, Value u) {
10031016
u, m_Op<arith::AddIOp>(matchers::m_Val(l), m_ConstantInt(&diff))) ||
10041017
matchPattern(
10051018
u, m_Op<arith::AddIOp>(m_ConstantInt(&diff), matchers::m_Val(l))))
1006-
return diff.getSExtValue();
1019+
return diff;
10071020
return std::nullopt;
10081021
}
10091022

@@ -1022,13 +1035,15 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
10221035
return success();
10231036
}
10241037

1025-
std::optional<int64_t> diff =
1038+
std::optional<APInt> diff =
10261039
computeConstDiff(op.getLowerBound(), op.getUpperBound());
10271040
if (!diff)
10281041
return failure();
10291042

10301043
// If the loop is known to have 0 iterations, remove it.
1031-
if (*diff <= 0) {
1044+
bool zeroOrLessIterations =
1045+
diff->isZero() || (!op.getUnsignedCmp() && diff->isNegative());
1046+
if (zeroOrLessIterations) {
10321047
rewriter.replaceOp(op, op.getInitArgs());
10331048
return success();
10341049
}
@@ -3384,9 +3399,8 @@ ParseResult scf::WhileOp::parse(OpAsmParser &parser, OperationState &result) {
33843399

33853400
if (functionType.getNumInputs() != operands.size()) {
33863401
return parser.emitError(typeLoc)
3387-
<< "expected as many input types as operands "
3388-
<< "(expected " << operands.size() << " got "
3389-
<< functionType.getNumInputs() << ")";
3402+
<< "expected as many input types as operands " << "(expected "
3403+
<< operands.size() << " got " << functionType.getNumInputs() << ")";
33903404
}
33913405

33923406
// Resolve input operands.

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,8 @@ struct ForOpInterface
769769
// Construct a new scf.for op with memref instead of tensor values.
770770
auto newForOp = scf::ForOp::create(
771771
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
772-
forOp.getStep(), castedInitArgs);
772+
forOp.getStep(), castedInitArgs, /*bodyBuilder=*/nullptr,
773+
forOp.getUnsignedCmp());
773774
newForOp->setAttrs(forOp->getAttrs());
774775
Block *loopBody = newForOp.getBody();
775776

mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,12 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
5858
auto *beforeBlock = rewriter.createBlock(
5959
&whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
6060
rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
61-
auto cmpOp = arith::CmpIOp::create(
62-
rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt,
63-
beforeBlock->getArgument(0), forOp.getUpperBound());
61+
arith::CmpIPredicate predicate = forOp.getUnsignedCmp()
62+
? arith::CmpIPredicate::ult
63+
: arith::CmpIPredicate::slt;
64+
auto cmpOp = arith::CmpIOp::create(rewriter, whileOp.getLoc(), predicate,
65+
beforeBlock->getArgument(0),
66+
forOp.getUpperBound());
6467
scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
6568
beforeBlock->getArguments());
6669

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,11 @@ FailureOr<ForOp> mlir::scf::pipelineForLoop(RewriterBase &rewriter, ForOp forOp,
791791
bool *modifiedIR) {
792792
if (modifiedIR)
793793
*modifiedIR = false;
794+
795+
// TODO: Add support for unsigned loops.
796+
if (forOp.getUnsignedCmp())
797+
return failure();
798+
794799
LoopPipelinerInternal pipeliner;
795800
if (!pipeliner.initializeLoopInfo(forOp, options))
796801
return failure();

0 commit comments

Comments
 (0)