Skip to content

Commit 107ca63

Browse files
authored
[mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg (#155967)
The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled. When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps. The commit help to pass following Tosa conformance tests. rescale_22x20_i32_outi8_sc0_rmS_pc0_iu0_ou0_dyn rescale_31x18_i8_outi8_sc0_rmS_pc0_iu1_ou0_dyn rescale_20x19_i16_outi8_sc0_rmS_pc0_iu1_ou0_dyn
1 parent 6cec362 commit 107ca63

File tree

2 files changed

+288
-86
lines changed

2 files changed

+288
-86
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 202 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -1392,6 +1392,137 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
13921392
}
13931393
};
13941394

1395+
// Collapse tensor<1xiN> into tensor<iN>
1396+
// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1397+
static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
1398+
Location loc) {
1399+
SmallVector<ReassociationExprs, 1> reassociation;
1400+
// Create the collapsed type
1401+
auto inputType = cast<RankedTensorType>(input.getType());
1402+
auto elemType = inputType.getElementType();
1403+
auto collapsedType = RankedTensorType::get({}, elemType);
1404+
// Emit the collapse op
1405+
return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
1406+
reassociation);
1407+
}
1408+
1409+
static llvm::SmallVector<int8_t>
1410+
convertToI8(const llvm::SmallVector<int32_t> &input) {
1411+
llvm::SmallVector<int8_t> output;
1412+
output.reserve(input.size());
1413+
1414+
for (auto v : llvm::map_range(
1415+
input, [](int32_t val) { return static_cast<int8_t>(val); })) {
1416+
output.push_back(v);
1417+
}
1418+
return output;
1419+
}
1420+
1421+
// The shift or multiplier may be either constant or non-constant, depending on
1422+
// whether dynamic extension is enabled.
1423+
// - If the shift or multiplier is non-constant, add it as an input to
1424+
// linalg::GenericOp by:
1425+
// 1. Pushing it into 'genericInputs'.
1426+
// 2. Appending a corresponding affine map to 'indexingMaps'.
1427+
// - If the shift or multiplier is constant, set 'constant' instead.
1428+
static void setupLinalgGenericOpInputAndIndexingMap(
1429+
PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values,
1430+
SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1431+
bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
1432+
bool isShift = false) {
1433+
1434+
auto loc = op.getLoc();
1435+
auto inputTy = cast<ShapedType>(op.getInput().getType());
1436+
unsigned rank = inputTy.getRank();
1437+
SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
1438+
1439+
if (isConstant) {
1440+
// If we are rescaling per-channel then we need to store the
1441+
// values in a buffer.
1442+
if (values.size() == 1) {
1443+
IntegerAttr intAttr = isShift
1444+
? rewriter.getI8IntegerAttr(values.front())
1445+
: rewriter.getI32IntegerAttr(values.front());
1446+
constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
1447+
} else {
1448+
auto elementType =
1449+
isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
1450+
auto tensorType = RankedTensorType::get(
1451+
{static_cast<int64_t>(values.size())}, elementType);
1452+
DenseIntElementsAttr EltAttr;
1453+
if (isShift)
1454+
EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
1455+
else
1456+
EltAttr = DenseIntElementsAttr::get(tensorType, values);
1457+
genericInputs.push_back(
1458+
arith::ConstantOp::create(rewriter, loc, EltAttr));
1459+
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1460+
/*symbolCount=*/0, exprs,
1461+
rewriter.getContext()));
1462+
}
1463+
} else {
1464+
// If we are not rescaling per-channel then we need to collapse 1xN to N
1465+
// and push broadcastMap.
1466+
auto operand = isShift ? op.getShift() : op.getMultiplier();
1467+
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1468+
if (tensorType && tensorType.hasStaticShape() &&
1469+
tensorType.getShape()[0] == 1) {
1470+
// broadcastMap = affine_map<(d0, d1) -> ()>
1471+
// It would affect as broadcast for scalar values in linalg::GenericOp.
1472+
AffineMap broadcastMap =
1473+
AffineMap::get(rank, 0, {}, rewriter.getContext());
1474+
genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
1475+
indexingMaps.push_back(broadcastMap);
1476+
} else {
1477+
genericInputs.push_back(operand);
1478+
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1479+
/*symbolCount=*/0, exprs,
1480+
rewriter.getContext()));
1481+
}
1482+
}
1483+
arg = indexingMaps.size() - 1;
1484+
}
1485+
1486+
// Return the extended Zp to be used in subsequent arithmetic operations.
1487+
static Value getExtendZp(OpBuilder &builder, Type valueTy,
1488+
FailureOr<int64_t> maybeZp, Location loc,
1489+
ValueRange blockArgs, int64_t zpArg,
1490+
bool isOutputZp = false) {
1491+
Value result;
1492+
const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
1493+
const uint32_t attrBitwidth =
1494+
isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
1495+
auto extendType = builder.getIntegerType(attrBitwidth);
1496+
// The Zp value can be either constant or non-constant, depending on
1497+
// whether dynamic extension is enabled.
1498+
// If 'maybeZp' fails, it indicates that Zp is non-constant and will
1499+
// be passed as an input to linalg::GenericOp.
1500+
if (failed(maybeZp)) {
1501+
result = blockArgs[zpArg];
1502+
auto zpTy = result.getType();
1503+
if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
1504+
// For ExtUIOp, the input must be signless.
1505+
// UnrealizedConversionCastOp will cast the input to signless type.
1506+
if (zpTy.isUnsignedInteger()) {
1507+
result =
1508+
UnrealizedConversionCastOp::create(
1509+
builder, loc,
1510+
builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
1511+
.getResult(0);
1512+
}
1513+
if (zpTy.isUnsignedInteger()) {
1514+
return builder.create<arith::ExtUIOp>(loc, extendType, result);
1515+
} else {
1516+
return builder.create<arith::ExtSIOp>(loc, extendType, result);
1517+
}
1518+
}
1519+
} else {
1520+
return builder.create<arith::ConstantOp>(
1521+
loc, IntegerAttr::get(extendType, *maybeZp));
1522+
}
1523+
return result;
1524+
}
1525+
13951526
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13961527
public:
13971528
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1423,40 +1554,46 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14231554
}
14241555
}
14251556

1426-
// The shift and multiplier values.
14271557
DenseElementsAttr shiftElems;
1428-
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1429-
return rewriter.notifyMatchFailure(
1430-
op, "tosa.rescale requires constant shift input values");
1558+
bool isShiftConstant = false;
1559+
if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
1560+
isShiftConstant = true;
14311561

14321562
DenseElementsAttr multiplierElems;
1433-
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1434-
return rewriter.notifyMatchFailure(
1435-
op, "tosa.rescale requires constant multiplier input values");
1436-
1437-
llvm::SmallVector<int8_t> shiftValues =
1438-
llvm::to_vector(shiftElems.getValues<int8_t>());
1439-
// explicit cast is required here
1440-
llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1441-
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1442-
[](IntegerAttr attr) -> int32_t {
1443-
return static_cast<int32_t>(attr.getInt());
1444-
}));
1445-
1446-
// If we shift by more than the bitwidth, this just sets to 0.
1447-
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1448-
if (shiftValues[i] > 63) {
1449-
shiftValues[i] = 0;
1450-
multiplierValues[i] = 0;
1563+
bool isMultiplierConstant = false;
1564+
if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1565+
isMultiplierConstant = true;
1566+
1567+
llvm::SmallVector<int32_t> shiftValues;
1568+
llvm::SmallVector<int32_t> multiplierValues;
1569+
bool doubleRound;
1570+
1571+
if (isMultiplierConstant && isShiftConstant) {
1572+
// explicit cast is required here
1573+
shiftValues = llvm::to_vector(llvm::map_range(
1574+
shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1575+
return static_cast<int32_t>(attr.getInt());
1576+
}));
1577+
multiplierValues = llvm::to_vector(
1578+
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1579+
[](IntegerAttr attr) -> int32_t {
1580+
return static_cast<int32_t>(attr.getInt());
1581+
}));
1582+
1583+
// If we shift by more than the bitwidth, this just sets to 0.
1584+
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1585+
if (shiftValues[i] > 63) {
1586+
shiftValues[i] = 0;
1587+
multiplierValues[i] = 0;
1588+
}
14511589
}
1452-
}
1590+
// Double round only occurs if shift is greater than 31, check that this
1591+
// is ever true.
1592+
doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1593+
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1594+
} else
1595+
doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
14531596

1454-
// Double round only occurs if shift is greater than 31, check that this
1455-
// is ever true.
1456-
1457-
bool doubleRound =
1458-
op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1459-
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
14601597
RoundingMode roundingMode =
14611598
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
14621599

@@ -1468,45 +1605,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14681605
// values in a buffer.
14691606
Value multiplierConstant;
14701607
int64_t multiplierArg = 0;
1471-
if (multiplierValues.size() == 1) {
1472-
multiplierConstant = arith::ConstantOp::create(
1473-
rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1474-
} else {
1475-
SmallVector<AffineExpr, 2> multiplierExprs{
1476-
rewriter.getAffineDimExpr(rank - 1)};
1477-
auto multiplierType =
1478-
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1479-
rewriter.getI32Type());
1480-
genericInputs.push_back(arith::ConstantOp::create(
1481-
rewriter, loc,
1482-
DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1483-
1484-
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1485-
/*symbolCount=*/0, multiplierExprs,
1486-
rewriter.getContext()));
1487-
1488-
multiplierArg = indexingMaps.size() - 1;
1489-
}
1608+
setupLinalgGenericOpInputAndIndexingMap(
1609+
rewriter, multiplierValues, genericInputs, indexingMaps,
1610+
isMultiplierConstant, op, multiplierConstant, multiplierArg);
14901611

14911612
// If we are rescaling per-channel then we need to store the shift
14921613
// values in a buffer.
14931614
Value shiftConstant;
14941615
int64_t shiftArg = 0;
1495-
if (shiftValues.size() == 1) {
1496-
shiftConstant = arith::ConstantOp::create(
1497-
rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1498-
} else {
1499-
SmallVector<AffineExpr, 2> shiftExprs = {
1500-
rewriter.getAffineDimExpr(rank - 1)};
1501-
auto shiftType =
1502-
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1503-
rewriter.getIntegerType(8));
1504-
genericInputs.push_back(arith::ConstantOp::create(
1505-
rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1506-
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1507-
/*symbolCount=*/0, shiftExprs,
1508-
rewriter.getContext()));
1509-
shiftArg = indexingMaps.size() - 1;
1616+
setupLinalgGenericOpInputAndIndexingMap(
1617+
rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1618+
shiftConstant, shiftArg, true);
1619+
1620+
// broadcastMap = affine_map<(d0, d1) -> ()>
1621+
// It would affect as broadcast for scalar values in linalg::GenericOp.
1622+
AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
1623+
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1624+
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1625+
// The inputZp and outputZp may be either constant or non-constant,
1626+
// depending on whether dynamic extension is enabled.
1627+
// - If the zp's are non-constant, add them as an inputs to
1628+
// linalg::GenericOp by:
1629+
// 1. Pushing it into 'genericInputs'.
1630+
// 2. Appending a corresponding affine map to 'indexingMaps'.
1631+
// - If the zp's are constant, they would be generated as arith.constant.
1632+
int64_t iZpArg = 0;
1633+
if (failed(maybeIZp)) {
1634+
genericInputs.push_back(
1635+
collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1636+
indexingMaps.push_back(broadcastMap);
1637+
iZpArg = indexingMaps.size() - 1;
1638+
}
1639+
int64_t oZpArg = 0;
1640+
if (failed(maybeOZp)) {
1641+
genericInputs.push_back(
1642+
collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1643+
indexingMaps.push_back(broadcastMap);
1644+
oZpArg = indexingMaps.size() - 1;
15101645
}
15111646

15121647
// Indexing maps for output values.
@@ -1526,36 +1661,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15261661
Type valueTy = value.getType();
15271662

15281663
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1529-
if (failed(maybeIZp)) {
1530-
(void)rewriter.notifyMatchFailure(
1531-
op, "input zero point cannot be statically determined");
1532-
return;
1533-
}
1534-
1535-
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1536-
// Extend zeropoint for sub-32bits widths.
1537-
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1538-
auto inputZp = arith::ConstantOp::create(
1539-
nestedBuilder, loc,
1540-
IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
1541-
*maybeIZp));
1664+
auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
1665+
nestedLoc, blockArgs, iZpArg);
15421666

15431667
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1544-
if (failed(maybeOZp)) {
1545-
(void)rewriter.notifyMatchFailure(
1546-
op, "output zero point cannot be statically determined");
1547-
return;
1548-
};
1668+
auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
1669+
nestedLoc, blockArgs, oZpArg, true);
15491670

15501671
IntegerType outIntType =
15511672
cast<IntegerType>(blockArgs.back().getType());
15521673
unsigned outBitWidth = outIntType.getWidth();
1553-
const int32_t outAttrBitwidth = 32;
15541674
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1555-
auto outputZp = arith::ConstantOp::create(
1556-
nestedBuilder, loc,
1557-
IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
1558-
*maybeOZp));
15591675

15601676
Value multiplier = multiplierConstant ? multiplierConstant
15611677
: blockArgs[multiplierArg];

0 commit comments

Comments
 (0)