@@ -1392,6 +1392,137 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
13921392 }
13931393};
13941394
1395+ // Collapse tensor<1xiN> into tensor<iN>
1396+ // E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1397+ static Value collapse1xNTensorToN (PatternRewriter &rewriter, Value input,
1398+ Location loc) {
1399+ SmallVector<ReassociationExprs, 1 > reassociation;
1400+ // Create the collapsed type
1401+ auto inputType = cast<RankedTensorType>(input.getType ());
1402+ auto elemType = inputType.getElementType ();
1403+ auto collapsedType = RankedTensorType::get ({}, elemType);
1404+ // Emit the collapse op
1405+ return rewriter.create <tensor::CollapseShapeOp>(loc, collapsedType, input,
1406+ reassociation);
1407+ }
1408+
1409+ static llvm::SmallVector<int8_t >
1410+ convertToI8 (const llvm::SmallVector<int32_t > &input) {
1411+ llvm::SmallVector<int8_t > output;
1412+ output.reserve (input.size ());
1413+
1414+ for (auto v : llvm::map_range (
1415+ input, [](int32_t val) { return static_cast <int8_t >(val); })) {
1416+ output.push_back (v);
1417+ }
1418+ return output;
1419+ }
1420+
1421+ // The shift or multiplier may be either constant or non-constant, depending on
1422+ // whether dynamic extension is enabled.
1423+ // - If the shift or multiplier is non-constant, add it as an input to
1424+ // linalg::GenericOp by:
1425+ // 1. Pushing it into 'genericInputs'.
1426+ // 2. Appending a corresponding affine map to 'indexingMaps'.
1427+ // - If the shift or multiplier is constant, set 'constant' instead.
1428+ static void setupLinalgGenericOpInputAndIndexingMap (
1429+ PatternRewriter &rewriter, llvm::SmallVector<int32_t > &values,
1430+ SmallVector<Value, 4 > &genericInputs, SmallVector<AffineMap> &indexingMaps,
1431+ bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
1432+ bool isShift = false ) {
1433+
1434+ auto loc = op.getLoc ();
1435+ auto inputTy = cast<ShapedType>(op.getInput ().getType ());
1436+ unsigned rank = inputTy.getRank ();
1437+ SmallVector<AffineExpr, 2 > exprs = {rewriter.getAffineDimExpr (rank - 1 )};
1438+
1439+ if (isConstant) {
1440+ // If we are rescaling per-channel then we need to store the
1441+ // values in a buffer.
1442+ if (values.size () == 1 ) {
1443+ IntegerAttr intAttr = isShift
1444+ ? rewriter.getI8IntegerAttr (values.front ())
1445+ : rewriter.getI32IntegerAttr (values.front ());
1446+ constant = rewriter.create <arith::ConstantOp>(loc, intAttr);
1447+ } else {
1448+ auto elementType =
1449+ isShift ? rewriter.getIntegerType (8 ) : rewriter.getI32Type ();
1450+ auto tensorType = RankedTensorType::get (
1451+ {static_cast <int64_t >(values.size ())}, elementType);
1452+ DenseIntElementsAttr EltAttr;
1453+ if (isShift)
1454+ EltAttr = DenseIntElementsAttr::get (tensorType, convertToI8 (values));
1455+ else
1456+ EltAttr = DenseIntElementsAttr::get (tensorType, values);
1457+ genericInputs.push_back (
1458+ arith::ConstantOp::create (rewriter, loc, EltAttr));
1459+ indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1460+ /* symbolCount=*/ 0 , exprs,
1461+ rewriter.getContext ()));
1462+ }
1463+ } else {
1464+ // If we are not rescaling per-channel then we need to collapse 1xN to N
1465+ // and push broadcastMap.
1466+ auto operand = isShift ? op.getShift () : op.getMultiplier ();
1467+ auto tensorType = dyn_cast<RankedTensorType>(operand.getType ());
1468+ if (tensorType && tensorType.hasStaticShape () &&
1469+ tensorType.getShape ()[0 ] == 1 ) {
1470+ // broadcastMap = affine_map<(d0, d1) -> ()>
1471+ // It would affect as broadcast for scalar values in linalg::GenericOp.
1472+ AffineMap broadcastMap =
1473+ AffineMap::get (rank, 0 , {}, rewriter.getContext ());
1474+ genericInputs.push_back (collapse1xNTensorToN (rewriter, operand, loc));
1475+ indexingMaps.push_back (broadcastMap);
1476+ } else {
1477+ genericInputs.push_back (operand);
1478+ indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1479+ /* symbolCount=*/ 0 , exprs,
1480+ rewriter.getContext ()));
1481+ }
1482+ }
1483+ arg = indexingMaps.size () - 1 ;
1484+ }
1485+
1486+ // Return the extended Zp to be used in subsequent arithmetic operations.
1487+ static Value getExtendZp (OpBuilder &builder, Type valueTy,
1488+ FailureOr<int64_t > maybeZp, Location loc,
1489+ ValueRange blockArgs, int64_t zpArg,
1490+ bool isOutputZp = false ) {
1491+ Value result;
1492+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth ();
1493+ const uint32_t attrBitwidth =
1494+ isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32 );
1495+ auto extendType = builder.getIntegerType (attrBitwidth);
1496+ // The Zp value can be either constant or non-constant, depending on
1497+ // whether dynamic extension is enabled.
1498+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1499+ // be passed as an input to linalg::GenericOp.
1500+ if (failed (maybeZp)) {
1501+ result = blockArgs[zpArg];
1502+ auto zpTy = result.getType ();
1503+ if (zpTy.getIntOrFloatBitWidth () < attrBitwidth) {
1504+ // For ExtUIOp, the input must be signless.
1505+ // UnrealizedConversionCastOp will cast the input to signless type.
1506+ if (zpTy.isUnsignedInteger ()) {
1507+ result =
1508+ UnrealizedConversionCastOp::create (
1509+ builder, loc,
1510+ builder.getIntegerType (zpTy.getIntOrFloatBitWidth ()), result)
1511+ .getResult (0 );
1512+ }
1513+ if (zpTy.isUnsignedInteger ()) {
1514+ return builder.create <arith::ExtUIOp>(loc, extendType, result);
1515+ } else {
1516+ return builder.create <arith::ExtSIOp>(loc, extendType, result);
1517+ }
1518+ }
1519+ } else {
1520+ return builder.create <arith::ConstantOp>(
1521+ loc, IntegerAttr::get (extendType, *maybeZp));
1522+ }
1523+ return result;
1524+ }
1525+
13951526class RescaleConverter : public OpRewritePattern <tosa::RescaleOp> {
13961527public:
13971528 using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1423,40 +1554,46 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14231554 }
14241555 }
14251556
1426- // The shift and multiplier values.
14271557 DenseElementsAttr shiftElems;
1428- if (! matchPattern (op. getShift (), m_Constant (&shiftElems)))
1429- return rewriter. notifyMatchFailure (
1430- op, " tosa.rescale requires constant shift input values " ) ;
1558+ bool isShiftConstant = false ;
1559+ if ( matchPattern (op. getShift (), m_Constant (&shiftElems)))
1560+ isShiftConstant = true ;
14311561
14321562 DenseElementsAttr multiplierElems;
1433- if (!matchPattern (op.getMultiplier (), m_Constant (&multiplierElems)))
1434- return rewriter.notifyMatchFailure (
1435- op, " tosa.rescale requires constant multiplier input values" );
1436-
1437- llvm::SmallVector<int8_t > shiftValues =
1438- llvm::to_vector (shiftElems.getValues <int8_t >());
1439- // explicit cast is required here
1440- llvm::SmallVector<int32_t > multiplierValues = llvm::to_vector (
1441- llvm::map_range (multiplierElems.getValues <IntegerAttr>(),
1442- [](IntegerAttr attr) -> int32_t {
1443- return static_cast <int32_t >(attr.getInt ());
1444- }));
1445-
1446- // If we shift by more than the bitwidth, this just sets to 0.
1447- for (int i = 0 , s = multiplierValues.size (); i < s; i++) {
1448- if (shiftValues[i] > 63 ) {
1449- shiftValues[i] = 0 ;
1450- multiplierValues[i] = 0 ;
1563+ bool isMultiplierConstant = false ;
1564+ if (matchPattern (op.getMultiplier (), m_Constant (&multiplierElems)))
1565+ isMultiplierConstant = true ;
1566+
1567+ llvm::SmallVector<int32_t > shiftValues;
1568+ llvm::SmallVector<int32_t > multiplierValues;
1569+ bool doubleRound;
1570+
1571+ if (isMultiplierConstant && isShiftConstant) {
1572+ // explicit cast is required here
1573+ shiftValues = llvm::to_vector (llvm::map_range (
1574+ shiftElems.getValues <IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1575+ return static_cast <int32_t >(attr.getInt ());
1576+ }));
1577+ multiplierValues = llvm::to_vector (
1578+ llvm::map_range (multiplierElems.getValues <IntegerAttr>(),
1579+ [](IntegerAttr attr) -> int32_t {
1580+ return static_cast <int32_t >(attr.getInt ());
1581+ }));
1582+
1583+ // If we shift by more than the bitwidth, this just sets to 0.
1584+ for (int i = 0 , s = multiplierValues.size (); i < s; i++) {
1585+ if (shiftValues[i] > 63 ) {
1586+ shiftValues[i] = 0 ;
1587+ multiplierValues[i] = 0 ;
1588+ }
14511589 }
1452- }
1590+ // Double round only occurs if shift is greater than 31, check that this
1591+ // is ever true.
1592+ doubleRound = op.getRoundingMode () == RoundingMode::DOUBLE_ROUND &&
1593+ llvm::any_of (shiftValues, [](int32_t v) { return v > 31 ; });
1594+ } else
1595+ doubleRound = op.getRoundingMode () == RoundingMode::DOUBLE_ROUND;
14531596
1454- // Double round only occurs if shift is greater than 31, check that this
1455- // is ever true.
1456-
1457- bool doubleRound =
1458- op.getRoundingMode () == RoundingMode::DOUBLE_ROUND &&
1459- llvm::any_of (shiftValues, [](int32_t v) { return v > 31 ; });
14601597 RoundingMode roundingMode =
14611598 doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
14621599
@@ -1468,45 +1605,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14681605 // values in a buffer.
14691606 Value multiplierConstant;
14701607 int64_t multiplierArg = 0 ;
1471- if (multiplierValues.size () == 1 ) {
1472- multiplierConstant = arith::ConstantOp::create (
1473- rewriter, loc, rewriter.getI32IntegerAttr (multiplierValues.front ()));
1474- } else {
1475- SmallVector<AffineExpr, 2 > multiplierExprs{
1476- rewriter.getAffineDimExpr (rank - 1 )};
1477- auto multiplierType =
1478- RankedTensorType::get ({static_cast <int64_t >(multiplierValues.size ())},
1479- rewriter.getI32Type ());
1480- genericInputs.push_back (arith::ConstantOp::create (
1481- rewriter, loc,
1482- DenseIntElementsAttr::get (multiplierType, multiplierValues)));
1483-
1484- indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1485- /* symbolCount=*/ 0 , multiplierExprs,
1486- rewriter.getContext ()));
1487-
1488- multiplierArg = indexingMaps.size () - 1 ;
1489- }
1608+ setupLinalgGenericOpInputAndIndexingMap (
1609+ rewriter, multiplierValues, genericInputs, indexingMaps,
1610+ isMultiplierConstant, op, multiplierConstant, multiplierArg);
14901611
14911612 // If we are rescaling per-channel then we need to store the shift
14921613 // values in a buffer.
14931614 Value shiftConstant;
14941615 int64_t shiftArg = 0 ;
1495- if (shiftValues.size () == 1 ) {
1496- shiftConstant = arith::ConstantOp::create (
1497- rewriter, loc, rewriter.getI8IntegerAttr (shiftValues.front ()));
1498- } else {
1499- SmallVector<AffineExpr, 2 > shiftExprs = {
1500- rewriter.getAffineDimExpr (rank - 1 )};
1501- auto shiftType =
1502- RankedTensorType::get ({static_cast <int64_t >(shiftValues.size ())},
1503- rewriter.getIntegerType (8 ));
1504- genericInputs.push_back (arith::ConstantOp::create (
1505- rewriter, loc, DenseIntElementsAttr::get (shiftType, shiftValues)));
1506- indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1507- /* symbolCount=*/ 0 , shiftExprs,
1508- rewriter.getContext ()));
1509- shiftArg = indexingMaps.size () - 1 ;
1616+ setupLinalgGenericOpInputAndIndexingMap (
1617+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1618+ shiftConstant, shiftArg, true );
1619+
1620+ // broadcastMap = affine_map<(d0, d1) -> ()>
1621+ // It would affect as broadcast for scalar values in linalg::GenericOp.
1622+ AffineMap broadcastMap = AffineMap::get (rank, 0 , {}, rewriter.getContext ());
1623+ FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
1624+ FailureOr<int64_t > maybeOZp = op.getOutputZeroPoint ();
1625+ // The inputZp and outputZp may be either constant or non-constant,
1626+ // depending on whether dynamic extension is enabled.
1627+ // - If the zp's are non-constant, add them as an inputs to
1628+ // linalg::GenericOp by:
1629+ // 1. Pushing it into 'genericInputs'.
1630+ // 2. Appending a corresponding affine map to 'indexingMaps'.
1631+ // - If the zp's are constant, they would be generated as arith.constant.
1632+ int64_t iZpArg = 0 ;
1633+ if (failed (maybeIZp)) {
1634+ genericInputs.push_back (
1635+ collapse1xNTensorToN (rewriter, op->getOperand (3 ), loc));
1636+ indexingMaps.push_back (broadcastMap);
1637+ iZpArg = indexingMaps.size () - 1 ;
1638+ }
1639+ int64_t oZpArg = 0 ;
1640+ if (failed (maybeOZp)) {
1641+ genericInputs.push_back (
1642+ collapse1xNTensorToN (rewriter, op->getOperand (4 ), loc));
1643+ indexingMaps.push_back (broadcastMap);
1644+ oZpArg = indexingMaps.size () - 1 ;
15101645 }
15111646
15121647 // Indexing maps for output values.
@@ -1526,36 +1661,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
15261661 Type valueTy = value.getType ();
15271662
15281663 FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
1529- if (failed (maybeIZp)) {
1530- (void )rewriter.notifyMatchFailure (
1531- op, " input zero point cannot be statically determined" );
1532- return ;
1533- }
1534-
1535- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth ();
1536- // Extend zeropoint for sub-32bits widths.
1537- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32 ;
1538- auto inputZp = arith::ConstantOp::create (
1539- nestedBuilder, loc,
1540- IntegerAttr::get (rewriter.getIntegerType (inAttrBitwidth),
1541- *maybeIZp));
1664+ auto inputZp = getExtendZp (nestedBuilder, valueTy, maybeIZp,
1665+ nestedLoc, blockArgs, iZpArg);
15421666
15431667 FailureOr<int64_t > maybeOZp = op.getOutputZeroPoint ();
1544- if (failed (maybeOZp)) {
1545- (void )rewriter.notifyMatchFailure (
1546- op, " output zero point cannot be statically determined" );
1547- return ;
1548- };
1668+ auto outputZp = getExtendZp (nestedBuilder, valueTy, maybeOZp,
1669+ nestedLoc, blockArgs, oZpArg, true );
15491670
15501671 IntegerType outIntType =
15511672 cast<IntegerType>(blockArgs.back ().getType ());
15521673 unsigned outBitWidth = outIntType.getWidth ();
1553- const int32_t outAttrBitwidth = 32 ;
15541674 assert (outBitWidth <= 32 && " Unexpected output zeropoint bitwidth" );
1555- auto outputZp = arith::ConstantOp::create (
1556- nestedBuilder, loc,
1557- IntegerAttr::get (rewriter.getIntegerType (outAttrBitwidth),
1558- *maybeOZp));
15591675
15601676 Value multiplier = multiplierConstant ? multiplierConstant
15611677 : blockArgs[multiplierArg];
0 commit comments