Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,15 @@ profileComplianceMap = {
{fp32T, fp16T}}}}},
{"tosa.rescale",
{{{Profile::pro_int},
{{i8T, i8T},
{i8T, i16T},
{i8T, i32T},
{i16T, i8T},
{i16T, i16T},
{i16T, i32T},
{i32T, i8T},
{i32T, i16T},
{i32T, i32T}}}}},
{{i8T, i8T, i8T, i8T},
{i8T, i8T, i16T, i16T},
{i8T, i8T, i32T, i32T},
{i16T, i16T, i8T, i8T},
{i16T, i16T, i16T, i16T},
{i16T, i16T, i32T, i32T},
{i32T, i32T, i8T, i8T},
{i32T, i32T, i16T, i16T},
{i32T, i32T, i32T, i32T}}}}},
{"tosa.const",
{{{Profile::pro_int}, {{boolT}, {i8T}, {i16T}, {i32T}}},
{{Profile::pro_fp}, {{fp16T}, {fp32T}}}}},
Expand Down Expand Up @@ -390,7 +390,10 @@ extensionComplianceMap = {
{fp16T, fp8e5m2T},
{fp32T, fp8e5m2T}}}}},
{"tosa.rescale",
{{{Extension::int16}, {{i48T, i8T}, {i48T, i16T}, {i48T, i32T}}}}},
{{{Extension::int16},
{{i48T, i48T, i8T, i8T},
{i48T, i48T, i16T, i16T},
{i48T, i48T, i32T, i32T}}}}},
{"tosa.const",
{{{Extension::int4}, {{i4T}}},
{{Extension::int16}, {{i48T}}},
Expand Down
11 changes: 9 additions & 2 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2357,8 +2357,8 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
Tosa_Tensor:$input,
Tosa_1DInt16Or32Tensor:$multiplier,
Tosa_1DInt8Tensor:$shift,
I32Attr:$input_zp,
I32Attr:$output_zp,
Tosa_ScalarIntOrFloatTensor:$input_zp,
Tosa_ScalarIntOrFloatTensor:$output_zp,
BoolAttr:$scale32,
Tosa_RoundingTypeAttr:$rounding_mode,
BoolAttr:$per_channel,
Expand All @@ -2375,6 +2375,13 @@ def Tosa_RescaleOp : Tosa_InferShapedTypeOp<"rescale"> {
Extension<[Tosa_EXT_INT16]>,
];

let extraClassDeclaration = [{
FailureOr<int64_t> getInputZeroPoint();
FailureOr<int64_t> getOutputZeroPoint();
LogicalResult verifyInputZeroPoint(int64_t zp);
LogicalResult verifyOutputZeroPoint(int64_t zp);
}];

let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";
Expand Down
30 changes: 22 additions & 8 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,9 @@ materializeBinaryNanCheckIfRequired(OpTy op, PatternRewriter &rewriter,

template <typename T>
static arith::ConstantOp
createConstFromIntAttribute(Operation *op, const std::string &attrName,
Type requiredAttrType, OpBuilder &rewriter) {
auto castedN = static_cast<T>(
cast<IntegerAttr>(op->getAttr(attrName)).getValue().getSExtValue());
createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType,
OpBuilder &rewriter) {
auto castedN = static_cast<T>(zp);
return rewriter.create<arith::ConstantOp>(
op->getLoc(), IntegerAttr::get(requiredAttrType, castedN));
}
Expand Down Expand Up @@ -1510,11 +1509,26 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// later.
int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32;

auto inputZp = createConstFromIntAttribute<int32_t>(
op, "input_zp", nestedBuilder.getIntegerType(inBitwidth),
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp)) {
(void)rewriter.notifyMatchFailure(
op, "input zero point cannot be statically determined");
return;
}

auto inputZp = createConstOpFromZpVal<int32_t>(
op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth),
nestedBuilder);
auto outputZp = createConstFromIntAttribute<int32_t>(
op, "output_zp", nestedBuilder.getI32Type(), nestedBuilder);

FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeOZp)) {
(void)rewriter.notifyMatchFailure(
op, "output zero point cannot be statically determined");
return;
};

auto outputZp = createConstOpFromZpVal<int32_t>(
op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder);

Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
Expand Down
97 changes: 62 additions & 35 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,25 @@ static Type getStorageElementTypeOrSelf(Type type) {
return elementType;
}

static LogicalResult verifyRescaleValueAndZpTypes(Operation *op, Value val,
Value valZp, StringRef name) {
Type eType = getStorageElementTypeOrSelf(val.getType());
Type eZpType = getStorageElementTypeOrSelf(valZp.getType());

bool bothInts =
mlir::isa<IntegerType>(eType) && mlir::isa<IntegerType>(eZpType);
bool sameBitWidth =
(eType.getIntOrFloatBitWidth() == eZpType.getIntOrFloatBitWidth());

if (!bothInts || !sameBitWidth) {
return op->emitOpError()
<< "expected " << name << " and " << name
<< "_zp to both be integer of the same bitwidth, but got " << eType
<< " vs. " << eZpType;
}
return success();
}

//===----------------------------------------------------------------------===//
// TOSA Operator Verifiers.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1729,6 +1748,33 @@ static LogicalResult verifyZeroPoint(T op, Value val, const int64_t &zp,
return success();
}

static LogicalResult verifyZeroPoint(tosa::RescaleOp op, Value zpVal,
const int64_t &zp,
const std::string &operand) {
bool isInputZp = (operand == "Input");

bool tensorUnsigned =
isInputZp ? op.getInputUnsigned() : op.getOutputUnsigned();
StringRef tensorName = isInputZp ? "input" : "output";

Type zpElemType = getElementTypeOrSelf(zpVal);

if (zp != 0) {
if (!zpElemType.isInteger(8) &&
!(zpElemType.isInteger(16) && tensorUnsigned)) {
return op.emitOpError()
<< "expect " << tensorName << "_zp of 0, got " << zp;
}
if (zpElemType.isInteger(16) && tensorUnsigned && zp != 32768) {
return op.emitOpError() << "expect " << tensorName
<< "_zp of 0 or 32768 for unsigned int16 "
<< tensorName << ", got " << zp;
}
}

return success();
}

#define ZERO_POINT_HELPER(OP, OPERAND_NAME) \
FailureOr<int64_t> tosa::OP::get##OPERAND_NAME##ZeroPoint() { \
return getZeroPoint(*this, get##OPERAND_NAME##Zp()); \
Expand All @@ -1751,7 +1797,8 @@ ZERO_POINT_HELPER(MatMulOp, A)
ZERO_POINT_HELPER(MatMulOp, B)
ZERO_POINT_HELPER(NegateOp, Input1)
ZERO_POINT_HELPER(NegateOp, Output)

ZERO_POINT_HELPER(RescaleOp, Input)
ZERO_POINT_HELPER(RescaleOp, Output)
#undef ZERO_POINT_HELPER

LogicalResult tosa::TransposeOp::inferReturnTypeComponents(
Expand Down Expand Up @@ -2784,41 +2831,21 @@ LogicalResult RescaleOp::verify() {
return failure();
}

auto input_zp = getInputZpAttr().getInt();
if (input_zp != 0) {
// only int8/uint8 and uint16 input can have non-zero input_zp
if (!inputElementType.isInteger(8) &&
!(inputElementType.isInteger(16) && getInputUnsigned())) {
emitOpError("expect input_zp of 0, got ") << input_zp;
return failure();
}
// input_zp must be either 0 or 32768 for uint16 input
if (inputElementType.isInteger(16) && getInputUnsigned() &&
input_zp != 32768) {
emitOpError(
"expect input_zp of 0 or 32768 for unsigned int16 input, got ")
<< input_zp;
return failure();
}
}
if (verifyRescaleValueAndZpTypes(*this, getInput(), getInputZp(), "input")
.failed())
return failure();

auto output_zp = getOutputZpAttr().getInt();
if (output_zp != 0) {
// only int8/uint8 and uint16 output can have non-zero output_zp
if (!outputElementType.isInteger(8) &&
!(outputElementType.isInteger(16) && getOutputUnsigned())) {
emitOpError("expect output_zp of 0, got ") << output_zp;
return failure();
}
// output_zp must be either 0 or 32768 for uint16 output
if (outputElementType.isInteger(16) && getOutputUnsigned() &&
output_zp != 32768) {
emitOpError(
"expect output_zp of 0 or 32768 for unsigned int16 output, got ")
<< output_zp;
return failure();
}
}
if (verifyRescaleValueAndZpTypes(*this, getOutput(), getOutputZp(), "output")
.failed())
return failure();

FailureOr<int64_t> maybeIZp = getInputZeroPoint();
if (succeeded(maybeIZp) && verifyInputZeroPoint(*maybeIZp).failed())
return failure();

FailureOr<int64_t> maybeOZp = getOutputZeroPoint();
if (succeeded(maybeOZp) && verifyOutputZeroPoint(*maybeOZp).failed())
return failure();

auto multiplierType = llvm::dyn_cast<ShapedType>(getMultiplier().getType());
if (!multiplierType) {
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Tosa/Transforms/TosaProfileCompliance.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ void ProfileInfoDepot::populateProfileInfo(tosa::SelectOp op) {
template <>
void ProfileInfoDepot::populateProfileInfo(tosa::RescaleOp op) {
addValue(op.getInput());
addValue(op.getInputZp());
addValue(op.getOutputZp());
addValue(op.getOutput());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,10 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
func.func @rescale_unsupported_type(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
%multiplier = "tosa.const"() {values = dense<1073741824> : tensor<1xi32> } : () -> tensor<1xi32>
%shift = "tosa.const"() {values = dense<30> : tensor<1xi8> } : () -> tensor<1xi8>
%input_zp = "tosa.const"() {values = dense<127> : tensor<1xi8>} : () -> tensor<1xi8>
%output_zp = "tosa.const"() {values = dense<-1> : tensor<1xi8>} : () -> tensor<1xi8>
// expected-error@+1 {{failed to legalize operation 'tosa.rescale'}}
%0 = tosa.rescale %arg0, %multiplier, %shift {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
%0 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {rounding_mode = "SINGLE_ROUND", input_zp = 127 : i32, output_zp = -1 : i32, per_channel = false, scale32 = true, input_unsigned = true, output_unsigned = false} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
}

Expand Down
Loading