@@ -1406,120 +1406,81 @@ static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
14061406 reassociation);
14071407}
14081408
1409- // The multiplier may be either constant or non-constant, depending on
1410- // whether dynamic extension is enabled.
1411- // - If the multiplier is non-constant, add it as an input to linalg::GenericOp
1412- // by:
1413- // 1. Pushing it into 'genericInputs'.
1414- // 2. Appending a corresponding affine map to 'indexingMaps'.
1415- // - If the multiplier is constant, set 'multiplierConstant' instead.
1416- static void setupLinalgGenericOpInputAndIndexingMapForMultiplier (
1417- PatternRewriter &rewriter, llvm::SmallVector<int32_t > &multiplierValues,
1418- SmallVector<Value, 4 > &genericInputs, SmallVector<AffineMap> &indexingMaps,
1419- bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
1420- int64_t &multiplierArg) {
1421-
1422- auto loc = op.getLoc ();
1423- auto inputTy = cast<ShapedType>(op.getInput ().getType ());
1424- unsigned rank = inputTy.getRank ();
1425- SmallVector<AffineExpr, 2 > multiplierExprs{
1426- rewriter.getAffineDimExpr (rank - 1 )};
1427-
1428- if (isConstant) {
1429- // If we are rescaling per-channel then we need to store the multiplier
1430- // values in a buffer.
1431- if (multiplierValues.size () == 1 ) {
1432- multiplierConstant = rewriter.create <arith::ConstantOp>(
1433- loc, rewriter.getI32IntegerAttr (multiplierValues.front ()));
1434- } else {
1435- auto multiplierType =
1436- RankedTensorType::get ({static_cast <int64_t >(multiplierValues.size ())},
1437- rewriter.getI32Type ());
1438- genericInputs.push_back (arith::ConstantOp::create (
1439- rewriter, loc,
1440- DenseIntElementsAttr::get (multiplierType, multiplierValues)));
1409+ static llvm::SmallVector<int8_t >
1410+ convertToI8 (const llvm::SmallVector<int32_t > &input) {
1411+ llvm::SmallVector<int8_t > output;
1412+ output.reserve (input.size ());
14411413
1442- indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1443- /* symbolCount=*/ 0 , multiplierExprs,
1444- rewriter.getContext ()));
1445- }
1446- } else {
1447- // If we are not rescaling per-channel then we need to collapse 1xN to N
1448- // and push broadcastMap.
1449- auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier ().getType ());
1450- if (tensorType && tensorType.hasStaticShape () &&
1451- tensorType.getShape ()[0 ] == 1 ) {
1452- // broadcastMap = affine_map<(d0, d1) -> ()>
1453- // It would affect as broadcast for scalar values in linalg::GenericOp.
1454- AffineMap broadcastMap =
1455- AffineMap::get (rank, 0 , {}, rewriter.getContext ());
1456- genericInputs.push_back (
1457- collapse1xNTensorToN (rewriter, op.getMultiplier (), loc));
1458- indexingMaps.push_back (broadcastMap);
1459- } else {
1460- genericInputs.push_back (op.getMultiplier ());
1461- indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1462- /* symbolCount=*/ 0 , multiplierExprs,
1463- rewriter.getContext ()));
1464- }
1414+ for (auto v : llvm::map_range (
1415+ input, [](int32_t val) { return static_cast <int8_t >(val); })) {
1416+ output.push_back (v);
14651417 }
1466- multiplierArg = indexingMaps. size () - 1 ;
1418+ return output ;
14671419}
14681420
1469- // The shift may be either constant or non-constant, depending on
1421+ // The shift or multiplier may be either constant or non-constant, depending on
14701422// whether dynamic extension is enabled.
1471- // - If the shift is non-constant, add it as an input to linalg::GenericOp by:
1423+ // - If the shift or multiplier is non-constant, add it as an input to
1424+ // linalg::GenericOp by:
14721425// 1. Pushing it into 'genericInputs'.
14731426// 2. Appending a corresponding affine map to 'indexingMaps'.
1474- // - If the shift is constant, set 'shiftConstant ' instead.
1475- static void setupLinalgGenericOpInputAndIndexingMapForShift (
1476- PatternRewriter &rewriter, llvm::SmallVector<int8_t > &shiftValues ,
1427+ // - If the shift or multiplier is constant, set 'constant ' instead.
1428+ static void setupLinalgGenericOpInputAndIndexingMap (
1429+ PatternRewriter &rewriter, llvm::SmallVector<int32_t > &values ,
14771430 SmallVector<Value, 4 > &genericInputs, SmallVector<AffineMap> &indexingMaps,
1478- bool isConstant, tosa::RescaleOp op, Value &shiftConstant ,
1479- int64_t &shiftArg ) {
1431+ bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg ,
1432+ bool isShift = false ) {
14801433
14811434 auto loc = op.getLoc ();
14821435 auto inputTy = cast<ShapedType>(op.getInput ().getType ());
14831436 unsigned rank = inputTy.getRank ();
1484- SmallVector<AffineExpr, 2 > shiftExprs = {rewriter.getAffineDimExpr (rank - 1 )};
1437+ SmallVector<AffineExpr, 2 > exprs = {rewriter.getAffineDimExpr (rank - 1 )};
14851438
14861439 if (isConstant) {
1487- // If we are rescaling per-channel then we need to store the shift
1440+ // If we are rescaling per-channel then we need to store the
14881441 // values in a buffer.
1489- if (shiftValues.size () == 1 ) {
1490- shiftConstant = rewriter.create <arith::ConstantOp>(
1491- loc, rewriter.getI8IntegerAttr (shiftValues.front ()));
1442+ if (values.size () == 1 ) {
1443+ IntegerAttr intAttr = isShift
1444+ ? rewriter.getI8IntegerAttr (values.front ())
1445+ : rewriter.getI32IntegerAttr (values.front ());
1446+ constant = rewriter.create <arith::ConstantOp>(loc, intAttr);
14921447 } else {
1493- auto shiftType =
1494- RankedTensorType::get ({static_cast <int64_t >(shiftValues.size ())},
1495- rewriter.getIntegerType (8 ));
1496- genericInputs.push_back (arith::ConstantOp::create (
1497- rewriter, loc, DenseIntElementsAttr::get (shiftType, shiftValues)));
1448+ auto elementType =
1449+ isShift ? rewriter.getIntegerType (8 ) : rewriter.getI32Type ();
1450+ auto tensorType = RankedTensorType::get (
1451+ {static_cast <int64_t >(values.size ())}, elementType);
1452+ DenseIntElementsAttr EltAttr;
1453+ if (isShift)
1454+ EltAttr = DenseIntElementsAttr::get (tensorType, convertToI8 (values));
1455+ else
1456+ EltAttr = DenseIntElementsAttr::get (tensorType, values);
1457+ genericInputs.push_back (
1458+ arith::ConstantOp::create (rewriter, loc, EltAttr));
14981459 indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1499- /* symbolCount=*/ 0 , shiftExprs ,
1460+ /* symbolCount=*/ 0 , exprs ,
15001461 rewriter.getContext ()));
15011462 }
15021463 } else {
15031464 // If we are not rescaling per-channel then we need to collapse 1xN to N
15041465 // and push broadcastMap.
1505- auto tensorType = dyn_cast<RankedTensorType>(op.getShift ().getType ());
1466+ auto operand = isShift ? op.getShift () : op.getMultiplier ();
1467+ auto tensorType = dyn_cast<RankedTensorType>(operand.getType ());
15061468 if (tensorType && tensorType.hasStaticShape () &&
15071469 tensorType.getShape ()[0 ] == 1 ) {
15081470 // broadcastMap = affine_map<(d0, d1) -> ()>
15091471 // It would affect as broadcast for scalar values in linalg::GenericOp.
15101472 AffineMap broadcastMap =
15111473 AffineMap::get (rank, 0 , {}, rewriter.getContext ());
1512- genericInputs.push_back (
1513- collapse1xNTensorToN (rewriter, op.getShift (), loc));
1474+ genericInputs.push_back (collapse1xNTensorToN (rewriter, operand, loc));
15141475 indexingMaps.push_back (broadcastMap);
15151476 } else {
1516- genericInputs.push_back (op. getShift () );
1477+ genericInputs.push_back (operand );
15171478 indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1518- /* symbolCount=*/ 0 , shiftExprs ,
1479+ /* symbolCount=*/ 0 , exprs ,
15191480 rewriter.getContext ()));
15201481 }
15211482 }
1522- shiftArg = indexingMaps.size () - 1 ;
1483+ arg = indexingMaps.size () - 1 ;
15231484}
15241485
15251486// Return the extended Zp to be used in subsequent arithmetic operations.
@@ -1626,13 +1587,16 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
16261587 if (matchPattern (op.getMultiplier (), m_Constant (&multiplierElems)))
16271588 isMultiplierConstant = true ;
16281589
1629- llvm::SmallVector<int8_t > shiftValues;
1590+ llvm::SmallVector<int32_t > shiftValues;
16301591 llvm::SmallVector<int32_t > multiplierValues;
16311592 bool doubleRound;
16321593
16331594 if (isMultiplierConstant && isShiftConstant) {
1634- shiftValues = llvm::to_vector (shiftElems.getValues <int8_t >());
16351595 // explicit cast is required here
1596+ shiftValues = llvm::to_vector (llvm::map_range (
1597+ shiftElems.getValues <IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1598+ return static_cast <int32_t >(attr.getInt ());
1599+ }));
16361600 multiplierValues = llvm::to_vector (
16371601 llvm::map_range (multiplierElems.getValues <IntegerAttr>(),
16381602 [](IntegerAttr attr) -> int32_t {
@@ -1664,17 +1628,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
16641628 // values in a buffer.
16651629 Value multiplierConstant;
16661630 int64_t multiplierArg = 0 ;
1667- setupLinalgGenericOpInputAndIndexingMapForMultiplier (
1631+ setupLinalgGenericOpInputAndIndexingMap (
16681632 rewriter, multiplierValues, genericInputs, indexingMaps,
16691633 isMultiplierConstant, op, multiplierConstant, multiplierArg);
16701634
16711635 // If we are rescaling per-channel then we need to store the shift
16721636 // values in a buffer.
16731637 Value shiftConstant;
16741638 int64_t shiftArg = 0 ;
1675- setupLinalgGenericOpInputAndIndexingMapForShift (
1639+ setupLinalgGenericOpInputAndIndexingMap (
16761640 rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1677- shiftConstant, shiftArg);
1641+ shiftConstant, shiftArg, true );
16781642
16791643 // broadcastMap = affine_map<(d0, d1) -> ()>
16801644 // It would affect as broadcast for scalar values in linalg::GenericOp.
0 commit comments