@@ -1392,6 +1392,199 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
13921392 }
13931393};
13941394
1395+ // Collapse tensor<1xiN> into tensor<iN>
1396+ // E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1397+ static Value collapse1xNTensorToN (PatternRewriter &rewriter, Value input,
1398+ Location loc) {
1399+ SmallVector<ReassociationExprs, 1 > reassociation;
1400+ // Create the collapsed type
1401+ auto inputType = cast<RankedTensorType>(input.getType ());
1402+ auto elemType = inputType.getElementType ();
1403+ auto collapsedType = RankedTensorType::get ({}, elemType);
1404+ // Emit the collapse op
1405+ return rewriter.create <tensor::CollapseShapeOp>(loc, collapsedType, input,
1406+ reassociation);
1407+ }
1408+
1409+ // The multiplier may be either constant or non-constant, depending on
1410+ // whether dynamic extension is enabled.
1411+ // - If the multiplier is non-constant, add it as an input to linalg::GenericOp
1412+ // by:
1413+ // 1. Pushing it into 'genericInputs'.
1414+ // 2. Appending a corresponding affine map to 'indexingMaps'.
1415+ // - If the multiplier is constant, set 'multiplierConstant' instead.
1416+ static void setupLinalgGenericOpInputAndIndexingMapForMultiplier (
1417+ PatternRewriter &rewriter, llvm::SmallVector<int32_t > &multiplierValues,
1418+ SmallVector<Value, 4 > &genericInputs, SmallVector<AffineMap> &indexingMaps,
1419+ bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
1420+ int64_t &multiplierArg) {
1421+
1422+ auto loc = op.getLoc ();
1423+ auto inputTy = cast<ShapedType>(op.getInput ().getType ());
1424+ unsigned rank = inputTy.getRank ();
1425+ SmallVector<AffineExpr, 2 > multiplierExprs{
1426+ rewriter.getAffineDimExpr (rank - 1 )};
1427+
1428+ if (isConstant) {
1429+ // If we are rescaling per-channel then we need to store the multiplier
1430+ // values in a buffer.
1431+ if (multiplierValues.size () == 1 ) {
1432+ multiplierConstant = rewriter.create <arith::ConstantOp>(
1433+ loc, rewriter.getI32IntegerAttr (multiplierValues.front ()));
1434+ } else {
1435+ auto multiplierType =
1436+ RankedTensorType::get ({static_cast <int64_t >(multiplierValues.size ())},
1437+ rewriter.getI32Type ());
1438+ genericInputs.push_back (arith::ConstantOp::create (
1439+ rewriter, loc,
1440+ DenseIntElementsAttr::get (multiplierType, multiplierValues)));
1441+
1442+ indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1443+ /* symbolCount=*/ 0 , multiplierExprs,
1444+ rewriter.getContext ()));
1445+ }
1446+ } else {
1447+ // If we are not rescaling per-channel then we need to collapse 1xN to N
1448+ // and push broadcastMap.
1449+ auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier ().getType ());
1450+ if (tensorType && tensorType.hasStaticShape () &&
1451+ tensorType.getShape ()[0 ] == 1 ) {
1452+ // broadcastMap = affine_map<(d0, d1) -> ()>
1453+ // It would affect as broadcast for scalar values in linalg::GenericOp.
1454+ AffineMap broadcastMap =
1455+ AffineMap::get (rank, 0 , {}, rewriter.getContext ());
1456+ genericInputs.push_back (
1457+ collapse1xNTensorToN (rewriter, op.getMultiplier (), loc));
1458+ indexingMaps.push_back (broadcastMap);
1459+ } else {
1460+ genericInputs.push_back (op.getMultiplier ());
1461+ indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1462+ /* symbolCount=*/ 0 , multiplierExprs,
1463+ rewriter.getContext ()));
1464+ }
1465+ }
1466+ multiplierArg = indexingMaps.size () - 1 ;
1467+ }
1468+
1469+ // The shift may be either constant or non-constant, depending on
1470+ // whether dynamic extension is enabled.
1471+ // - If the shift is non-constant, add it as an input to linalg::GenericOp by:
1472+ // 1. Pushing it into 'genericInputs'.
1473+ // 2. Appending a corresponding affine map to 'indexingMaps'.
1474+ // - If the shift is constant, set 'shiftConstant' instead.
1475+ static void setupLinalgGenericOpInputAndIndexingMapForShift (
1476+ PatternRewriter &rewriter, llvm::SmallVector<int8_t > &shiftValues,
1477+ SmallVector<Value, 4 > &genericInputs, SmallVector<AffineMap> &indexingMaps,
1478+ bool isConstant, tosa::RescaleOp op, Value &shiftConstant,
1479+ int64_t &shiftArg) {
1480+
1481+ auto loc = op.getLoc ();
1482+ auto inputTy = cast<ShapedType>(op.getInput ().getType ());
1483+ unsigned rank = inputTy.getRank ();
1484+ SmallVector<AffineExpr, 2 > shiftExprs = {rewriter.getAffineDimExpr (rank - 1 )};
1485+
1486+ if (isConstant) {
1487+ // If we are rescaling per-channel then we need to store the shift
1488+ // values in a buffer.
1489+ if (shiftValues.size () == 1 ) {
1490+ shiftConstant = rewriter.create <arith::ConstantOp>(
1491+ loc, rewriter.getI8IntegerAttr (shiftValues.front ()));
1492+ } else {
1493+ auto shiftType =
1494+ RankedTensorType::get ({static_cast <int64_t >(shiftValues.size ())},
1495+ rewriter.getIntegerType (8 ));
1496+ genericInputs.push_back (arith::ConstantOp::create (
1497+ rewriter, loc, DenseIntElementsAttr::get (shiftType, shiftValues)));
1498+ indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1499+ /* symbolCount=*/ 0 , shiftExprs,
1500+ rewriter.getContext ()));
1501+ }
1502+ } else {
1503+ // If we are not rescaling per-channel then we need to collapse 1xN to N
1504+ // and push broadcastMap.
1505+ auto tensorType = dyn_cast<RankedTensorType>(op.getShift ().getType ());
1506+ if (tensorType && tensorType.hasStaticShape () &&
1507+ tensorType.getShape ()[0 ] == 1 ) {
1508+ // broadcastMap = affine_map<(d0, d1) -> ()>
1509+ // It would affect as broadcast for scalar values in linalg::GenericOp.
1510+ AffineMap broadcastMap =
1511+ AffineMap::get (rank, 0 , {}, rewriter.getContext ());
1512+ genericInputs.push_back (
1513+ collapse1xNTensorToN (rewriter, op.getShift (), loc));
1514+ indexingMaps.push_back (broadcastMap);
1515+ } else {
1516+ genericInputs.push_back (op.getShift ());
1517+ indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1518+ /* symbolCount=*/ 0 , shiftExprs,
1519+ rewriter.getContext ()));
1520+ }
1521+ }
1522+ shiftArg = indexingMaps.size () - 1 ;
1523+ }
1524+
1525+ // Return the extended Zp to be used in subsequent arithmetic operations.
1526+ static Value getExtendInputZp (OpBuilder &builder, Type valueTy,
1527+ FailureOr<int64_t > maybeZp, Location loc,
1528+ ValueRange blockArgs, int64_t iZpArg) {
1529+ Value result;
1530+ // The Zp value can be either constant or non-constant, depending on
1531+ // whether dynamic extension is enabled.
1532+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1533+ // be passed as an input to linalg::GenericOp.
1534+ if (failed (maybeZp)) {
1535+ result = blockArgs[iZpArg];
1536+ auto zpTy = result.getType ();
1537+ if (zpTy.getIntOrFloatBitWidth () < 32 ) {
1538+ if (zpTy.isUnsignedInteger ()) {
1539+ result =
1540+ builder.create <arith::ExtUIOp>(loc, builder.getI32Type (), result);
1541+ } else {
1542+ result =
1543+ builder.create <arith::ExtSIOp>(loc, builder.getI32Type (), result);
1544+ }
1545+ }
1546+ } else {
1547+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth ();
1548+ // Extend zeropoint for sub-32bits widths.
1549+ const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32 ;
1550+ result = builder.create <arith::ConstantOp>(
1551+ loc, IntegerAttr::get (builder.getIntegerType (attrBitwidth), *maybeZp));
1552+ }
1553+ return result;
1554+ }
1555+
1556+ // Return the i32 outputZp to be used in subsequent arithmetic operations.
1557+ static Value getI32OutputZp (OpBuilder &builder, Type valueTy,
1558+ FailureOr<int64_t > maybeZp, Location loc,
1559+ ValueRange blockArgs, int64_t oZpArg) {
1560+ Value result;
1561+ // The Zp value can be either constant or non-constant, depending on
1562+ // whether dynamic extension is enabled.
1563+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1564+ // be passed as an input to linalg::GenericOp.
1565+ if (failed (maybeZp)) {
1566+ result = blockArgs[oZpArg];
1567+ auto zpTy = result.getType ();
1568+ if (zpTy.getIntOrFloatBitWidth () < 32 ) {
1569+ if (zpTy.isUnsignedInteger ()) {
1570+ result =
1571+ builder.create <arith::ExtUIOp>(loc, builder.getI32Type (), result);
1572+ } else {
1573+ result =
1574+ builder.create <arith::ExtSIOp>(loc, builder.getI32Type (), result);
1575+ }
1576+ } else if (zpTy.getIntOrFloatBitWidth () > 32 ) {
1577+ result =
1578+ builder.create <arith::TruncIOp>(loc, builder.getI32Type (), result);
1579+ }
1580+ } else {
1581+ const int32_t attrBitwidth = 32 ;
1582+ result = builder.create <arith::ConstantOp>(
1583+ loc, IntegerAttr::get (builder.getIntegerType (attrBitwidth), *maybeZp));
1584+ }
1585+ return result;
1586+ }
1587+
13951588class RescaleConverter : public OpRewritePattern <tosa::RescaleOp> {
13961589public:
13971590 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1423,40 +1616,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14231616 }
14241617 }
14251618
1426- // The shift and multiplier values.
14271619 DenseElementsAttr shiftElems;
1428- if (! matchPattern (op. getShift (), m_Constant (&shiftElems)))
1429- return rewriter. notifyMatchFailure (
1430- op, " tosa.rescale requires constant shift input values " ) ;
1620+ bool isShiftConstant = false ;
1621+ if ( matchPattern (op. getShift (), m_Constant (&shiftElems)))
1622+ isShiftConstant = true ;
14311623
14321624 DenseElementsAttr multiplierElems;
1433- if (!matchPattern (op.getMultiplier (), m_Constant (&multiplierElems)))
1434- return rewriter.notifyMatchFailure (
1435- op, " tosa.rescale requires constant multiplier input values" );
1436-
1437- llvm::SmallVector<int8_t > shiftValues =
1438- llvm::to_vector (shiftElems.getValues <int8_t >());
1439- // explicit cast is required here
1440- llvm::SmallVector<int32_t > multiplierValues = llvm::to_vector (
1441- llvm::map_range (multiplierElems.getValues <IntegerAttr>(),
1442- [](IntegerAttr attr) -> int32_t {
1443- return static_cast <int32_t >(attr.getInt ());
1444- }));
1445-
1446- // If we shift by more than the bitwidth, this just sets to 0.
1447- for (int i = 0 , s = multiplierValues.size (); i < s; i++) {
1448- if (shiftValues[i] > 63 ) {
1449- shiftValues[i] = 0 ;
1450- multiplierValues[i] = 0 ;
1625+ bool isMultiplierConstant = false ;
1626+ if (matchPattern (op.getMultiplier (), m_Constant (&multiplierElems)))
1627+ isMultiplierConstant = true ;
1628+
1629+ llvm::SmallVector<int8_t > shiftValues;
1630+ llvm::SmallVector<int32_t > multiplierValues;
1631+ bool doubleRound;
1632+
1633+ if (isMultiplierConstant && isShiftConstant) {
1634+ shiftValues = llvm::to_vector (shiftElems.getValues <int8_t >());
1635+ // explicit cast is required here
1636+ multiplierValues = llvm::to_vector (
1637+ llvm::map_range (multiplierElems.getValues <IntegerAttr>(),
1638+ [](IntegerAttr attr) -> int32_t {
1639+ return static_cast <int32_t >(attr.getInt ());
1640+ }));
1641+
1642+ // If we shift by more than the bitwidth, this just sets to 0.
1643+ for (int i = 0 , s = multiplierValues.size (); i < s; i++) {
1644+ if (shiftValues[i] > 63 ) {
1645+ shiftValues[i] = 0 ;
1646+ multiplierValues[i] = 0 ;
1647+ }
14511648 }
1452- }
1649+ // Double round only occurs if shift is greater than 31, check that this
1650+ // is ever true.
1651+ doubleRound = op.getRoundingMode () == RoundingMode::DOUBLE_ROUND &&
1652+ llvm::any_of (shiftValues, [](int32_t v) { return v > 31 ; });
1653+ } else
1654+ doubleRound = op.getRoundingMode () == RoundingMode::DOUBLE_ROUND;
14531655
1454- // Double round only occurs if shift is greater than 31, check that this
1455- // is ever true.
1456-
1457- bool doubleRound =
1458- op.getRoundingMode () == RoundingMode::DOUBLE_ROUND &&
1459- llvm::any_of (shiftValues, [](int32_t v) { return v > 31 ; });
14601656 RoundingMode roundingMode =
14611657 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
14621658
@@ -1468,45 +1664,41 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14681664 // values in a buffer.
14691665 Value multiplierConstant;
14701666 int64_t multiplierArg = 0 ;
1471- if (multiplierValues.size () == 1 ) {
1472- multiplierConstant = arith::ConstantOp::create (
1473- rewriter, loc, rewriter.getI32IntegerAttr (multiplierValues.front ()));
1474- } else {
1475- SmallVector<AffineExpr, 2 > multiplierExprs{
1476- rewriter.getAffineDimExpr (rank - 1 )};
1477- auto multiplierType =
1478- RankedTensorType::get ({static_cast <int64_t >(multiplierValues.size ())},
1479- rewriter.getI32Type ());
1480- genericInputs.push_back (arith::ConstantOp::create (
1481- rewriter, loc,
1482- DenseIntElementsAttr::get (multiplierType, multiplierValues)));
1483-
1484- indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1485- /* symbolCount=*/ 0 , multiplierExprs,
1486- rewriter.getContext ()));
1487-
1488- multiplierArg = indexingMaps.size () - 1 ;
1489- }
1667+ setupLinalgGenericOpInputAndIndexingMapForMultiplier (
1668+ rewriter, multiplierValues, genericInputs, indexingMaps,
1669+ isMultiplierConstant, op, multiplierConstant, multiplierArg);
14901670
14911671 // If we are rescaling per-channel then we need to store the shift
14921672 // values in a buffer.
14931673 Value shiftConstant;
14941674 int64_t shiftArg = 0 ;
1495- if (shiftValues.size () == 1 ) {
1496- shiftConstant = arith::ConstantOp::create (
1497- rewriter, loc, rewriter.getI8IntegerAttr (shiftValues.front ()));
1498- } else {
1499- SmallVector<AffineExpr, 2 > shiftExprs = {
1500- rewriter.getAffineDimExpr (rank - 1 )};
1501- auto shiftType =
1502- RankedTensorType::get ({static_cast <int64_t >(shiftValues.size ())},
1503- rewriter.getIntegerType (8 ));
1504- genericInputs.push_back (arith::ConstantOp::create (
1505- rewriter, loc, DenseIntElementsAttr::get (shiftType, shiftValues)));
1506- indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1507- /* symbolCount=*/ 0 , shiftExprs,
1508- rewriter.getContext ()));
1509- shiftArg = indexingMaps.size () - 1 ;
1675+ setupLinalgGenericOpInputAndIndexingMapForShift (
1676+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1677+ shiftConstant, shiftArg);
1678+
1679+ // broadcastMap = affine_map<(d0, d1) -> ()>
1680+ // It would affect as broadcast for scalar values in linalg::GenericOp.
1681+ AffineMap broadcastMap = AffineMap::get (rank, 0 , {}, rewriter.getContext ());
1682+ FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
1683+ FailureOr<int64_t > maybeOZp = op.getOutputZeroPoint ();
1684+ // The inputZp and outputZp may be either constant or non-constant,
1685+ // depending on whether dynamic extension is enabled.
1686+ // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
1687+ // 1. Pushing it into 'genericInputs'.
1688+ // 2. Appending a corresponding affine map to 'indexingMaps'.
1689+ int64_t iZpArg = 0 ;
1690+ if (failed (maybeIZp)) {
1691+ genericInputs.push_back (
1692+ collapse1xNTensorToN (rewriter, op->getOperand (3 ), loc));
1693+ indexingMaps.push_back (broadcastMap);
1694+ iZpArg = indexingMaps.size () - 1 ;
1695+ }
1696+ int64_t oZpArg = 0 ;
1697+ if (failed (maybeOZp)) {
1698+ genericInputs.push_back (
1699+ collapse1xNTensorToN (rewriter, op->getOperand (4 ), loc));
1700+ indexingMaps.push_back (broadcastMap);
1701+ oZpArg = indexingMaps.size () - 1 ;
15101702 }
15111703
15121704 // Indexing maps for output values.
@@ -1526,36 +1718,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15261718 Type valueTy = value.getType ();
15271719
15281720 FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
1529- if (failed (maybeIZp)) {
1530- (void )rewriter.notifyMatchFailure (
1531- op, " input zero point cannot be statically determined" );
1532- return ;
1533- }
1534-
1535- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth ();
1536- // Extend zeropoint for sub-32bits widths.
1537- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32 ;
1538- auto inputZp = arith::ConstantOp::create (
1539- nestedBuilder, loc,
1540- IntegerAttr::get (rewriter.getIntegerType (inAttrBitwidth),
1541- *maybeIZp));
1721+ auto inputZp = getExtendInputZp (nestedBuilder, valueTy, maybeIZp,
1722+ nestedLoc, blockArgs, iZpArg);
15421723
15431724 FailureOr<int64_t > maybeOZp = op.getOutputZeroPoint ();
1544- if (failed (maybeOZp)) {
1545- (void )rewriter.notifyMatchFailure (
1546- op, " output zero point cannot be statically determined" );
1547- return ;
1548- };
1725+ auto outputZp = getI32OutputZp (nestedBuilder, valueTy, maybeOZp,
1726+ nestedLoc, blockArgs, oZpArg);
15491727
15501728 IntegerType outIntType =
15511729 cast<IntegerType>(blockArgs.back ().getType ());
15521730 unsigned outBitWidth = outIntType.getWidth ();
1553- const int32_t outAttrBitwidth = 32 ;
15541731 assert (outBitWidth <= 32 && " Unexpected output zeropoint bitwidth" );
1555- auto outputZp = arith::ConstantOp::create (
1556- nestedBuilder, loc,
1557- IntegerAttr::get (rewriter.getIntegerType (outAttrBitwidth),
1558- *maybeOZp));
15591732
15601733 Value multiplier = multiplierConstant ? multiplierConstant
15611734 : blockArgs[multiplierArg];
0 commit comments