@@ -1382,7 +1382,10 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
13821382 return success ();
13831383 }
13841384
1385- ArrayRef<int64_t > scale = op.getScale ();
1385+ SmallVector<int64_t > scale;
1386+ if (!tosa::getConstShapeValue (op.getScale ().getDefiningOp (), scale)) {
1387+ return failure ();
1388+ }
13861389
13871390 // Collapse the unit width and height away.
13881391 SmallVector<ReassociationExprs, 4 > reassociationMap (2 );
@@ -1444,105 +1447,6 @@ class ResizeUnaryConverter : public OpRewritePattern<tosa::ResizeOp> {
14441447 }
14451448};
14461449
1447- // TOSA resize with width or height of 1 may be broadcasted to a wider
1448- // dimension. This is done by materializing a new tosa.resize without
1449- // the broadcasting behavior, and an explicit broadcast afterwards.
1450- class MaterializeResizeBroadcast : public OpRewritePattern <tosa::ResizeOp> {
1451- public:
1452- using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
1453-
1454- LogicalResult matchAndRewrite (tosa::ResizeOp op,
1455- PatternRewriter &rewriter) const final {
1456- Location loc = op.getLoc ();
1457- ImplicitLocOpBuilder builder (loc, rewriter);
1458- auto input = op.getInput ();
1459- auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
1460- auto resultTy = dyn_cast<RankedTensorType>(op.getType ());
1461-
1462- if (!inputTy || !resultTy)
1463- return rewriter.notifyMatchFailure (op,
1464- " requires ranked input/output types" );
1465-
1466- auto batch = inputTy.getDimSize (0 );
1467- auto channels = inputTy.getDimSize (3 );
1468- auto inputH = inputTy.getDimSize (1 );
1469- auto inputW = inputTy.getDimSize (2 );
1470- auto outputH = resultTy.getDimSize (1 );
1471- auto outputW = resultTy.getDimSize (2 );
1472-
1473- if ((inputH != 1 || outputH == 1 ) && (inputW != 1 || outputW == 1 ))
1474- return rewriter.notifyMatchFailure (
1475- op, " tosa.resize has no broadcasting behavior" );
1476-
1477- // For any dimension that is broadcastable we generate a width of 1
1478- // on the output.
1479- llvm::SmallVector<int64_t > resizeShape;
1480- resizeShape.push_back (batch);
1481- resizeShape.push_back (inputH == 1 ? 1 : outputH);
1482- resizeShape.push_back (inputW == 1 ? 1 : outputW);
1483- resizeShape.push_back (channels);
1484-
1485- auto resizeTy = resultTy.clone (resizeShape);
1486- auto resize =
1487- builder.create <tosa::ResizeOp>(resizeTy, input, op->getAttrs ());
1488-
1489- // Collapse an unit result dims.
1490- SmallVector<ReassociationExprs, 4 > reassociationMap (2 );
1491- reassociationMap[0 ].push_back (builder.getAffineDimExpr (0 ));
1492- reassociationMap.back ().push_back (builder.getAffineDimExpr (1 ));
1493- if (inputH != 1 )
1494- reassociationMap.push_back ({});
1495- reassociationMap.back ().push_back (builder.getAffineDimExpr (2 ));
1496- if (inputW != 1 )
1497- reassociationMap.push_back ({});
1498- reassociationMap.back ().push_back (builder.getAffineDimExpr (3 ));
1499-
1500- llvm::SmallVector<int64_t > collapseShape = {batch};
1501- if (inputH != 1 )
1502- collapseShape.push_back (outputH);
1503- if (inputW != 1 )
1504- collapseShape.push_back (outputW);
1505- collapseShape.push_back (channels);
1506-
1507- auto collapseTy = resultTy.clone (collapseShape);
1508- Value collapse = builder.create <tensor::CollapseShapeOp>(collapseTy, resize,
1509- reassociationMap);
1510-
1511- // Broadcast the collapsed shape to the output result.
1512- llvm::SmallVector<Value> outputDynSize;
1513- if (inputTy.isDynamicDim (0 ))
1514- outputDynSize.push_back (builder.create <tensor::DimOp>(input, 0 ));
1515- if (inputTy.isDynamicDim (3 ))
1516- outputDynSize.push_back (builder.create <tensor::DimOp>(input, 3 ));
1517-
1518- SmallVector<utils::IteratorType> iterators (resultTy.getRank (),
1519- utils::IteratorType::parallel);
1520- Value empty = builder.create <tensor::EmptyOp>(
1521- resultTy.getShape (), resultTy.getElementType (), outputDynSize);
1522-
1523- SmallVector<AffineExpr, 4 > inputExprs{rewriter.getAffineDimExpr (0 )};
1524- if (inputH != 1 )
1525- inputExprs.push_back (rewriter.getAffineDimExpr (1 ));
1526- if (inputW != 1 )
1527- inputExprs.push_back (rewriter.getAffineDimExpr (2 ));
1528- inputExprs.push_back (rewriter.getAffineDimExpr (3 ));
1529-
1530- auto inputMap = AffineMap::get (resultTy.getRank (), /* symbolCount=*/ 0 ,
1531- inputExprs, rewriter.getContext ());
1532-
1533- auto outputMap = rewriter.getMultiDimIdentityMap (resultTy.getRank ());
1534- rewriter.replaceOpWithNewOp <linalg::GenericOp>(
1535- op, resultTy, ValueRange{collapse}, ValueRange{empty},
1536- ArrayRef<AffineMap>{inputMap, outputMap}, iterators,
1537- [=](OpBuilder &b, Location loc, ValueRange args) {
1538- Value value = args[0 ];
1539- b.create <linalg::YieldOp>(loc, value);
1540- });
1541-
1542- return success ();
1543- }
1544- };
1545-
15461450class GenericResizeConverter : public OpRewritePattern <tosa::ResizeOp> {
15471451public:
15481452 using OpRewritePattern<tosa::ResizeOp>::OpRewritePattern;
@@ -1599,9 +1503,14 @@ class GenericResizeConverter : public OpRewritePattern<tosa::ResizeOp> {
15991503 Value inY = b.create <arith::IndexCastOp>(b.getI32Type (), y);
16001504 Value inX = b.create <arith::IndexCastOp>(b.getI32Type (), x);
16011505
1602- ArrayRef<int64_t > offset = op.getOffset ();
1603- ArrayRef<int64_t > border = op.getBorder ();
1604- ArrayRef<int64_t > scale = op.getScale ();
1506+ SmallVector<int64_t > scale, offset, border;
1507+ if (!tosa::getConstShapeValue (op.getScale ().getDefiningOp (), scale) ||
1508+ !tosa::getConstShapeValue (op.getOffset ().getDefiningOp (), offset) ||
1509+ !tosa::getConstShapeValue (op.getBorder ().getDefiningOp (), border)) {
1510+ return rewriter.notifyMatchFailure (
1511+ op, " tosa.resize scale/offset/border should have compile time "
1512+ " constant values." );
1513+ }
16051514
16061515 Value yScaleN, yScaleD, xScaleN, xScaleD;
16071516 yScaleN = b.create <arith::ConstantOp>(b.getI32IntegerAttr (scale[0 ]));
@@ -2612,8 +2521,6 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
26122521 /* benefit=*/ 100 );
26132522 patterns->add <ResizeUnaryConverter>(patterns->getContext (),
26142523 /* benefit=*/ 200 );
2615- patterns->add <MaterializeResizeBroadcast>(patterns->getContext (),
2616- /* benefit=*/ 300 );
26172524
26182525 patterns->add <
26192526 // clang-format off
0 commit comments