Skip to content

Commit f792a33

Browse files
committed
[mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg
The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled. When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.
1 parent c69a70b commit f792a33

File tree

2 files changed

+315
-86
lines changed

2 files changed

+315
-86
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 259 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,199 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
13921392
}
13931393
};
13941394

1395+
// Collapse tensor<1xiN> into tensor<iN>
1396+
// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1397+
static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
1398+
Location loc) {
1399+
SmallVector<ReassociationExprs, 1> reassociation;
1400+
// Create the collapsed type
1401+
auto inputType = cast<RankedTensorType>(input.getType());
1402+
auto elemType = inputType.getElementType();
1403+
auto collapsedType = RankedTensorType::get({}, elemType);
1404+
// Emit the collapse op
1405+
return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
1406+
reassociation);
1407+
}
1408+
1409+
// The multiplier may be either constant or non-constant, depending on
1410+
// whether dynamic extension is enabled.
1411+
// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
1412+
// by:
1413+
// 1. Pushing it into 'genericInputs'.
1414+
// 2. Appending a corresponding affine map to 'indexingMaps'.
1415+
// - If the multiplier is constant, set 'multiplierConstant' instead.
1416+
static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
1417+
PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
1418+
SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1419+
bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
1420+
int64_t &multiplierArg) {
1421+
1422+
auto loc = op.getLoc();
1423+
auto inputTy = cast<ShapedType>(op.getInput().getType());
1424+
unsigned rank = inputTy.getRank();
1425+
SmallVector<AffineExpr, 2> multiplierExprs{
1426+
rewriter.getAffineDimExpr(rank - 1)};
1427+
1428+
if (isConstant) {
1429+
// If we are rescaling per-channel then we need to store the multiplier
1430+
// values in a buffer.
1431+
if (multiplierValues.size() == 1) {
1432+
multiplierConstant = rewriter.create<arith::ConstantOp>(
1433+
loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1434+
} else {
1435+
auto multiplierType =
1436+
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1437+
rewriter.getI32Type());
1438+
genericInputs.push_back(arith::ConstantOp::create(
1439+
rewriter, loc,
1440+
DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1441+
1442+
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1443+
/*symbolCount=*/0, multiplierExprs,
1444+
rewriter.getContext()));
1445+
}
1446+
} else {
1447+
// If we are not rescaling per-channel then we need to collapse 1xN to N
1448+
// and push broadcastMap.
1449+
auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier().getType());
1450+
if (tensorType && tensorType.hasStaticShape() &&
1451+
tensorType.getShape()[0] == 1) {
1452+
// broadcastMap = affine_map<(d0, d1) -> ()>
1453+
// It would affect as broadcast for scalar values in linalg::GenericOp.
1454+
AffineMap broadcastMap =
1455+
AffineMap::get(rank, 0, {}, rewriter.getContext());
1456+
genericInputs.push_back(
1457+
collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
1458+
indexingMaps.push_back(broadcastMap);
1459+
} else {
1460+
genericInputs.push_back(op.getMultiplier());
1461+
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1462+
/*symbolCount=*/0, multiplierExprs,
1463+
rewriter.getContext()));
1464+
}
1465+
}
1466+
multiplierArg = indexingMaps.size() - 1;
1467+
}
1468+
1469+
// The shift may be either constant or non-constant, depending on
1470+
// whether dynamic extension is enabled.
1471+
// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
1472+
// 1. Pushing it into 'genericInputs'.
1473+
// 2. Appending a corresponding affine map to 'indexingMaps'.
1474+
// - If the shift is constant, set 'shiftConstant' instead.
1475+
static void setupLinalgGenericOpInputAndIndexingMapForShift(
1476+
PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
1477+
SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1478+
bool isConstant, tosa::RescaleOp op, Value &shiftConstant,
1479+
int64_t &shiftArg) {
1480+
1481+
auto loc = op.getLoc();
1482+
auto inputTy = cast<ShapedType>(op.getInput().getType());
1483+
unsigned rank = inputTy.getRank();
1484+
SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
1485+
1486+
if (isConstant) {
1487+
// If we are rescaling per-channel then we need to store the shift
1488+
// values in a buffer.
1489+
if (shiftValues.size() == 1) {
1490+
shiftConstant = rewriter.create<arith::ConstantOp>(
1491+
loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1492+
} else {
1493+
auto shiftType =
1494+
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1495+
rewriter.getIntegerType(8));
1496+
genericInputs.push_back(arith::ConstantOp::create(
1497+
rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1498+
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1499+
/*symbolCount=*/0, shiftExprs,
1500+
rewriter.getContext()));
1501+
}
1502+
} else {
1503+
// If we are not rescaling per-channel then we need to collapse 1xN to N
1504+
// and push broadcastMap.
1505+
auto tensorType = dyn_cast<RankedTensorType>(op.getShift().getType());
1506+
if (tensorType && tensorType.hasStaticShape() &&
1507+
tensorType.getShape()[0] == 1) {
1508+
// broadcastMap = affine_map<(d0, d1) -> ()>
1509+
// It would affect as broadcast for scalar values in linalg::GenericOp.
1510+
AffineMap broadcastMap =
1511+
AffineMap::get(rank, 0, {}, rewriter.getContext());
1512+
genericInputs.push_back(
1513+
collapse1xNTensorToN(rewriter, op.getShift(), loc));
1514+
indexingMaps.push_back(broadcastMap);
1515+
} else {
1516+
genericInputs.push_back(op.getShift());
1517+
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1518+
/*symbolCount=*/0, shiftExprs,
1519+
rewriter.getContext()));
1520+
}
1521+
}
1522+
shiftArg = indexingMaps.size() - 1;
1523+
}
1524+
1525+
// Return the extended Zp to be used in subsequent arithmetic operations.
1526+
static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
1527+
FailureOr<int64_t> maybeZp, Location loc,
1528+
ValueRange blockArgs, int64_t iZpArg) {
1529+
Value result;
1530+
// The Zp value can be either constant or non-constant, depending on
1531+
// whether dynamic extension is enabled.
1532+
// If 'maybeZp' fails, it indicates that Zp is non-constant and will
1533+
// be passed as an input to linalg::GenericOp.
1534+
if (failed(maybeZp)) {
1535+
result = blockArgs[iZpArg];
1536+
auto zpTy = result.getType();
1537+
if (zpTy.getIntOrFloatBitWidth() < 32) {
1538+
if (zpTy.isUnsignedInteger()) {
1539+
result =
1540+
builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
1541+
} else {
1542+
result =
1543+
builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
1544+
}
1545+
}
1546+
} else {
1547+
const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
1548+
// Extend zeropoint for sub-32bits widths.
1549+
const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
1550+
result = builder.create<arith::ConstantOp>(
1551+
loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
1552+
}
1553+
return result;
1554+
}
1555+
1556+
// Return the i32 outputZp to be used in subsequent arithmetic operations.
1557+
static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
1558+
FailureOr<int64_t> maybeZp, Location loc,
1559+
ValueRange blockArgs, int64_t oZpArg) {
1560+
Value result;
1561+
// The Zp value can be either constant or non-constant, depending on
1562+
// whether dynamic extension is enabled.
1563+
// If 'maybeZp' fails, it indicates that Zp is non-constant and will
1564+
// be passed as an input to linalg::GenericOp.
1565+
if (failed(maybeZp)) {
1566+
result = blockArgs[oZpArg];
1567+
auto zpTy = result.getType();
1568+
if (zpTy.getIntOrFloatBitWidth() < 32) {
1569+
if (zpTy.isUnsignedInteger()) {
1570+
result =
1571+
builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
1572+
} else {
1573+
result =
1574+
builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
1575+
}
1576+
} else if (zpTy.getIntOrFloatBitWidth() > 32) {
1577+
result =
1578+
builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
1579+
}
1580+
} else {
1581+
const int32_t attrBitwidth = 32;
1582+
result = builder.create<arith::ConstantOp>(
1583+
loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
1584+
}
1585+
return result;
1586+
}
1587+
13951588
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13961589
public:
13971590
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1423,40 +1616,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14231616
}
14241617
}
14251618

1426-
// The shift and multiplier values.
14271619
DenseElementsAttr shiftElems;
1428-
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1429-
return rewriter.notifyMatchFailure(
1430-
op, "tosa.rescale requires constant shift input values");
1620+
bool isShiftConstant = false;
1621+
if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
1622+
isShiftConstant = true;
14311623

14321624
DenseElementsAttr multiplierElems;
1433-
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1434-
return rewriter.notifyMatchFailure(
1435-
op, "tosa.rescale requires constant multiplier input values");
1436-
1437-
llvm::SmallVector<int8_t> shiftValues =
1438-
llvm::to_vector(shiftElems.getValues<int8_t>());
1439-
// explicit cast is required here
1440-
llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1441-
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1442-
[](IntegerAttr attr) -> int32_t {
1443-
return static_cast<int32_t>(attr.getInt());
1444-
}));
1445-
1446-
// If we shift by more than the bitwidth, this just sets to 0.
1447-
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1448-
if (shiftValues[i] > 63) {
1449-
shiftValues[i] = 0;
1450-
multiplierValues[i] = 0;
1625+
bool isMultiplierConstant = false;
1626+
if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1627+
isMultiplierConstant = true;
1628+
1629+
llvm::SmallVector<int8_t> shiftValues;
1630+
llvm::SmallVector<int32_t> multiplierValues;
1631+
bool doubleRound;
1632+
1633+
if (isMultiplierConstant && isShiftConstant) {
1634+
shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
1635+
// explicit cast is required here
1636+
multiplierValues = llvm::to_vector(
1637+
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1638+
[](IntegerAttr attr) -> int32_t {
1639+
return static_cast<int32_t>(attr.getInt());
1640+
}));
1641+
1642+
// If we shift by more than the bitwidth, this just sets to 0.
1643+
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1644+
if (shiftValues[i] > 63) {
1645+
shiftValues[i] = 0;
1646+
multiplierValues[i] = 0;
1647+
}
14511648
}
1452-
}
1649+
// Double round only occurs if shift is greater than 31, check that this
1650+
// is ever true.
1651+
doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1652+
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1653+
} else
1654+
doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
14531655

1454-
// Double round only occurs if shift is greater than 31, check that this
1455-
// is ever true.
1456-
1457-
bool doubleRound =
1458-
op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1459-
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
14601656
RoundingMode roundingMode =
14611657
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
14621658

@@ -1468,45 +1664,41 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14681664
// values in a buffer.
14691665
Value multiplierConstant;
14701666
int64_t multiplierArg = 0;
1471-
if (multiplierValues.size() == 1) {
1472-
multiplierConstant = arith::ConstantOp::create(
1473-
rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1474-
} else {
1475-
SmallVector<AffineExpr, 2> multiplierExprs{
1476-
rewriter.getAffineDimExpr(rank - 1)};
1477-
auto multiplierType =
1478-
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1479-
rewriter.getI32Type());
1480-
genericInputs.push_back(arith::ConstantOp::create(
1481-
rewriter, loc,
1482-
DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1483-
1484-
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1485-
/*symbolCount=*/0, multiplierExprs,
1486-
rewriter.getContext()));
1487-
1488-
multiplierArg = indexingMaps.size() - 1;
1489-
}
1667+
setupLinalgGenericOpInputAndIndexingMapForMultiplier(
1668+
rewriter, multiplierValues, genericInputs, indexingMaps,
1669+
isMultiplierConstant, op, multiplierConstant, multiplierArg);
14901670

14911671
// If we are rescaling per-channel then we need to store the shift
14921672
// values in a buffer.
14931673
Value shiftConstant;
14941674
int64_t shiftArg = 0;
1495-
if (shiftValues.size() == 1) {
1496-
shiftConstant = arith::ConstantOp::create(
1497-
rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1498-
} else {
1499-
SmallVector<AffineExpr, 2> shiftExprs = {
1500-
rewriter.getAffineDimExpr(rank - 1)};
1501-
auto shiftType =
1502-
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1503-
rewriter.getIntegerType(8));
1504-
genericInputs.push_back(arith::ConstantOp::create(
1505-
rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1506-
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1507-
/*symbolCount=*/0, shiftExprs,
1508-
rewriter.getContext()));
1509-
shiftArg = indexingMaps.size() - 1;
1675+
setupLinalgGenericOpInputAndIndexingMapForShift(
1676+
rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1677+
shiftConstant, shiftArg);
1678+
1679+
// broadcastMap = affine_map<(d0, d1) -> ()>
1680+
// It would affect as broadcast for scalar values in linalg::GenericOp.
1681+
AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
1682+
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1683+
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1684+
// The inputZp and outputZp may be either constant or non-constant,
1685+
// depending on whether dynamic extension is enabled.
1686+
// - If the zp is non-constant, add it as an input to linalg::GenericOp by:
1687+
// 1. Pushing it into 'genericInputs'.
1688+
// 2. Appending a corresponding affine map to 'indexingMaps'.
1689+
int64_t iZpArg = 0;
1690+
if (failed(maybeIZp)) {
1691+
genericInputs.push_back(
1692+
collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1693+
indexingMaps.push_back(broadcastMap);
1694+
iZpArg = indexingMaps.size() - 1;
1695+
}
1696+
int64_t oZpArg = 0;
1697+
if (failed(maybeOZp)) {
1698+
genericInputs.push_back(
1699+
collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1700+
indexingMaps.push_back(broadcastMap);
1701+
oZpArg = indexingMaps.size() - 1;
15101702
}
15111703

15121704
// Indexing maps for output values.
@@ -1526,36 +1718,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15261718
Type valueTy = value.getType();
15271719

15281720
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1529-
if (failed(maybeIZp)) {
1530-
(void)rewriter.notifyMatchFailure(
1531-
op, "input zero point cannot be statically determined");
1532-
return;
1533-
}
1534-
1535-
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1536-
// Extend zeropoint for sub-32bits widths.
1537-
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1538-
auto inputZp = arith::ConstantOp::create(
1539-
nestedBuilder, loc,
1540-
IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
1541-
*maybeIZp));
1721+
auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
1722+
nestedLoc, blockArgs, iZpArg);
15421723

15431724
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1544-
if (failed(maybeOZp)) {
1545-
(void)rewriter.notifyMatchFailure(
1546-
op, "output zero point cannot be statically determined");
1547-
return;
1548-
};
1725+
auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
1726+
nestedLoc, blockArgs, oZpArg);
15491727

15501728
IntegerType outIntType =
15511729
cast<IntegerType>(blockArgs.back().getType());
15521730
unsigned outBitWidth = outIntType.getWidth();
1553-
const int32_t outAttrBitwidth = 32;
15541731
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1555-
auto outputZp = arith::ConstantOp::create(
1556-
nestedBuilder, loc,
1557-
IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
1558-
*maybeOZp));
15591732

15601733
Value multiplier = multiplierConstant ? multiplierConstant
15611734
: blockArgs[multiplierArg];

0 commit comments

Comments
 (0)