@@ -1378,7 +1378,10 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
13781378 return success ();
13791379 }
13801380
1381- ArrayRef<int64_t > scale = op.getScale ();
1381+ SmallVector<int64_t > scale;
1382+ if (!tosa::getConstShapeValue (op.getScale ().getDefiningOp (), scale)) {
1383+ return failure ();
1384+ }
13821385
13831386 // Collapse the unit width and height away.
13841387 SmallVector<ReassociationExprs, 4 > reassociationMap (2 );
@@ -1440,105 +1443,6 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
14401443 }
14411444};
14421445
1443- // TOSA resize with width or height of 1 may be broadcasted to a wider
1444- // dimension. This is done by materializing a new tosa.resize without
1445- // the broadcasting behavior, and an explicit broadcast afterwards.
1446- class MaterializeResizeBroadcast : public OpRewritePattern <tosa::ResizeOp> {
1447- public:
1448- using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1449-
1450- LogicalResult matchAndRewrite (tosa::ResizeOp op,
1451- PatternRewriter &rewriter) const final {
1452- Location loc = op.getLoc ();
1453- ImplicitLocOpBuilder builder (loc, rewriter);
1454- auto input = op.getInput ();
1455- auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
1456- auto resultTy = dyn_cast<RankedTensorType>(op.getType ());
1457-
1458- if (!inputTy || !resultTy)
1459- return rewriter.notifyMatchFailure (op,
1460- " requires ranked input/output types" );
1461-
1462- auto batch = inputTy.getDimSize (0 );
1463- auto channels = inputTy.getDimSize (3 );
1464- auto inputH = inputTy.getDimSize (1 );
1465- auto inputW = inputTy.getDimSize (2 );
1466- auto outputH = resultTy.getDimSize (1 );
1467- auto outputW = resultTy.getDimSize (2 );
1468-
1469- if ((inputH != 1 || outputH == 1 ) && (inputW != 1 || outputW == 1 ))
1470- return rewriter.notifyMatchFailure (
1471- op, " tosa.resize has no broadcasting behavior" );
1472-
1473- // For any dimension that is broadcastable we generate a width of 1
1474- // on the output.
1475- llvm::SmallVector<int64_t > resizeShape;
1476- resizeShape.push_back (batch);
1477- resizeShape.push_back (inputH == 1 ? 1 : outputH);
1478- resizeShape.push_back (inputW == 1 ? 1 : outputW);
1479- resizeShape.push_back (channels);
1480-
1481- auto resizeTy = resultTy.clone (resizeShape);
1482- auto resize =
1483- builder.create <tosa::ResizeOp>(resizeTy, input, op->getAttrs ());
1484-
1485- // Collapse an unit result dims.
1486- SmallVector<ReassociationExprs, 4 > reassociationMap (2 );
1487- reassociationMap[0 ].push_back (builder.getAffineDimExpr (0 ));
1488- reassociationMap.back ().push_back (builder.getAffineDimExpr (1 ));
1489- if (inputH != 1 )
1490- reassociationMap.push_back ({});
1491- reassociationMap.back ().push_back (builder.getAffineDimExpr (2 ));
1492- if (inputW != 1 )
1493- reassociationMap.push_back ({});
1494- reassociationMap.back ().push_back (builder.getAffineDimExpr (3 ));
1495-
1496- llvm::SmallVector<int64_t > collapseShape = {batch};
1497- if (inputH != 1 )
1498- collapseShape.push_back (outputH);
1499- if (inputW != 1 )
1500- collapseShape.push_back (outputW);
1501- collapseShape.push_back (channels);
1502-
1503- auto collapseTy = resultTy.clone (collapseShape);
1504- Value collapse = builder.create <tensor::CollapseShapeOp>(collapseTy, resize,
1505- reassociationMap);
1506-
1507- // Broadcast the collapsed shape to the output result.
1508- llvm::SmallVector<Value> outputDynSize;
1509- if (inputTy.isDynamicDim (0 ))
1510- outputDynSize.push_back (builder.create <tensor::DimOp>(input, 0 ));
1511- if (inputTy.isDynamicDim (3 ))
1512- outputDynSize.push_back (builder.create <tensor::DimOp>(input, 3 ));
1513-
1514- SmallVector<utils::IteratorType> iterators (resultTy.getRank (),
1515- utils::IteratorType::parallel);
1516- Value empty = builder.create <tensor::EmptyOp>(
1517- resultTy.getShape (), resultTy.getElementType (), outputDynSize);
1518-
1519- SmallVector<AffineExpr, 4 > inputExprs{rewriter.getAffineDimExpr (0 )};
1520- if (inputH != 1 )
1521- inputExprs.push_back (rewriter.getAffineDimExpr (1 ));
1522- if (inputW != 1 )
1523- inputExprs.push_back (rewriter.getAffineDimExpr (2 ));
1524- inputExprs.push_back (rewriter.getAffineDimExpr (3 ));
1525-
1526- auto inputMap = AffineMap::get (resultTy.getRank (), /* symbolCount=*/ 0 ,
1527- inputExprs, rewriter.getContext ());
1528-
1529- auto outputMap = rewriter.getMultiDimIdentityMap (resultTy.getRank ());
1530- rewriter.replaceOpWithNewOp <linalg::GenericOp>(
1531- op, resultTy, ValueRange{collapse}, ValueRange{empty},
1532- ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1533- [=](OpBuilder &b, Location loc, ValueRange args) {
1534- Value value = args[0 ];
1535- b.create <linalg::YieldOp>(loc, value);
1536- });
1537-
1538- return success ();
1539- }
1540- };
1541-
15421446class GenericResizeConverter : public OpRewritePattern <tosa::ResizeOp> {
15431447public:
15441448 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
@@ -1595,9 +1499,14 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
15951499 Value inY = b.create <arith::IndexCastOp>(b.getI32Type (), y);
15961500 Value inX = b.create <arith::IndexCastOp>(b.getI32Type (), x);
15971501
1598- ArrayRef<int64_t > offset = op.getOffset ();
1599- ArrayRef<int64_t > border = op.getBorder ();
1600- ArrayRef<int64_t > scale = op.getScale ();
1502+ SmallVector<int64_t > scale, offset, border;
1503+ if (!tosa::getConstShapeValue (op.getScale ().getDefiningOp (), scale) ||
1504+ !tosa::getConstShapeValue (op.getOffset ().getDefiningOp (), offset) ||
1505+ !tosa::getConstShapeValue (op.getBorder ().getDefiningOp (), border)) {
1506+ return rewriter.notifyMatchFailure (
1507+ op, " tosa.resize scale/offset/border should have compile time "
1508+ " constant values." );
1509+ }
16011510
16021511 Value yScaleN, yScaleD, xScaleN, xScaleD;
16031512 yScaleN = b.create <arith::ConstantOp>(b.getI32IntegerAttr (scale[0 ]));
@@ -2607,8 +2516,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
26072516 /* benefit=*/ 100 );
26082517 patterns->add <ResizeUnaryConverter>(patterns->getContext (),
26092518 /* benefit=*/ 200 );
2610- patterns->add <MaterializeResizeBroadcast>(patterns->getContext (),
2611- /* benefit=*/ 300 );
26122519
26132520 patterns->add <
26142521 // clang-format off
0 commit comments