-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][tosa] Make Convolution Zero Points Inputs #122939
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Jack Frankland (FranklandJack) ChangesThe TOSA-v1.0 specification moves the the "zero point" parameters of the convolution opertors CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs. Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately. Patch is 172.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122939.diff 21 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 66512cbe350ec8..1c1c65f48cd0af 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -101,4 +101,39 @@ class TosaElementwiseOperator
#define GET_OP_CLASSES
#include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
+namespace mlir {
+namespace tosa {
+
+// Create a rank-0 const tensor for zero point of the source tensor.
+std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
+ Type srcElemType, int64_t zp = 0);
+
+// Get zero point value from the attribute argument.
+LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
+
+// Verify if zero point falls into valid range.
+template <typename T>
+LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
+ if constexpr (!std::is_same_v<T, Conv2DOp> &&
+ !std::is_same_v<T, Conv3DOp> &&
+ !std::is_same_v<T, DepthwiseConv2DOp> &&
+ !std::is_same_v<T, TransposeConv2DOp>) {
+ return failure();
+ }
+
+ if (!zpElemType.isIntOrFloat())
+ return failure();
+
+ if (!zpElemType.isInteger(8) && zp != 0)
+ return failure();
+
+ if (zp < -128 || zp > 127)
+ return failure();
+
+ return success();
+}
+
+} // namespace tosa
+} // namespace mlir
+
#endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6b43c9a259b108..372447e278c4e0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -103,11 +103,13 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
+ Tosa_ScalarTensor:$input_zp,
+ Tosa_ScalarTensor:$weight_zp,
+
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -133,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
Tosa_Tensor5D:$input,
TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
Tosa_Tensor1D:$bias,
+ Tosa_ScalarTensor:$input_zp,
+ Tosa_ScalarTensor:$weight_zp,
Tosa_IntArrayAttr6:$pad,
Tosa_IntArrayAttr3:$stride,
Tosa_IntArrayAttr3:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -164,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
Tosa_Tensor4D:$input,
TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
+ Tosa_ScalarTensor:$input_zp,
+ Tosa_ScalarTensor:$weight_zp,
Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
@@ -346,13 +350,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
let arguments = (ins
Tosa_Tensor4D:$input,
- TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
+ TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
Tosa_Tensor1D:$bias,
+ Tosa_ScalarTensor:$input_zp,
+ Tosa_ScalarTensor:$weight_zp,
Tosa_IntArrayAttr4:$out_pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr4:$out_shape,
TypeAttrOf<Tosa_AccType>:$acc_type,
- OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
);
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 5e80745777b3b3..10dc5dd36cfa96 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier,
ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
Value input, Value weight);
+std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
+ Value weight);
+
//// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
Value a, Value b);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d537aef5791031..fb213461b2acb7 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -21,6 +21,7 @@
#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -258,7 +259,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
DenseI64ArrayAttr padAttr = op.getPadAttr();
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
- bool isQuantized = op.getQuantizationInfo().has_value();
+
+ ElementsAttr inputZpAttr;
+ ElementsAttr weightZpAttr;
+ if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+ !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr)))
+ return rewriter.notifyMatchFailure(
+ op,
+ "bail out if the actual value of zero points cannot be determined");
+
+ // Get and verify explicit zero points.
+ int64_t inputZpVal;
+ int64_t weightZpVal;
+
+ if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+ tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf(inputZpAttr),
+ inputZpVal)
+ .failed())
+ return rewriter.notifyMatchFailure(
+ op, "input zero point must be zero for non-int8 integer types");
+
+ if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+ tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf(weightZpAttr),
+ weightZpVal)
+ .failed())
+ return rewriter.notifyMatchFailure(
+ op, "weight zero point must be zero for non-int8 integer types");
+
+ const bool hasZp =
+ (inputZpVal != 0) || (weightZpVal != 0) || isa<IntegerType>(inputETy);
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
@@ -284,10 +313,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
- if (isQuantized) {
- auto quantizationInfo = *op.getQuantizationInfo();
- int64_t iZp = quantizationInfo.getInputZp();
-
+ if (hasZp) {
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
@@ -295,11 +321,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
.getSExtValue();
- if (iZp < intMin || iZp > intMax)
+ if (inputZpVal < intMin || inputZpVal > intMax)
return rewriter.notifyMatchFailure(
op, "tosa.conv op quantization has zp outside of input range");
- zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
+ zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
}
llvm::SmallVector<int64_t> pad;
@@ -312,7 +338,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
// For 2D convolutions, we need to check if the target convolution op
// wants a HWCF kernel layout.
bool wantHwcf =
- isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+ hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
: std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
if (wantHwcf) {
// Transpose the kernel to match dimension ordering of the linalg
@@ -374,10 +400,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
- if (isQuantized) {
- auto quantizationInfo = *op.getQuantizationInfo();
- auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
- auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
+ if (hasZp) {
+ auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
+ auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -440,25 +465,40 @@ class DepthwiseConvConverter
/*inputSizeDims=*/{1, 2},
/*kernelSizeDims=*/{0, 1}, rewriter);
- bool isQuantized = op->hasAttr("quantization_info");
- IntegerAttr iZp;
- IntegerAttr kZp;
- if (isQuantized) {
- auto quantizationInfo =
- cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
- iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
- kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
- }
+ ElementsAttr inputZpAttr;
+ ElementsAttr weightZpAttr;
+ if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+ !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr)))
+ return rewriter.notifyMatchFailure(
+ op,
+ "bail out if the actual value of zero points cannot be determined");
+
+ // Get and verify explicit zero points.
+ int64_t inputZpVal;
+ int64_t weightZpVal;
+
+ if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+ tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
+ getElementTypeOrSelf(inputZpAttr), inputZpVal)
+ .failed())
+ return rewriter.notifyMatchFailure(
+ op, "input zero point must be zero for non-int8 integer types");
+
+ if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+ tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
+ getElementTypeOrSelf(weightZpAttr), weightZpVal)
+ .failed())
+ return rewriter.notifyMatchFailure(
+ op, "weight zero point must be zero for non-int8 integer types");
+ bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
auto weightShape = weightTy.getShape();
auto resultShape = resultTy.getShape();
// Apply padding as necessary.
TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
- if (isQuantized) {
- auto quantizationInfo =
- cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
- int64_t iZp = quantizationInfo.getInputZp();
+ if (inputZpVal) {
+ const int64_t iZp = inputZpVal;
int64_t intMin =
APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
@@ -512,7 +552,7 @@ class DepthwiseConvConverter
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
- if (!isQuantized) {
+ if (!hasZp && isa<FloatType>(inputETy)) {
Value conv = rewriter
.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
loc, linalgConvTy, ValueRange{input, weight},
@@ -539,6 +579,8 @@ class DepthwiseConvConverter
.getResult(0);
rewriter.replaceOp(op, result);
} else {
+ IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
+ IntegerAttr kZp = rewriter.getI32IntegerAttr(weightZpVal);
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
Value conv =
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 764a5db48e0787..c5c145a0eb329b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,11 +211,7 @@ static LogicalResult verifyConvOp(T op) {
// All TOSA conv ops have an input() and weight().
auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
- RankedTensorType weightType;
- if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
- weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
- else
- weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
+ auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
// Must be ranked tensor types
if (!inputType) {
@@ -223,18 +219,49 @@ static LogicalResult verifyConvOp(T op) {
return failure();
}
if (!weightType) {
- if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
- op.emitOpError("expect a ranked tensor for filter, got ")
- << op.getFilter();
- } else {
- op.emitOpError("expect a ranked tensor for weight, got ")
- << op.getWeight();
- }
+ op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
return failure();
}
auto inputEType = inputType.getElementType();
auto weightEType = weightType.getElementType();
+ auto biasEType =
+ llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
+ auto resultEType =
+ llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+ bool biasIsFloat = llvm::isa<FloatType>(biasEType);
+ bool resultIsFloat = llvm::isa<FloatType>(resultEType);
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+ inputEType = quantType.getStorageType();
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+ biasEType = quantType.getStorageType();
+
+ if (auto quantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+ resultEType = quantType.getStorageType();
+
+ if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
+ // for now, only enforce bias element type == result element type for
+ // float types.
+ op.emitOpError(
+ "expect both bias and result to have same element type, got ")
+ << biasEType << " and " << resultEType;
+ return failure();
+ }
+
+ if (inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3FN() ||
+ weightEType.isFloat8E5M2() || weightEType.isFloat8E4M3FN()) {
+ if (inputEType != weightEType) {
+ op.emitOpError(
+ "expect both input and weight to have same element type, got ")
+ << inputEType << " and " << weightEType;
+ return failure();
+ }
+ }
bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
@@ -247,14 +274,33 @@ static LogicalResult verifyConvOp(T op) {
return failure();
}
- // Quantized type must have constructed the quantizationattr, and unquantized
- // types should not have a quantizationattr.
- if ((inputIsQuant && !op.getQuantizationInfo()) ||
- (!inputIsQuant && op.getQuantizationInfo())) {
- op.emitOpError("quantizationattr is required for quantized type, and not "
- "allowed for float type");
+ ElementsAttr inputZpAttr;
+ ElementsAttr weightZpAttr;
+ if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+ !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
+ op.emitOpError(
+ "bail out if the actual value of zero points cannot be determined");
+ return failure();
+ }
+
+ // Get and verify explicit zero points.
+ int64_t inputZpVal;
+ int64_t weightZpVal;
+
+ if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+ tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
+ .failed()) {
+ op.emitOpError("input zero point must be zero for non-int8 integer types");
return failure();
}
+
+ if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+ tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr), weightZpVal)
+ .failed()) {
+ op.emitOpError("weight zero point must be zero for non-int8 integer types");
+ return failure();
+ }
+
return success();
}
@@ -314,6 +360,39 @@ static LogicalResult verifyConvOpModes(T op) {
return success();
}
+// verify that inType and outType have same element types
+template <typename T>
+static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
+ auto inputType = llvm::dyn_cast<TensorType>(inType);
+ auto outputType = llvm::dyn_cast<TensorType>(outType);
+ if (!inputType) {
+ op.emitOpError("expect shaped tensor for input, got ") << inType;
+ return failure();
+ }
+ if (!outputType) {
+ op.emitOpError("expect shaped tensor for output, got ") << outType;
+ return failure();
+ }
+ auto inputElementType = inputType.getElementType();
+ auto outputElementType = outputType.getElementType();
+ auto inputQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
+ auto outputQuantType =
+ llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
+ if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
+ (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
+ inputElementType != outputElementType) {
+ // only check if both element types are int/index/float/UniformQuantized
+ // eg, not sure how to check quant::QuantizedType
+ // this happens in test_conv2d_q_grouped_convolution in
+ // tfl-to-tosa-pipeline.mlir
+ op.emitOpError("expect input and output to have same element type, got ")
+ << inputElementType << " and " << outputElementType;
+ return failure();
+ }
+ return success();
+}
+
LogicalResult tosa::ArgMaxOp::verify() {
// Ensure output is of 32-bit integer
const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -413,21 +492,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
DenseI64ArrayAttr stride,
DenseI64ArrayAttr dilation,
TypeAttr accType) {
-
- result.addOperands({input, weight, bias});
+ auto zps = createZPsAsConst(builder, input, weight);
+ result.addOperands({input, weight, bias, zps.first, zps.second});
result.addAttribute("pad", pad);
result.addAttribute("stride", stride);
result.addAttribute("dilation", dilation);
result.addAttribute("acc_type", accType);
-
- auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
- if (quantAttr) {
- result.addAttribute("quantization_info", quantAttr);
- result.addTypes(
- buildConvOpResultTypeInfo(builder, outputType, input, weight));
- } else {
- result.addTypes(outputType);
- }
+ result.addTypes(outputType);
}
/// Handles tosa.transpose_conv2d which has outpad and output shape
@@ -782,7 +853,7 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
return success();
}
-LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
+LogicalResult FullyConnectedOp::verify() { return success(); }
LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
MLIRContext *context, ::std::optional<Location> location,
@@ -1850,7 +1921,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
}
// Weight shapes describes the filter width/height and the output channels.
- ShapeAdaptor weightShape(adaptor.getFilter().getType());
+ ShapeAdaptor weightShape(adaptor.getWeight().getType());
if (weightShape.hasRank()) {
outputShape[3] = ShapedType::isDynamic(outputShape[3])
? weightShape.getDimSize(0)
@@ -1906,10 +1977,8 @@ LogicalResult IfOp::inferReturnTypeComponents(
if (...
[truncated]
|
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
lhutton1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM modulo formatting issues
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to review the implementation and check if it reflects the spec, and if is layered as expected?
Happy to approve but I feel it doesn't align with the final spec but might be missing something.
@sjarus @eric-k256 could you please have a look as well?
c70c291 to
8b2cbe2
Compare
8b2cbe2 to
2624044
Compare
It looked ok to me. What layering do you have concerns with @GeorgeARM ? |
Need to review the new patch. The first instance from what I recall wasn't aligned with the spec. |
2624044 to
bbd925b
Compare
lhutton1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, I had another look at this and noticed something I didn't previously
4b6a660 to
bd7e1c3
Compare
lhutton1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! I added a couple of nitpick comments, not blocking on them though (I'll be away for 10 days from tomorrow). I think it would be good to add a note in the commit message about the other changes here just to highlight these changes to other users of TOSA:
- transpose_conv2d: filter -> weight
- removal of quantization_info on the conv ops
The TOSA-v1.0 specification moves the "zero point" parameters of the convolution operators CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs. Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately. Rename the "filter" argument of `tosa.transpose_conv2d` to weight to align with the TOSA specification. Remove the quantization_info attribute on the convolution operations.
bd7e1c3 to
9e8465e
Compare
lhutton1
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Thanks for all the changes
The TOSA-v1.0 specification moves the "zero point" parameters of the convolution operators CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs. Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately. Rename the "filter" argument of `tosa.transpose_conv2d` to weight to align with the TOSA specification. Remove the quantization_info attribute on the convolution operations. Co-authored-by: TatWai Chong <[email protected]>
The TOSA-v1.0 specification moves the the "zero point" parameters of the convolution opertors CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs.
Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately.