@@ -308,108 +308,117 @@ static Value doCast(OpBuilder &builder, Location loc, Value src, Type dstType) {
308308
309309struct NarrowElementwise final
310310 : public OpTraitRewritePattern<OpTrait::Elementwise> {
311- NarrowElementwise (MLIRContext *context, DataFlowSolver &s, unsigned target)
311+ NarrowElementwise (MLIRContext *context, DataFlowSolver &s,
312+ ArrayRef<unsigned > target)
312313 : OpTraitRewritePattern<OpTrait::Elementwise>(context), solver(s),
313- targetBitwidth (target) {}
314+ targetBitwidths (target) {}
314315
315316 using OpTraitRewritePattern::OpTraitRewritePattern;
316317 LogicalResult matchAndRewrite (Operation *op,
317318 PatternRewriter &rewriter) const override {
318- Type srcType = checkElementwiseOpType (op, targetBitwidth);
319- if (!srcType)
320- return failure ();
321319
322320 std::optional<ConstantIntRanges> range =
323321 getOperandsRange (solver, op->getResults ());
324322 if (!range)
325323 return failure ();
326324
327- // We are truncating op args to the desired bitwidth before the op and then
328- // extending op results back to the original width after.
329- // extui and exti will produce different results for negative values, so
330- // limit signed range to non-negative values.
331- auto smin = APInt::getZero (targetBitwidth);
332- auto smax = APInt::getSignedMaxValue (targetBitwidth);
333- auto umin = APInt::getMinValue (targetBitwidth);
334- auto umax = APInt::getMaxValue (targetBitwidth);
335- if (!checkRange (*range, smin, smax, umin, umax))
336- return failure ();
325+ for (unsigned targetBitwidth : targetBitwidths) {
326+ Type srcType = checkElementwiseOpType (op, targetBitwidth);
327+ if (!srcType)
328+ continue ;
337329
338- Type targetType = getTargetType (srcType, targetBitwidth);
339- if (targetType == srcType)
340- return failure ();
330+ // We are truncating op args to the desired bitwidth before the op and
331+ // then extending op results back to the original width after. extui and
332+ // exti will produce different results for negative values, so limit
333+ // signed range to non-negative values.
334+ auto smin = APInt::getZero (targetBitwidth);
335+ auto smax = APInt::getSignedMaxValue (targetBitwidth);
336+ auto umin = APInt::getMinValue (targetBitwidth);
337+ auto umax = APInt::getMaxValue (targetBitwidth);
338+ if (!checkRange (*range, smin, smax, umin, umax))
339+ continue ;
341340
342- Location loc = op->getLoc ();
343- IRMapping mapping;
344- for (Value arg : op->getOperands ()) {
345- Value newArg = doCast (rewriter, loc, arg, targetType);
346- mapping.map (arg, newArg);
347- }
341+ Type targetType = getTargetType (srcType, targetBitwidth);
342+ if (targetType == srcType)
343+ continue ;
348344
349- Operation *newOp = rewriter.clone (*op, mapping);
350- rewriter.modifyOpInPlace (newOp, [&]() {
351- for (OpResult res : newOp->getResults ()) {
352- res.setType (targetType);
345+ Location loc = op->getLoc ();
346+ IRMapping mapping;
347+ for (Value arg : op->getOperands ()) {
348+ Value newArg = doCast (rewriter, loc, arg, targetType);
349+ mapping.map (arg, newArg);
353350 }
354- });
355- SmallVector<Value> newResults;
356- for (Value res : newOp->getResults ())
357- newResults.emplace_back (doCast (rewriter, loc, res, srcType));
358351
359- rewriter.replaceOp (op, newResults);
360- return success ();
352+ Operation *newOp = rewriter.clone (*op, mapping);
353+ rewriter.modifyOpInPlace (newOp, [&]() {
354+ for (OpResult res : newOp->getResults ()) {
355+ res.setType (targetType);
356+ }
357+ });
358+ SmallVector<Value> newResults;
359+ for (Value res : newOp->getResults ())
360+ newResults.emplace_back (doCast (rewriter, loc, res, srcType));
361+
362+ rewriter.replaceOp (op, newResults);
363+ return success ();
364+ }
365+ return failure ();
361366 }
362367
363368private:
364369 DataFlowSolver &solver;
365- unsigned targetBitwidth ;
370+ SmallVector< unsigned , 4 > targetBitwidths ;
366371};
367372
368373struct NarrowCmpi final : public OpRewritePattern<arith::CmpIOp> {
369374 NarrowCmpi (MLIRContext *context, PatternBenefit benefit, DataFlowSolver &s,
370- unsigned target)
371- : OpRewritePattern(context, benefit), solver(s), targetBitwidth(target) {}
375+ ArrayRef<unsigned > target)
376+ : OpRewritePattern(context, benefit), solver(s), targetBitwidths(target) {
377+ }
372378
373379 LogicalResult matchAndRewrite (arith::CmpIOp op,
374380 PatternRewriter &rewriter) const override {
375381 Value lhs = op.getLhs ();
376382 Value rhs = op.getRhs ();
377383
378- Type srcType = checkArithType (lhs.getType (), targetBitwidth);
379- if (!srcType)
380- return failure ();
381-
382384 std::optional<ConstantIntRanges> range =
383385 getOperandsRange (solver, {lhs, rhs});
384386 if (!range)
385387 return failure ();
386388
387- auto smin = APInt::getSignedMinValue (targetBitwidth);
388- auto smax = APInt::getSignedMaxValue (targetBitwidth);
389- auto umin = APInt::getMinValue (targetBitwidth);
390- auto umax = APInt::getMaxValue (targetBitwidth);
391- if (!checkRange (*range, smin, smax, umin, umax))
392- return failure ();
389+ for (unsigned targetBitwidth : targetBitwidths) {
390+ Type srcType = checkArithType (lhs.getType (), targetBitwidth);
391+ if (!srcType)
392+ continue ;
393393
394- Type targetType = getTargetType (srcType, targetBitwidth);
395- if (targetType == srcType)
396- return failure ();
394+ auto smin = APInt::getSignedMinValue (targetBitwidth);
395+ auto smax = APInt::getSignedMaxValue (targetBitwidth);
396+ auto umin = APInt::getMinValue (targetBitwidth);
397+ auto umax = APInt::getMaxValue (targetBitwidth);
398+ if (!checkRange (*range, smin, smax, umin, umax))
399+ continue ;
397400
398- Location loc = op->getLoc ();
399- IRMapping mapping;
400- for (Value arg : op->getOperands ()) {
401- Value newArg = doCast (rewriter, loc, arg, targetType);
402- mapping.map (arg, newArg);
403- }
401+ Type targetType = getTargetType (srcType, targetBitwidth);
402+ if (targetType == srcType)
403+ continue ;
404404
405- Operation *newOp = rewriter.clone (*op, mapping);
406- rewriter.replaceOp (op, newOp->getResults ());
407- return success ();
405+ Location loc = op->getLoc ();
406+ IRMapping mapping;
407+ for (Value arg : op->getOperands ()) {
408+ Value newArg = doCast (rewriter, loc, arg, targetType);
409+ mapping.map (arg, newArg);
410+ }
411+
412+ Operation *newOp = rewriter.clone (*op, mapping);
413+ rewriter.replaceOp (op, newOp->getResults ());
414+ return success ();
415+ }
416+ return failure ();
408417 }
409418
410419private:
411420 DataFlowSolver &solver;
412- unsigned targetBitwidth ;
421+ SmallVector< unsigned , 4 > targetBitwidths ;
413422};
414423
415424struct IntRangeOptimizationsPass
@@ -453,7 +462,7 @@ struct IntRangeNarrowingPass
453462 DataFlowListener listener (solver);
454463
455464 RewritePatternSet patterns (ctx);
456- populateIntRangeNarrowingPatterns (patterns, solver, this -> targetBitwidth );
465+ populateIntRangeNarrowingPatterns (patterns, solver, bitwidthsSupported );
457466
458467 GreedyRewriteConfig config;
459468 config.listener = &listener;
@@ -470,16 +479,16 @@ void mlir::arith::populateIntRangeOptimizationsPatterns(
470479 DeleteTrivialRem<RemUIOp>>(patterns.getContext (), solver);
471480}
472481
473- void mlir::arith::populateIntRangeNarrowingPatterns (RewritePatternSet &patterns,
474- DataFlowSolver &solver,
475- unsigned targetBitwidth ) {
482+ void mlir::arith::populateIntRangeNarrowingPatterns (
483+ RewritePatternSet &patterns, DataFlowSolver &solver,
484+ ArrayRef< unsigned > bitwidthsSupported ) {
476485 // Cmpi uses args ranges instead of results, run it with higher benefit,
477486 // as its argumens can be potentially replaced.
478487 patterns.add <NarrowCmpi>(patterns.getContext (), /* benefit*/ 10 , solver,
479- targetBitwidth );
488+ bitwidthsSupported );
480489
481490 patterns.add <NarrowElementwise>(patterns.getContext (), solver,
482- targetBitwidth );
491+ bitwidthsSupported );
483492}
484493
485494std::unique_ptr<Pass> mlir::arith::createIntRangeOptimizationsPass () {
0 commit comments