Skip to content

Commit eb91b30

Browse files
committed
use list of bitwidths instead of 1
1 parent ed4920c commit eb91b30

File tree

4 files changed

+79
-70
lines changed

4 files changed

+79
-70
lines changed

mlir/include/mlir/Dialect/Arith/Transforms/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ std::unique_ptr<Pass> createIntRangeOptimizationsPass();
8080
/// Add patterns for int range based narrowing.
8181
void populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
8282
DataFlowSolver &solver,
83-
unsigned targetBitwidth);
83+
ArrayRef<unsigned> bitwidthsSupported);
8484

8585
// TODO: merge these two narrowing passes.
8686
/// Add patterns for integer bitwidth narrowing.

mlir/include/mlir/Dialect/Arith/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ def ArithIntRangeNarrowing : Pass<"arith-int-range-narrowing"> {
5858
}];
5959

6060
let options = [
61-
Option<"targetBitwidth", "target-bitwidth", "unsigned",
62-
/*default=*/"32", "Target bitwidth this pass will try to narrow to">,
61+
ListOption<"bitwidthsSupported", "int-bitwidths-supported", "unsigned",
62+
"Integer bitwidths supported">,
6363
];
6464

6565
// Explicitly depend on "arith" because this pass could create operations in

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 75 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -308,108 +308,117 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
308308

309309
struct NarrowElementwise final
310310
: public OpTraitRewritePattern<OpTrait::Elementwise> {
311-
NarrowElementwise(MLIRContext *context, DataFlowSolver &s, unsigned target)
311+
NarrowElementwise(MLIRContext *context, DataFlowSolver &s,
312+
ArrayRef<unsigned> target)
312313
: OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
313-
targetBitwidth(target) {}
314+
targetBitwidths(target) {}
314315

315316
using OpTraitRewritePattern::OpTraitRewritePattern;
316317
LogicalResult matchAndRewrite(Operation *op,
317318
PatternRewriter &rewriter) const override {
318-
Type srcType = checkElementwiseOpType(op, targetBitwidth);
319-
if (!srcType)
320-
return failure();
321319

322320
std::optional<ConstantIntRanges> range =
323321
getOperandsRange(solver, op->getResults());
324322
if (!range)
325323
return failure();
326324

327-
// We are truncating op args to the desired bitwidth before the op and then
328-
// extending op results back to the original width after.
329-
// extui and exti will produce different results for negative values, so
330-
// limit signed range to non-negative values.
331-
auto smin = APInt::getZero(targetBitwidth);
332-
auto smax = APInt::getSignedMaxValue(targetBitwidth);
333-
auto umin = APInt::getMinValue(targetBitwidth);
334-
auto umax = APInt::getMaxValue(targetBitwidth);
335-
if (!checkRange(*range, smin, smax, umin, umax))
336-
return failure();
325+
for (unsigned targetBitwidth : targetBitwidths) {
326+
Type srcType = checkElementwiseOpType(op, targetBitwidth);
327+
if (!srcType)
328+
continue;
337329

338-
Type targetType = getTargetType(srcType, targetBitwidth);
339-
if (targetType == srcType)
340-
return failure();
330+
// We are truncating op args to the desired bitwidth before the op and
331+
// then extending op results back to the original width after. extui and
332+
// exti will produce different results for negative values, so limit
333+
// signed range to non-negative values.
334+
auto smin = APInt::getZero(targetBitwidth);
335+
auto smax = APInt::getSignedMaxValue(targetBitwidth);
336+
auto umin = APInt::getMinValue(targetBitwidth);
337+
auto umax = APInt::getMaxValue(targetBitwidth);
338+
if (!checkRange(*range, smin, smax, umin, umax))
339+
continue;
341340

342-
Location loc = op->getLoc();
343-
IRMapping mapping;
344-
for (Value arg : op->getOperands()) {
345-
Value newArg = doCast(rewriter, loc, arg, targetType);
346-
mapping.map(arg, newArg);
347-
}
341+
Type targetType = getTargetType(srcType, targetBitwidth);
342+
if (targetType == srcType)
343+
continue;
348344

349-
Operation *newOp = rewriter.clone(*op, mapping);
350-
rewriter.modifyOpInPlace(newOp, [&]() {
351-
for (OpResult res : newOp->getResults()) {
352-
res.setType(targetType);
345+
Location loc = op->getLoc();
346+
IRMapping mapping;
347+
for (Value arg : op->getOperands()) {
348+
Value newArg = doCast(rewriter, loc, arg, targetType);
349+
mapping.map(arg, newArg);
353350
}
354-
});
355-
SmallVector<Value> newResults;
356-
for (Value res : newOp->getResults())
357-
newResults.emplace_back(doCast(rewriter, loc, res, srcType));
358351

359-
rewriter.replaceOp(op, newResults);
360-
return success();
352+
Operation *newOp = rewriter.clone(*op, mapping);
353+
rewriter.modifyOpInPlace(newOp, [&]() {
354+
for (OpResult res : newOp->getResults()) {
355+
res.setType(targetType);
356+
}
357+
});
358+
SmallVector<Value> newResults;
359+
for (Value res : newOp->getResults())
360+
newResults.emplace_back(doCast(rewriter, loc, res, srcType));
361+
362+
rewriter.replaceOp(op, newResults);
363+
return success();
364+
}
365+
return failure();
361366
}
362367

363368
private:
364369
DataFlowSolver &solver;
365-
unsigned targetBitwidth;
370+
SmallVector<unsigned, 4> targetBitwidths;
366371
};
367372

368373
struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
369374
NarrowCmpi(MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
370-
unsigned target)
371-
: OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
375+
ArrayRef<unsigned> target)
376+
: OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) {
377+
}
372378

373379
LogicalResult matchAndRewrite(arith::CmpIOp op,
374380
PatternRewriter &rewriter) const override {
375381
Value lhs = op.getLhs();
376382
Value rhs = op.getRhs();
377383

378-
Type srcType = checkArithType(lhs.getType(), targetBitwidth);
379-
if (!srcType)
380-
return failure();
381-
382384
std::optional<ConstantIntRanges> range =
383385
getOperandsRange(solver, {lhs, rhs});
384386
if (!range)
385387
return failure();
386388

387-
auto smin = APInt::getSignedMinValue(targetBitwidth);
388-
auto smax = APInt::getSignedMaxValue(targetBitwidth);
389-
auto umin = APInt::getMinValue(targetBitwidth);
390-
auto umax = APInt::getMaxValue(targetBitwidth);
391-
if (!checkRange(*range, smin, smax, umin, umax))
392-
return failure();
389+
for (unsigned targetBitwidth : targetBitwidths) {
390+
Type srcType = checkArithType(lhs.getType(), targetBitwidth);
391+
if (!srcType)
392+
continue;
393393

394-
Type targetType = getTargetType(srcType, targetBitwidth);
395-
if (targetType == srcType)
396-
return failure();
394+
auto smin = APInt::getSignedMinValue(targetBitwidth);
395+
auto smax = APInt::getSignedMaxValue(targetBitwidth);
396+
auto umin = APInt::getMinValue(targetBitwidth);
397+
auto umax = APInt::getMaxValue(targetBitwidth);
398+
if (!checkRange(*range, smin, smax, umin, umax))
399+
continue;
397400

398-
Location loc = op->getLoc();
399-
IRMapping mapping;
400-
for (Value arg : op->getOperands()) {
401-
Value newArg = doCast(rewriter, loc, arg, targetType);
402-
mapping.map(arg, newArg);
403-
}
401+
Type targetType = getTargetType(srcType, targetBitwidth);
402+
if (targetType == srcType)
403+
continue;
404404

405-
Operation *newOp = rewriter.clone(*op, mapping);
406-
rewriter.replaceOp(op, newOp->getResults());
407-
return success();
405+
Location loc = op->getLoc();
406+
IRMapping mapping;
407+
for (Value arg : op->getOperands()) {
408+
Value newArg = doCast(rewriter, loc, arg, targetType);
409+
mapping.map(arg, newArg);
410+
}
411+
412+
Operation *newOp = rewriter.clone(*op, mapping);
413+
rewriter.replaceOp(op, newOp->getResults());
414+
return success();
415+
}
416+
return failure();
408417
}
409418

410419
private:
411420
DataFlowSolver &solver;
412-
unsigned targetBitwidth;
421+
SmallVector<unsigned, 4> targetBitwidths;
413422
};
414423

415424
struct IntRangeOptimizationsPass
@@ -453,7 +462,7 @@ struct IntRangeNarrowingPass
453462
DataFlowListener listener(solver);
454463

455464
RewritePatternSet patterns(ctx);
456-
populateIntRangeNarrowingPatterns(patterns, solver, this->targetBitwidth);
465+
populateIntRangeNarrowingPatterns(patterns, solver, bitwidthsSupported);
457466

458467
GreedyRewriteConfig config;
459468
config.listener = &listener;
@@ -470,16 +479,16 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
470479
DeleteTrivialRem<RemUIOp>>(patterns.getContext(), solver);
471480
}
472481

473-
void mlir::arith::populateIntRangeNarrowingPatterns(RewritePatternSet &patterns,
474-
DataFlowSolver &solver,
475-
unsigned targetBitwidth) {
482+
void mlir::arith::populateIntRangeNarrowingPatterns(
483+
RewritePatternSet &patterns, DataFlowSolver &solver,
484+
ArrayRef<unsigned> bitwidthsSupported) {
476485
// Cmpi uses args ranges instead of results, run it with higher benefit,
477486
// as its argumens can be potentially replaced.
478487
patterns.add<NarrowCmpi>(patterns.getContext(), /*benefit*/ 10, solver,
479-
targetBitwidth);
488+
bitwidthsSupported);
480489

481490
patterns.add<NarrowElementwise>(patterns.getContext(), solver,
482-
targetBitwidth);
491+
bitwidthsSupported);
483492
}
484493

485494
std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass() {

mlir/test/Dialect/Arith/int-range-narrowing.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt --arith-int-range-narrowing="target-bitwidth=32" %s | FileCheck %s
1+
// RUN: mlir-opt --arith-int-range-narrowing="int-bitwidths-supported=32" %s | FileCheck %s
22

33
// Do not truncate negative values
44
// CHECK-LABEL: func @test_addi_neg

0 commit comments

Comments
 (0)