Skip to content

Commit 327c8fc

Browse files
committed
[AutoBump] Merge with fixes of 3430bc3 (Feb 18)
2 parents 7ed352e + 3430bc3 commit 327c8fc

File tree

13 files changed

+558
-72
lines changed

13 files changed

+558
-72
lines changed

mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,9 +1821,9 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
18211821

18221822
let arguments = (ins
18231823
Tosa_Tensor4D:$input,
1824-
Tosa_IntArrayAttr4:$scale,
1825-
Tosa_IntArrayAttr2:$offset,
1826-
Tosa_IntArrayAttr2:$border,
1824+
Rank4TosaShape:$scale,
1825+
Rank2TosaShape:$offset,
1826+
Rank2TosaShape:$border,
18271827
Tosa_ResizeTypeAttr:$mode
18281828
);
18291829

@@ -1832,6 +1832,7 @@ def Tosa_ResizeOp : Tosa_InferShapedTypeOp<"resize"> {
18321832
);
18331833

18341834
let hasFolder = 1;
1835+
let hasVerifier = 1;
18351836
}
18361837

18371838
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,9 @@ SmallVector<int64_t> convertFromMlirShape(ArrayRef<int64_t> shape);
256256
bool getConstShapeValue(Operation *op,
257257
llvm::SmallVector<int64_t> &result_shape);
258258

259+
// returns a small vector of int64_t values that attr contains
260+
SmallVector<int64_t> convertFromIntAttr(const DenseElementsAttr &attr,
261+
const int rank);
259262
} // namespace tosa
260263
} // namespace mlir
261264

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1505,7 +1505,10 @@ class ResizeUnaryConverter : public OpConversionPattern<tosa::ResizeOp> {
15051505
return success();
15061506
}
15071507

1508-
ArrayRef<int64_t> scale = operands.getScale();
1508+
SmallVector<int64_t> scale;
1509+
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale)) {
1510+
return failure();
1511+
}
15091512

15101513
// Collapse the unit width and height away.
15111514
SmallVector<ReassociationExprs, 4> reassociationMap(2);
@@ -1611,8 +1614,9 @@ class MaterializeResizeBroadcast : public OpConversionPattern<tosa::ResizeOp> {
16111614
resizeShape.push_back(channels);
16121615

16131616
auto resizeTy = resultTy.clone(resizeShape);
1614-
auto resize =
1615-
builder.create<tosa::ResizeOp>(resizeTy, input, op->getAttrs());
1617+
auto resize = builder.create<tosa::ResizeOp>(resizeTy, input, op.getScale(),
1618+
op.getOffset(), op.getBorder(),
1619+
op.getMode());
16161620

16171621
// Collapse an unit result dims.
16181622
SmallVector<ReassociationExprs, 4> reassociationMap(2);
@@ -1733,9 +1737,14 @@ class GenericResizeConverter : public OpConversionPattern<tosa::ResizeOp> {
17331737
Value inY = b.create<arith::IndexCastOp>(b.getI32Type(), y);
17341738
Value inX = b.create<arith::IndexCastOp>(b.getI32Type(), x);
17351739

1736-
ArrayRef<int64_t> offset = operands.getOffset();
1737-
ArrayRef<int64_t> border = operands.getBorder();
1738-
ArrayRef<int64_t> scale = operands.getScale();
1740+
SmallVector<int64_t> scale, offset, border;
1741+
if (!tosa::getConstShapeValue(op.getScale().getDefiningOp(), scale) ||
1742+
!tosa::getConstShapeValue(op.getOffset().getDefiningOp(), offset) ||
1743+
!tosa::getConstShapeValue(op.getBorder().getDefiningOp(), border)) {
1744+
return rewriter.notifyMatchFailure(
1745+
op, "tosa.resize scale/offset/border should have compile time "
1746+
"constant values.");
1747+
}
17391748

17401749
Value yScaleN, yScaleD, xScaleN, xScaleD;
17411750
yScaleN = b.create<arith::ConstantOp>(b.getI32IntegerAttr(scale[0]));

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1625,9 +1625,22 @@ OpFoldResult PadOp::fold(FoldAdaptor adaptor) {
16251625
// Fold away cases where a tosa.resize operation returns a copy
16261626
// of the input image.
16271627
OpFoldResult ResizeOp::fold(FoldAdaptor adaptor) {
1628-
ArrayRef<int64_t> offset = getOffset();
1629-
ArrayRef<int64_t> border = getBorder();
1630-
ArrayRef<int64_t> scale = getScale();
1628+
auto scaleAttr =
1629+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getScale());
1630+
auto offsetAttr =
1631+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getOffset());
1632+
auto borderAttr =
1633+
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getBorder());
1634+
if (!scaleAttr || !offsetAttr || !borderAttr) {
1635+
return {};
1636+
}
1637+
1638+
auto scale = tosa::convertFromIntAttr(scaleAttr, /* rank = */ 4);
1639+
auto offset = tosa::convertFromIntAttr(offsetAttr, /* rank = */ 2);
1640+
auto border = tosa::convertFromIntAttr(borderAttr, /* rank = */ 2);
1641+
if (scale.size() != 4 || offset.size() != 2 || border.size() != 2) {
1642+
return {};
1643+
}
16311644

16321645
// Check unit scaling.
16331646
if (scale[0] != scale[1] || scale[2] != scale[3]) {

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 100 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1689,9 +1689,14 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
16891689
(inputWidth == ShapedType::kDynamic))
16901690
return failure();
16911691

1692-
llvm::ArrayRef<int64_t> scaleInt = adaptor.getScale();
1693-
llvm::ArrayRef<int64_t> offsetInt = adaptor.getOffset();
1694-
llvm::ArrayRef<int64_t> borderInt = adaptor.getBorder();
1692+
SmallVector<int64_t> scaleInt, offsetInt, borderInt;
1693+
if (!tosa::getConstShapeValue(adaptor.getScale().getDefiningOp(), scaleInt) ||
1694+
!tosa::getConstShapeValue(adaptor.getOffset().getDefiningOp(),
1695+
offsetInt) ||
1696+
!tosa::getConstShapeValue(adaptor.getBorder().getDefiningOp(),
1697+
borderInt)) {
1698+
return failure();
1699+
}
16951700

16961701
// Compute the output shape based on attributes: scale, offset, and border.
16971702
outputShape[1] =
@@ -1708,6 +1713,98 @@ LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
17081713
return success();
17091714
}
17101715

1716+
LogicalResult tosa::ResizeOp::verify() {
1717+
const Value input = getInput();
1718+
const Value output = getOutput();
1719+
const RankedTensorType inputType =
1720+
llvm::dyn_cast<RankedTensorType>(input.getType());
1721+
const RankedTensorType outputType =
1722+
llvm::dyn_cast<RankedTensorType>(output.getType());
1723+
1724+
if (!inputType)
1725+
return emitOpError("expect a ranked input tensor");
1726+
if (!outputType)
1727+
return emitOpError("expect a ranked output tensor");
1728+
1729+
const int64_t oh = outputType.getDimSize(1);
1730+
const int64_t ow = outputType.getDimSize(2);
1731+
const int64_t ih = inputType.getDimSize(1);
1732+
const int64_t iw = inputType.getDimSize(2);
1733+
1734+
SmallVector<int64_t> scaleValues;
1735+
SmallVector<int64_t> offsetValues;
1736+
SmallVector<int64_t> borderValues;
1737+
if (!tosa::getConstShapeValue(getScale().getDefiningOp(), scaleValues) ||
1738+
!tosa::getConstShapeValue(getOffset().getDefiningOp(), offsetValues) ||
1739+
!tosa::getConstShapeValue(getBorder().getDefiningOp(), borderValues)) {
1740+
// Skip following checks if shape is not constant
1741+
return success();
1742+
}
1743+
1744+
if (llvm::any_of(scaleValues, [](int64_t s) { return s <= 0; }))
1745+
return emitOpError("expect all scale values to be > 0, got ")
1746+
<< scaleValues;
1747+
1748+
const int64_t scaleYN = scaleValues[0];
1749+
const int64_t scaleYD = scaleValues[1];
1750+
const int64_t scaleXN = scaleValues[2];
1751+
const int64_t scaleXD = scaleValues[3];
1752+
1753+
const int64_t offsetY = offsetValues[0];
1754+
const int64_t offsetX = offsetValues[1];
1755+
1756+
const int64_t borderY = borderValues[0];
1757+
const int64_t borderX = borderValues[1];
1758+
1759+
auto idivCheck = [](const int64_t lhs,
1760+
const int64_t rhs) -> std::optional<int64_t> {
1761+
if (lhs % rhs != 0)
1762+
return std::nullopt;
1763+
return lhs / rhs;
1764+
};
1765+
1766+
// Don't check with input height that could be broadcast (ih != 1)
1767+
// since Linalg, a consumer of TOSA, expects broadcasting support
1768+
// in resize to be available. Taking the cautious approach for now,
1769+
// we can consider removing support for broadcasting later.
1770+
if (ih != ShapedType::kDynamic && ih != 1) {
1771+
const std::optional<int64_t> calculatedOutHeightMinusOne =
1772+
idivCheck((ih - 1) * scaleYN - offsetY + borderY, scaleYD);
1773+
if (!calculatedOutHeightMinusOne.has_value())
1774+
return emitOpError("expected (input_height - 1) * scale_y_n - offset_y + "
1775+
"border_y ")
1776+
<< "to be wholly divisible by scale_y_d, got ((" << ih
1777+
<< " - 1) * " << scaleYN << " - " << offsetY << " + " << borderY
1778+
<< ") / " << scaleYD;
1779+
const int64_t calculatedOutHeight = calculatedOutHeightMinusOne.value() + 1;
1780+
if (oh != ShapedType::kDynamic && calculatedOutHeight != oh)
1781+
return emitOpError("calculated output height did not match expected: ")
1782+
<< "calculated=" << calculatedOutHeight << ", expected=" << oh;
1783+
}
1784+
1785+
// Don't check with input width that could be broadcast (iw != 1)
1786+
// since Linalg, a consumer of TOSA, expects broadcasting support
1787+
// in resize to be available. Taking the cautious approach for now,
1788+
// we can consider removing support for broadcasting later.
1789+
if (iw != ShapedType::kDynamic && iw != 1) {
1790+
const int64_t scaledInWidth = (iw - 1) * scaleXN - offsetX + borderX;
1791+
const std::optional<int64_t> calculatedOutWidthMinusOne =
1792+
idivCheck(scaledInWidth, scaleXD);
1793+
if (!calculatedOutWidthMinusOne.has_value())
1794+
return emitOpError("expected (input_width - 1) * scale_x_n - offset_x + "
1795+
"border_x ")
1796+
<< "to be wholly divisible by scale_x_d, got ((" << iw
1797+
<< " - 1) * " << scaleXN << " - " << offsetX << " + " << borderX
1798+
<< ") / " << scaleXD;
1799+
const int64_t calculatedOutWidth = calculatedOutWidthMinusOne.value() + 1;
1800+
if (ow != ShapedType::kDynamic && calculatedOutWidth != ow)
1801+
return emitOpError("calculated output width did not match expected: ")
1802+
<< "calculated=" << calculatedOutWidth << ", expected=" << ow;
1803+
}
1804+
1805+
return success();
1806+
}
1807+
17111808
LogicalResult tosa::ScatterOp::inferReturnTypeComponents(
17121809
MLIRContext *context, ::std::optional<Location> location,
17131810
ScatterOp::Adaptor adaptor,

0 commit comments

Comments
 (0)