Skip to content

Commit 8f72666

Browse files
committed
introduce setupLinalgGenericOpInputAndIndexingMap for both shift and multiplier
1 parent f792a33 commit 8f72666

File tree

1 file changed

+49
-85
lines changed

1 file changed

+49
-85
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 49 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -1406,120 +1406,81 @@ static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
14061406
reassociation);
14071407
}
14081408

1409-
// The multiplier may be either constant or non-constant, depending on
1410-
// whether dynamic extension is enabled.
1411-
// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
1412-
// by:
1413-
// 1. Pushing it into 'genericInputs'.
1414-
// 2. Appending a corresponding affine map to 'indexingMaps'.
1415-
// - If the multiplier is constant, set 'multiplierConstant' instead.
1416-
static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
1417-
PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
1418-
SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1419-
bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
1420-
int64_t &multiplierArg) {
1421-
1422-
auto loc = op.getLoc();
1423-
auto inputTy = cast<ShapedType>(op.getInput().getType());
1424-
unsigned rank = inputTy.getRank();
1425-
SmallVector<AffineExpr, 2> multiplierExprs{
1426-
rewriter.getAffineDimExpr(rank - 1)};
1427-
1428-
if (isConstant) {
1429-
// If we are rescaling per-channel then we need to store the multiplier
1430-
// values in a buffer.
1431-
if (multiplierValues.size() == 1) {
1432-
multiplierConstant = rewriter.create<arith::ConstantOp>(
1433-
loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1434-
} else {
1435-
auto multiplierType =
1436-
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1437-
rewriter.getI32Type());
1438-
genericInputs.push_back(arith::ConstantOp::create(
1439-
rewriter, loc,
1440-
DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1409+
static llvm::SmallVector<int8_t>
1410+
convertToI8(const llvm::SmallVector<int32_t> &input) {
1411+
llvm::SmallVector<int8_t> output;
1412+
output.reserve(input.size());
14411413

1442-
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1443-
/*symbolCount=*/0, multiplierExprs,
1444-
rewriter.getContext()));
1445-
}
1446-
} else {
1447-
// If we are not rescaling per-channel then we need to collapse 1xN to N
1448-
// and push broadcastMap.
1449-
auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier().getType());
1450-
if (tensorType && tensorType.hasStaticShape() &&
1451-
tensorType.getShape()[0] == 1) {
1452-
// broadcastMap = affine_map<(d0, d1) -> ()>
1453-
// It would affect as broadcast for scalar values in linalg::GenericOp.
1454-
AffineMap broadcastMap =
1455-
AffineMap::get(rank, 0, {}, rewriter.getContext());
1456-
genericInputs.push_back(
1457-
collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
1458-
indexingMaps.push_back(broadcastMap);
1459-
} else {
1460-
genericInputs.push_back(op.getMultiplier());
1461-
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1462-
/*symbolCount=*/0, multiplierExprs,
1463-
rewriter.getContext()));
1464-
}
1414+
for (auto v : llvm::map_range(
1415+
input, [](int32_t val) { return static_cast<int8_t>(val); })) {
1416+
output.push_back(v);
14651417
}
1466-
multiplierArg = indexingMaps.size() - 1;
1418+
return output;
14671419
}
14681420

1469-
// The shift may be either constant or non-constant, depending on
1421+
// The shift or multiplier may be either constant or non-constant, depending on
14701422
// whether dynamic extension is enabled.
1471-
// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
1423+
// - If the shift or multiplier is non-constant, add it as an input to
1424+
// linalg::GenericOp by:
14721425
// 1. Pushing it into 'genericInputs'.
14731426
// 2. Appending a corresponding affine map to 'indexingMaps'.
1474-
// - If the shift is constant, set 'shiftConstant' instead.
1475-
static void setupLinalgGenericOpInputAndIndexingMapForShift(
1476-
PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
1427+
// - If the shift or multiplier is constant, set 'constant' instead.
1428+
static void setupLinalgGenericOpInputAndIndexingMap(
1429+
PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values,
14771430
SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1478-
bool isConstant, tosa::RescaleOp op, Value &shiftConstant,
1479-
int64_t &shiftArg) {
1431+
bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
1432+
bool isShift = false) {
14801433

14811434
auto loc = op.getLoc();
14821435
auto inputTy = cast<ShapedType>(op.getInput().getType());
14831436
unsigned rank = inputTy.getRank();
1484-
SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
1437+
SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};
14851438

14861439
if (isConstant) {
1487-
// If we are rescaling per-channel then we need to store the shift
1440+
// If we are rescaling per-channel then we need to store the
14881441
// values in a buffer.
1489-
if (shiftValues.size() == 1) {
1490-
shiftConstant = rewriter.create<arith::ConstantOp>(
1491-
loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1442+
if (values.size() == 1) {
1443+
IntegerAttr intAttr = isShift
1444+
? rewriter.getI8IntegerAttr(values.front())
1445+
: rewriter.getI32IntegerAttr(values.front());
1446+
constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
14921447
} else {
1493-
auto shiftType =
1494-
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1495-
rewriter.getIntegerType(8));
1496-
genericInputs.push_back(arith::ConstantOp::create(
1497-
rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1448+
auto elementType =
1449+
isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
1450+
auto tensorType = RankedTensorType::get(
1451+
{static_cast<int64_t>(values.size())}, elementType);
1452+
DenseIntElementsAttr EltAttr;
1453+
if (isShift)
1454+
EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
1455+
else
1456+
EltAttr = DenseIntElementsAttr::get(tensorType, values);
1457+
genericInputs.push_back(
1458+
arith::ConstantOp::create(rewriter, loc, EltAttr));
14981459
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1499-
/*symbolCount=*/0, shiftExprs,
1460+
/*symbolCount=*/0, exprs,
15001461
rewriter.getContext()));
15011462
}
15021463
} else {
15031464
// If we are not rescaling per-channel then we need to collapse 1xN to N
15041465
// and push broadcastMap.
1505-
auto tensorType = dyn_cast<RankedTensorType>(op.getShift().getType());
1466+
auto operand = isShift ? op.getShift() : op.getMultiplier();
1467+
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
15061468
if (tensorType && tensorType.hasStaticShape() &&
15071469
tensorType.getShape()[0] == 1) {
15081470
// broadcastMap = affine_map<(d0, d1) -> ()>
15091471
// It would affect as broadcast for scalar values in linalg::GenericOp.
15101472
AffineMap broadcastMap =
15111473
AffineMap::get(rank, 0, {}, rewriter.getContext());
1512-
genericInputs.push_back(
1513-
collapse1xNTensorToN(rewriter, op.getShift(), loc));
1474+
genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
15141475
indexingMaps.push_back(broadcastMap);
15151476
} else {
1516-
genericInputs.push_back(op.getShift());
1477+
genericInputs.push_back(operand);
15171478
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1518-
/*symbolCount=*/0, shiftExprs,
1479+
/*symbolCount=*/0, exprs,
15191480
rewriter.getContext()));
15201481
}
15211482
}
1522-
shiftArg = indexingMaps.size() - 1;
1483+
arg = indexingMaps.size() - 1;
15231484
}
15241485

15251486
// Return the extended Zp to be used in subsequent arithmetic operations.
@@ -1626,13 +1587,16 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
16261587
if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
16271588
isMultiplierConstant = true;
16281589

1629-
llvm::SmallVector<int8_t> shiftValues;
1590+
llvm::SmallVector<int32_t> shiftValues;
16301591
llvm::SmallVector<int32_t> multiplierValues;
16311592
bool doubleRound;
16321593

16331594
if (isMultiplierConstant && isShiftConstant) {
1634-
shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
16351595
// explicit cast is required here
1596+
shiftValues = llvm::to_vector(llvm::map_range(
1597+
shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
1598+
return static_cast<int32_t>(attr.getInt());
1599+
}));
16361600
multiplierValues = llvm::to_vector(
16371601
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
16381602
[](IntegerAttr attr) -> int32_t {
@@ -1664,17 +1628,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
16641628
// values in a buffer.
16651629
Value multiplierConstant;
16661630
int64_t multiplierArg = 0;
1667-
setupLinalgGenericOpInputAndIndexingMapForMultiplier(
1631+
setupLinalgGenericOpInputAndIndexingMap(
16681632
rewriter, multiplierValues, genericInputs, indexingMaps,
16691633
isMultiplierConstant, op, multiplierConstant, multiplierArg);
16701634

16711635
// If we are rescaling per-channel then we need to store the shift
16721636
// values in a buffer.
16731637
Value shiftConstant;
16741638
int64_t shiftArg = 0;
1675-
setupLinalgGenericOpInputAndIndexingMapForShift(
1639+
setupLinalgGenericOpInputAndIndexingMap(
16761640
rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1677-
shiftConstant, shiftArg);
1641+
shiftConstant, shiftArg, true);
16781642

16791643
// broadcastMap = affine_map<(d0, d1) -> ()>
16801644
// It would affect as broadcast for scalar values in linalg::GenericOp.

0 commit comments

Comments
 (0)