Skip to content
288 changes: 202 additions & 86 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1392,6 +1392,137 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
}
};

// Collapse tensor<1xiN> into tensor<iN>
// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
Location loc) {
SmallVector<ReassociationExprs, 1> reassociation;
// Create the collapsed type
auto inputType = cast<RankedTensorType>(input.getType());
auto elemType = inputType.getElementType();
auto collapsedType = RankedTensorType::get({}, elemType);
// Emit the collapse op
return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
reassociation);
}

static llvm::SmallVector<int8_t>
convertToI8(const llvm::SmallVector<int32_t> &input) {
llvm::SmallVector<int8_t> output;
output.reserve(input.size());

for (auto v : llvm::map_range(
input, [](int32_t val) { return static_cast<int8_t>(val); })) {
output.push_back(v);
}
return output;
}

// The shift or multiplier may be either constant or non-constant, depending on
// whether dynamic extension is enabled.
// - If the shift or multiplier is non-constant, add it as an input to
// linalg::GenericOp by:
// 1. Pushing it into 'genericInputs'.
// 2. Appending a corresponding affine map to 'indexingMaps'.
// - If the shift or multiplier is constant, set 'constant' instead.
static void setupLinalgGenericOpInputAndIndexingMap(
PatternRewriter &rewriter, llvm::SmallVector<int32_t> &values,
SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
bool isConstant, tosa::RescaleOp op, Value &constant, int64_t &arg,
bool isShift = false) {

auto loc = op.getLoc();
auto inputTy = cast<ShapedType>(op.getInput().getType());
unsigned rank = inputTy.getRank();
SmallVector<AffineExpr, 2> exprs = {rewriter.getAffineDimExpr(rank - 1)};

if (isConstant) {
// If we are rescaling per-channel then we need to store the
// values in a buffer.
if (values.size() == 1) {
IntegerAttr intAttr = isShift
? rewriter.getI8IntegerAttr(values.front())
: rewriter.getI32IntegerAttr(values.front());
constant = rewriter.create<arith::ConstantOp>(loc, intAttr);
} else {
auto elementType =
isShift ? rewriter.getIntegerType(8) : rewriter.getI32Type();
auto tensorType = RankedTensorType::get(
{static_cast<int64_t>(values.size())}, elementType);
DenseIntElementsAttr EltAttr;
if (isShift)
EltAttr = DenseIntElementsAttr::get(tensorType, convertToI8(values));
else
EltAttr = DenseIntElementsAttr::get(tensorType, values);
genericInputs.push_back(
arith::ConstantOp::create(rewriter, loc, EltAttr));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, exprs,
rewriter.getContext()));
}
} else {
// If we are not rescaling per-channel then we need to collapse 1xN to N
// and push broadcastMap.
auto operand = isShift ? op.getShift() : op.getMultiplier();
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
if (tensorType && tensorType.hasStaticShape() &&
tensorType.getShape()[0] == 1) {
// broadcastMap = affine_map<(d0, d1) -> ()>
// It would affect as broadcast for scalar values in linalg::GenericOp.
AffineMap broadcastMap =
AffineMap::get(rank, 0, {}, rewriter.getContext());
genericInputs.push_back(collapse1xNTensorToN(rewriter, operand, loc));
indexingMaps.push_back(broadcastMap);
} else {
genericInputs.push_back(operand);
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, exprs,
rewriter.getContext()));
}
}
arg = indexingMaps.size() - 1;
}

// Return the extended Zp to be used in subsequent arithmetic operations.
static Value getExtendZp(OpBuilder &builder, Type valueTy,
FailureOr<int64_t> maybeZp, Location loc,
ValueRange blockArgs, int64_t zpArg,
bool isOutputZp = false) {
Value result;
const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
const uint32_t attrBitwidth =
isOutputZp ? 32 : (bitwidth > 32 ? bitwidth : 32);
auto extendType = builder.getIntegerType(attrBitwidth);
// The Zp value can be either constant or non-constant, depending on
// whether dynamic extension is enabled.
// If 'maybeZp' fails, it indicates that Zp is non-constant and will
// be passed as an input to linalg::GenericOp.
if (failed(maybeZp)) {
result = blockArgs[zpArg];
auto zpTy = result.getType();
if (zpTy.getIntOrFloatBitWidth() < attrBitwidth) {
// For ExtUIOp, the input must be signless.
// UnrealizedConversionCastOp will cast the input to signless type.
if (zpTy.isUnsignedInteger()) {
result =
UnrealizedConversionCastOp::create(
builder, loc,
builder.getIntegerType(zpTy.getIntOrFloatBitWidth()), result)
.getResult(0);
}
if (zpTy.isUnsignedInteger()) {
return builder.create<arith::ExtUIOp>(loc, extendType, result);
} else {
return builder.create<arith::ExtSIOp>(loc, extendType, result);
}
}
} else {
return builder.create<arith::ConstantOp>(
loc, IntegerAttr::get(extendType, *maybeZp));
}
return result;
}

class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
Expand Down Expand Up @@ -1423,40 +1554,46 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}
}

// The shift and multiplier values.
DenseElementsAttr shiftElems;
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires constant shift input values");
bool isShiftConstant = false;
if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
isShiftConstant = true;

DenseElementsAttr multiplierElems;
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires constant multiplier input values");

llvm::SmallVector<int8_t> shiftValues =
llvm::to_vector(shiftElems.getValues<int8_t>());
// explicit cast is required here
llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
[](IntegerAttr attr) -> int32_t {
return static_cast<int32_t>(attr.getInt());
}));

// If we shift by more than the bitwidth, this just sets to 0.
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
if (shiftValues[i] > 63) {
shiftValues[i] = 0;
multiplierValues[i] = 0;
bool isMultiplierConstant = false;
if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
isMultiplierConstant = true;

llvm::SmallVector<int32_t> shiftValues;
llvm::SmallVector<int32_t> multiplierValues;
bool doubleRound;

if (isMultiplierConstant && isShiftConstant) {
// explicit cast is required here
shiftValues = llvm::to_vector(llvm::map_range(
shiftElems.getValues<IntegerAttr>(), [](IntegerAttr attr) -> int32_t {
return static_cast<int32_t>(attr.getInt());
}));
multiplierValues = llvm::to_vector(
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
[](IntegerAttr attr) -> int32_t {
return static_cast<int32_t>(attr.getInt());
}));

// If we shift by more than the bitwidth, this just sets to 0.
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
if (shiftValues[i] > 63) {
shiftValues[i] = 0;
multiplierValues[i] = 0;
}
}
}
// Double round only occurs if shift is greater than 31, check that this
// is ever true.
doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
} else
doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;

// Double round only occurs if shift is greater than 31, check that this
// is ever true.

bool doubleRound =
op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
RoundingMode roundingMode =
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;

Expand All @@ -1468,45 +1605,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// values in a buffer.
Value multiplierConstant;
int64_t multiplierArg = 0;
if (multiplierValues.size() == 1) {
multiplierConstant = arith::ConstantOp::create(
rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
} else {
SmallVector<AffineExpr, 2> multiplierExprs{
rewriter.getAffineDimExpr(rank - 1)};
auto multiplierType =
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
rewriter.getI32Type());
genericInputs.push_back(arith::ConstantOp::create(
rewriter, loc,
DenseIntElementsAttr::get(multiplierType, multiplierValues)));

indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, multiplierExprs,
rewriter.getContext()));

multiplierArg = indexingMaps.size() - 1;
}
setupLinalgGenericOpInputAndIndexingMap(
rewriter, multiplierValues, genericInputs, indexingMaps,
isMultiplierConstant, op, multiplierConstant, multiplierArg);

// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
int64_t shiftArg = 0;
if (shiftValues.size() == 1) {
shiftConstant = arith::ConstantOp::create(
rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
} else {
SmallVector<AffineExpr, 2> shiftExprs = {
rewriter.getAffineDimExpr(rank - 1)};
auto shiftType =
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
rewriter.getIntegerType(8));
genericInputs.push_back(arith::ConstantOp::create(
rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, shiftExprs,
rewriter.getContext()));
shiftArg = indexingMaps.size() - 1;
setupLinalgGenericOpInputAndIndexingMap(
rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
shiftConstant, shiftArg, true);

// broadcastMap = affine_map<(d0, d1) -> ()>
// It would affect as broadcast for scalar values in linalg::GenericOp.
AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
// The inputZp and outputZp may be either constant or non-constant,
// depending on whether dynamic extension is enabled.
// - If the zp's are non-constant, add them as an inputs to
// linalg::GenericOp by:
// 1. Pushing it into 'genericInputs'.
// 2. Appending a corresponding affine map to 'indexingMaps'.
// - If the zp's are constant, they would be generated as arith.constant.
int64_t iZpArg = 0;
if (failed(maybeIZp)) {
genericInputs.push_back(
collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
indexingMaps.push_back(broadcastMap);
iZpArg = indexingMaps.size() - 1;
}
int64_t oZpArg = 0;
if (failed(maybeOZp)) {
genericInputs.push_back(
collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
indexingMaps.push_back(broadcastMap);
oZpArg = indexingMaps.size() - 1;
}

// Indexing maps for output values.
Expand All @@ -1526,36 +1661,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Type valueTy = value.getType();

FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp)) {
(void)rewriter.notifyMatchFailure(
op, "input zero point cannot be statically determined");
return;
}

const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
// Extend zeropoint for sub-32bits widths.
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
auto inputZp = arith::ConstantOp::create(
nestedBuilder, loc,
IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
*maybeIZp));
auto inputZp = getExtendZp(nestedBuilder, valueTy, maybeIZp,
nestedLoc, blockArgs, iZpArg);

FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeOZp)) {
(void)rewriter.notifyMatchFailure(
op, "output zero point cannot be statically determined");
return;
};
auto outputZp = getExtendZp(nestedBuilder, valueTy, maybeOZp,
nestedLoc, blockArgs, oZpArg, true);

IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
auto outputZp = arith::ConstantOp::create(
nestedBuilder, loc,
IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
*maybeOZp));

Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
Expand Down
Loading