Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -264,4 +264,11 @@ class Tosa_InferShapedTypeOp<string mnemonic, list<Trait> traits = []>
"operands attr-dict `:` functional-type(operands, results)";
}

// The "SameVariadicOperandSize" trait allows us to pass optional arguments
// for multiple zero points in convolution ops.
class Tosa_ConvOp<string mnemonic, list<Trait> traits = []>
: Tosa_InferShapedTypeOp<mnemonic, !listconcat(traits,
[SameVariadicOperandSize])> {
}

#endif // TOSA_OP_BASE
118 changes: 118 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Quant/IR/QuantTypes.h"
#include "mlir/Dialect/Traits.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -29,6 +30,7 @@
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Tosa/IR/TosaOpsDialect.h.inc"
#include "mlir/Transforms/DialectConversion.h"

namespace mlir {
class PatternRewriter;
Expand Down Expand Up @@ -152,4 +154,120 @@ bool isa_tosa_shape_type(mlir::Type t);
#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"

namespace mlir {
namespace tosa {

// Create a rank-1 const tensor for zero point of the source tensor.
std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
Type srcElemType, int64_t zp = 0);

// Get zero point value from the attribute argument.
LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);

// Verify if zero point falls into valid range.
template <typename T>
LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
if constexpr (!std::is_same_v<T, Conv2DOp> && !std::is_same_v<T, Conv3DOp> &&
!std::is_same_v<T, DepthwiseConv2DOp> &&
!std::is_same_v<T, TransposeConv2DOp>) {
return failure();
}

if (!zpElemType.isIntOrFloat())
return failure();

if (!zpElemType.isInteger(8) && zp != 0)
return failure();

if (zpElemType.isSignedInteger(8) && (zp < -128 || zp > 127))
return failure();

if (zpElemType.isUnsignedInteger(8) && (zp < 0 || zp > 255))
return failure();

return success();
}

// Helper type trait to determine if an operation is a tosa convolution.
template <typename Op>
struct IsTosaConv : std::false_type {};

template <>
struct IsTosaConv<tosa::Conv2DOp> : std::true_type {};
template <>
struct IsTosaConv<tosa::DepthwiseConv2DOp> : std::true_type {};
template <>
struct IsTosaConv<tosa::TransposeConv2DOp> : std::true_type {};
template <>
struct IsTosaConv<tosa::Conv3DOp> : std::true_type {};

template <typename Op>
constexpr bool is_tosa_conv_v = IsTosaConv<Op>::value;

// Helper struct to hold the zero points of a TOSA convolution operation as
// named 64-bit integer fields.
struct ConvZpPair {
ConvZpPair(std::int64_t inputZp, std::int64_t weightZp)
: inputZp(inputZp), weightZp(weightZp) {}
std::int64_t inputZp;
std::int64_t weightZp;
};

// Helper function which attempts to extract the zero points from a TOSA
// convolution by matching them against defining ops which should be tosa.const
// operations.
//
// There are three possible results:
// 1. Failed to extract the zero-points i.e. they should exist and don't or they
// do exist but are invalid.
// 2. Succeeded in extracting zero-points.
// 3. Zero points are "empty" and meaningless for this op i.e. non-quantized
// convolution.
using FailOrMaybeZP = llvm::FailureOr<std::optional<ConvZpPair>>;
template <typename TosaConvOp>
std::enable_if_t<is_tosa_conv_v<TosaConvOp>, FailOrMaybeZP>
extractConvZpPair(TosaConvOp op, PatternRewriter &rewriter) {
// Strictly speaking the base TOSA spec requires that for non int8 types
// zero points must be zero. However, in the dialect these operands are
// optional and only required for int8. They have no semantic meaning for
// non-quantized types and can therefore be safely ignored. This is case 3.
if (auto opElementTY =
cast<ShapedType>(op->getOperand(0).getType()).getElementType();
!opElementTY.isInteger(8))
return FailOrMaybeZP(std::nullopt);

// Now we know we should have a zero point check it is valid.
if (!op.getInputZp())
return rewriter.notifyMatchFailure(op, "missing input zero point");

// Helper to extract the zero point by matching its definition against a
// constant.
auto extractZeroPoint = [](Value zpValue) -> std::optional<int64_t> {
ElementsAttr zpAttr;
if (!matchPattern(zpValue, m_Constant(&zpAttr)))
return std::nullopt;

int64_t zp;
if (tosa::getZeroPoint(zpAttr, zp).failed())
return std::nullopt;

return std::make_optional(zp);
};

auto maybeInputZp = extractZeroPoint(op.getInputZp());
if (!maybeInputZp)
return rewriter.notifyMatchFailure(op, "unable to extract input zp");

if (!op.getWeightZp())
return rewriter.notifyMatchFailure(op, "missing weight zero point");

auto maybeWeightZp = extractZeroPoint(op.getWeightZp());
if (!maybeWeightZp)
return rewriter.notifyMatchFailure(op, "unable to extract weight zp");

return std::make_optional<ConvZpPair>(*maybeInputZp, *maybeWeightZp);
}
} // namespace tosa
} // namespace mlir

#endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H
22 changes: 13 additions & 9 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def Tosa_AvgPool2dOp : Tosa_InferShapedTypeOp<"avg_pool2d"> {
//===----------------------------------------------------------------------===//
// Operator: conv2d
//===----------------------------------------------------------------------===//
def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> {
let summary = "2D Convolution Operator";

let description = [{
Expand All @@ -104,11 +104,12 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);

Expand All @@ -123,7 +124,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
//===----------------------------------------------------------------------===//
// Operator: conv3d
//===----------------------------------------------------------------------===//
def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> {
let summary = "3D Convolution operator";

let description = [{
Expand All @@ -134,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);

Expand All @@ -153,7 +155,7 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
//===----------------------------------------------------------------------===//
// Operator: depthwise_conv2d
//===----------------------------------------------------------------------===//
def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> {
let summary = "Depthwise 2D Convolution operator";

let description = [{
Expand All @@ -165,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);

Expand Down Expand Up @@ -338,7 +341,7 @@ def Tosa_RFFT2dOp : Tosa_InferShapedTypeOp<"rfft2d"> {
//===----------------------------------------------------------------------===//
// Operator: transpose_conv2d
//===----------------------------------------------------------------------===//
def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> {
let summary = "Transpose 2D Convolution operator.";

let description = [{
Expand All @@ -348,13 +351,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {

let arguments = (ins
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
Optional<Tosa_ZeroPointTensor>:$input_zp,
Optional<Tosa_ZeroPointTensor>:$weight_zp,
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
TypeAttrOf<Tosa_AccType>:$acc_type,
OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);

Expand Down
5 changes: 5 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -288,4 +288,9 @@ def Rank1TosaShape : TosaShapeOfRank<1>;
def Rank2TosaShape : TosaShapeOfRank<2>;
def Rank4TosaShape : TosaShapeOfRank<4>;

// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this
// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the
// following def can be removed.
def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>;

#endif // TOSA_TYPES_BASE
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier,
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
Value input, Value weight);

std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
Value weight);

//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
Value a, Value b);
Expand Down
57 changes: 26 additions & 31 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,12 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
DenseI64ArrayAttr padAttr = op.getPadAttr();
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
bool isQuantized = op.getQuantizationInfo().has_value();

auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
if (llvm::failed(failureOrMaybeZps))
return failure();

auto maybeZps = failureOrMaybeZps.value();

if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
Expand All @@ -284,22 +289,19 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {

// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo = *op.getQuantizationInfo();
int64_t iZp = quantizationInfo.getInputZp();

if (maybeZps) {
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
int64_t intMax =
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();

if (iZp < intMin || iZp > intMax)
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.conv op quantization has zp outside of input range");

zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
}

llvm::SmallVector<int64_t> pad;
Expand All @@ -312,8 +314,8 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
// For 2D convolutions, we need to check if the target convolution op
// wants a HWCF kernel layout.
bool wantHwcf =
isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
maybeZps ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
if (wantHwcf) {
// Transpose the kernel to match dimension ordering of the linalg
// convolution operation.
Expand Down Expand Up @@ -374,10 +376,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);

if (isQuantized) {
auto quantizationInfo = *op.getQuantizationInfo();
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
if (maybeZps) {
auto iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
auto kZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);

auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
Expand Down Expand Up @@ -440,39 +441,31 @@ class DepthwiseConvConverter
/*inputSizeDims=*/{1, 2},
/*kernelSizeDims=*/{0, 1}, rewriter);

bool isQuantized = op->hasAttr("quantization_info");
IntegerAttr iZp;
IntegerAttr kZp;
if (isQuantized) {
auto quantizationInfo =
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
}
auto failureOrMaybeZps = extractConvZpPair(op, rewriter);
if (llvm::failed(failureOrMaybeZps))
return failure();

auto maybeZps = failureOrMaybeZps.value();

auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();

// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
if (isQuantized) {
auto quantizationInfo =
cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
int64_t iZp = quantizationInfo.getInputZp();

if (maybeZps) {
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
int64_t intMax =
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();

if (iZp < intMin || iZp > intMax)
if (maybeZps->inputZp < intMin || maybeZps->inputZp > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.depthwise_conv op quantization has zp outside of input "
"range");

zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
zeroAttr = rewriter.getIntegerAttr(inputETy, maybeZps->inputZp);
}

llvm::SmallVector<int64_t> pad;
Expand Down Expand Up @@ -512,7 +505,7 @@ class DepthwiseConvConverter
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));

if (!isQuantized) {
if (!maybeZps) {
Value conv = rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
loc, linalgConvTy, ValueRange{input, weight},
Expand All @@ -539,8 +532,10 @@ class DepthwiseConvConverter
.getResult(0);
rewriter.replaceOp(op, result);
} else {
IntegerAttr iZp = rewriter.getI32IntegerAttr(maybeZps->inputZp);
IntegerAttr wZp = rewriter.getI32IntegerAttr(maybeZps->weightZp);
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
Value conv =
rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
Expand Down
Loading