Skip to content

Conversation

@FranklandJack
Copy link
Contributor

The TOSA-v1.0 specification moves the the "zero point" parameters of the convolution opertors CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs.

Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately.

@llvmbot
Copy link
Member

llvmbot commented Jan 14, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: Jack Frankland (FranklandJack)

Changes

The TOSA-v1.0 specification moves the the "zero point" parameters of the convolution opertors CONV2D, CONV3D, DEPTHWISE_CONV2D, and TRANSPOSE_CONV2D from attributes to inputs.

Make the zero points of the convolutions in the MLIR TOSA dialect inputs and update any transformations, materializations and lit tests appropriately.


Patch is 172.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/122939.diff

21 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h (+35)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+10-5)
  • (modified) mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h (+3)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+68-26)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+150-37)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeConv2D.cpp (+39-5)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+36-8)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+57-42)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaValidation.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Tosa/Utils/QuantUtils.cpp (+58-10)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+44-22)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+8-4)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+45-38)
  • (modified) mlir/test/Dialect/Tosa/level_check.mlir (+108-72)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+16-8)
  • (modified) mlir/test/Dialect/Tosa/quant-test.mlir (+3-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-conv2d.mlir (+11-4)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+9-3)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+34-17)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+61-31)
  • (modified) mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir (+4-3)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
index 66512cbe350ec8..1c1c65f48cd0af 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h
@@ -101,4 +101,39 @@ class TosaElementwiseOperator
 #define GET_OP_CLASSES
 #include "mlir/Dialect/Tosa/IR/TosaOps.h.inc"
 
+namespace mlir {
+namespace tosa {
+
+// Create a rank-0 const tensor for zero point of the source tensor.
+std::optional<Value> createZeroPointTensor(OpBuilder &builder, Location loc,
+                                           Type srcElemType, int64_t zp = 0);
+
+// Get zero point value from the attribute argument.
+LogicalResult getZeroPoint(ElementsAttr zpAttr, int64_t &zp);
+
+// Verify if zero point falls into valid range.
+template <typename T>
+LogicalResult verifyZeroPoint(Type zpElemType, int64_t zp) {
+  if constexpr (!std::is_same_v<T, Conv2DOp> &&
+                !std::is_same_v<T, Conv3DOp> &&
+                !std::is_same_v<T, DepthwiseConv2DOp> &&
+                !std::is_same_v<T, TransposeConv2DOp>) {
+    return failure();
+  }
+
+  if (!zpElemType.isIntOrFloat())
+    return failure();
+
+  if (!zpElemType.isInteger(8) && zp != 0)
+    return failure();
+
+  if (zp < -128 || zp > 127)
+    return failure();
+
+  return success();
+}
+
+} // namespace tosa
+} // namespace mlir
+
 #endif // MLIR_DIALECT_TOSA_IR_TOSAOPS_H
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 6b43c9a259b108..372447e278c4e0 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -103,11 +103,13 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
+
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
@@ -133,11 +135,12 @@ def Tosa_Conv3DOp : Tosa_InferShapedTypeOp<"conv3d"> {
     Tosa_Tensor5D:$input,
     TosaTensorRankOf<[Tosa_Weight], [5]>:$weight,
     Tosa_Tensor1D:$bias,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
     Tosa_IntArrayAttr6:$pad,
     Tosa_IntArrayAttr3:$stride,
     Tosa_IntArrayAttr3:$dilation,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
@@ -164,11 +167,12 @@ def Tosa_DepthwiseConv2DOp : Tosa_InferShapedTypeOp<"depthwise_conv2d"> {
     Tosa_Tensor4D:$input,
     TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
@@ -346,13 +350,14 @@ def Tosa_TransposeConv2DOp : Tosa_InferShapedTypeOp<"transpose_conv2d"> {
 
   let arguments = (ins
     Tosa_Tensor4D:$input,
-    TosaTensorRankOf<[Tosa_Weight], [4]>:$filter,
+    TosaTensorRankOf<[Tosa_Weight], [4]>:$weight,
     Tosa_Tensor1D:$bias,
+    Tosa_ScalarTensor:$input_zp,
+    Tosa_ScalarTensor:$weight_zp,
     Tosa_IntArrayAttr4:$out_pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr4:$out_shape,
     TypeAttrOf<Tosa_AccType>:$acc_type,
-    OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
 
diff --git a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
index 5e80745777b3b3..10dc5dd36cfa96 100644
--- a/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
+++ b/mlir/include/mlir/Dialect/Tosa/Utils/QuantUtils.h
@@ -35,6 +35,9 @@ void computeMultiplierAndShift(double scale, int32_t &multiplier,
 ConvOpQuantizationAttr buildConvOpQuantizationAttr(OpBuilder &builder,
                                                    Value input, Value weight);
 
+std::pair<Value, Value> createZPsAsConst(OpBuilder &builder, Value input,
+                                         Value weight);
+
 //// Builds MatMulOpQuantizationAttr for MatMul operations from A and B.
 MatMulOpQuantizationAttr buildMatMulOpQuantizationAttr(OpBuilder &builder,
                                                        Value a, Value b);
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index d537aef5791031..fb213461b2acb7 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -21,6 +21,7 @@
 #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h"
 #include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
+#include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Transforms/DialectConversion.h"
@@ -258,7 +259,35 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     DenseI64ArrayAttr padAttr = op.getPadAttr();
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
-    bool isQuantized = op.getQuantizationInfo().has_value();
+
+    ElementsAttr inputZpAttr;
+    ElementsAttr weightZpAttr;
+    if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+        !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr)))
+      return rewriter.notifyMatchFailure(
+          op,
+          "bail out if the actual value of zero points cannot be determined");
+
+    // Get and verify explicit zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+        tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf(inputZpAttr),
+                                          inputZpVal)
+            .failed())
+      return rewriter.notifyMatchFailure(
+          op, "input zero point must be zero for non-int8 integer types");
+
+    if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+        tosa::verifyZeroPoint<TosaConvOp>(getElementTypeOrSelf(weightZpAttr),
+                                          weightZpVal)
+            .failed())
+      return rewriter.notifyMatchFailure(
+          op, "weight zero point must be zero for non-int8 integer types");
+
+    const bool hasZp =
+        (inputZpVal != 0) || (weightZpVal != 0) || isa<IntegerType>(inputETy);
 
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
@@ -284,10 +313,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (isQuantized) {
-      auto quantizationInfo = *op.getQuantizationInfo();
-      int64_t iZp = quantizationInfo.getInputZp();
-
+    if (hasZp) {
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
@@ -295,11 +321,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
           APInt::getSignedMaxValue(inputETy.getIntOrFloatBitWidth())
               .getSExtValue();
 
-      if (iZp < intMin || iZp > intMax)
+      if (inputZpVal < intMin || inputZpVal > intMax)
         return rewriter.notifyMatchFailure(
             op, "tosa.conv op quantization has zp outside of input range");
 
-      zeroAttr = rewriter.getIntegerAttr(inputETy, iZp);
+      zeroAttr = rewriter.getIntegerAttr(inputETy, inputZpVal);
     }
 
     llvm::SmallVector<int64_t> pad;
@@ -312,7 +338,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       // For 2D convolutions, we need to check if the target convolution op
       // wants a HWCF kernel layout.
       bool wantHwcf =
-          isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+          hasZp ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
                       : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
       if (wantHwcf) {
         // Transpose the kernel to match dimension ordering of the linalg
@@ -374,10 +400,9 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
-    if (isQuantized) {
-      auto quantizationInfo = *op.getQuantizationInfo();
-      auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
-      auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
+    if (hasZp) {
+      auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
+      auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
 
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
@@ -440,25 +465,40 @@ class DepthwiseConvConverter
         /*inputSizeDims=*/{1, 2},
         /*kernelSizeDims=*/{0, 1}, rewriter);
 
-    bool isQuantized = op->hasAttr("quantization_info");
-    IntegerAttr iZp;
-    IntegerAttr kZp;
-    if (isQuantized) {
-      auto quantizationInfo =
-          cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
-      iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
-      kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
-    }
+    ElementsAttr inputZpAttr;
+    ElementsAttr weightZpAttr;
+    if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+        !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr)))
+      return rewriter.notifyMatchFailure(
+          op,
+          "bail out if the actual value of zero points cannot be determined");
+
+    // Get and verify explicit zero points.
+    int64_t inputZpVal;
+    int64_t weightZpVal;
+
+    if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+        tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
+            getElementTypeOrSelf(inputZpAttr), inputZpVal)
+            .failed())
+      return rewriter.notifyMatchFailure(
+          op, "input zero point must be zero for non-int8 integer types");
+
+    if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+        tosa::verifyZeroPoint<tosa::DepthwiseConv2DOp>(
+            getElementTypeOrSelf(weightZpAttr), weightZpVal)
+            .failed())
+      return rewriter.notifyMatchFailure(
+          op, "weight zero point must be zero for non-int8 integer types");
 
+    bool hasZp = (inputZpVal != 0) || (weightZpVal != 0);
     auto weightShape = weightTy.getShape();
     auto resultShape = resultTy.getShape();
 
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
-    if (isQuantized) {
-      auto quantizationInfo =
-          cast<tosa::ConvOpQuantizationAttr>(op->getAttr("quantization_info"));
-      int64_t iZp = quantizationInfo.getInputZp();
+    if (inputZpVal) {
+      const int64_t iZp = inputZpVal;
 
       int64_t intMin =
           APInt::getSignedMinValue(inputETy.getIntOrFloatBitWidth())
@@ -512,7 +552,7 @@ class DepthwiseConvConverter
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
 
-    if (!isQuantized) {
+    if (!hasZp && isa<FloatType>(inputETy)) {
       Value conv = rewriter
                        .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
                            loc, linalgConvTy, ValueRange{input, weight},
@@ -539,6 +579,8 @@ class DepthwiseConvConverter
               .getResult(0);
       rewriter.replaceOp(op, result);
     } else {
+      IntegerAttr iZp = rewriter.getI32IntegerAttr(inputZpVal);
+      IntegerAttr kZp = rewriter.getI32IntegerAttr(weightZpVal);
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
       Value conv =
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 764a5db48e0787..c5c145a0eb329b 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -211,11 +211,7 @@ static LogicalResult verifyConvOp(T op) {
   // All TOSA conv ops have an input() and weight().
   auto inputType = llvm::dyn_cast<RankedTensorType>(op.getInput().getType());
 
-  RankedTensorType weightType;
-  if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>)
-    weightType = llvm::dyn_cast<RankedTensorType>(op.getFilter().getType());
-  else
-    weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
+  auto weightType = llvm::dyn_cast<RankedTensorType>(op.getWeight().getType());
 
   // Must be ranked tensor types
   if (!inputType) {
@@ -223,18 +219,49 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
   if (!weightType) {
-    if constexpr (std::is_same_v<T, tosa::TransposeConv2DOp>) {
-      op.emitOpError("expect a ranked tensor for filter, got ")
-          << op.getFilter();
-    } else {
-      op.emitOpError("expect a ranked tensor for weight, got ")
-          << op.getWeight();
-    }
+    op.emitOpError("expect a ranked tensor for weight, got ") << op.getWeight();
     return failure();
   }
 
   auto inputEType = inputType.getElementType();
   auto weightEType = weightType.getElementType();
+  auto biasEType =
+      llvm::cast<ShapedType>(op.getBias().getType()).getElementType();
+  auto resultEType =
+      llvm::cast<ShapedType>(op.getResult().getType()).getElementType();
+  bool biasIsFloat = llvm::isa<FloatType>(biasEType);
+  bool resultIsFloat = llvm::isa<FloatType>(resultEType);
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType))
+    inputEType = quantType.getStorageType();
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(biasEType))
+    biasEType = quantType.getStorageType();
+
+  if (auto quantType =
+          llvm::dyn_cast<mlir::quant::UniformQuantizedType>(resultEType))
+    resultEType = quantType.getStorageType();
+
+  if (biasIsFloat && resultIsFloat && (biasEType != resultEType)) {
+    // for now, only enforce bias element type == result element type for
+    // float types.
+    op.emitOpError(
+        "expect both bias and result to have same element type, got ")
+        << biasEType << " and " << resultEType;
+    return failure();
+  }
+
+  if (inputEType.isFloat8E5M2() || inputEType.isFloat8E4M3FN() ||
+      weightEType.isFloat8E5M2() || weightEType.isFloat8E4M3FN()) {
+    if (inputEType != weightEType) {
+      op.emitOpError(
+          "expect both input and weight to have same element type, got ")
+          << inputEType << " and " << weightEType;
+      return failure();
+    }
+  }
 
   bool inputIsQuant = !llvm::isa<FloatType>(inputEType);
   bool weightIsQuant = !llvm::isa<FloatType>(weightEType);
@@ -247,14 +274,33 @@ static LogicalResult verifyConvOp(T op) {
     return failure();
   }
 
-  // Quantized type must have constructed the quantizationattr, and unquantized
-  // types should not have a quantizationattr.
-  if ((inputIsQuant && !op.getQuantizationInfo()) ||
-      (!inputIsQuant && op.getQuantizationInfo())) {
-    op.emitOpError("quantizationattr is required for quantized type, and not "
-                   "allowed for float type");
+  ElementsAttr inputZpAttr;
+  ElementsAttr weightZpAttr;
+  if (!matchPattern(op.getInputZp(), m_Constant(&inputZpAttr)) ||
+      !matchPattern(op.getWeightZp(), m_Constant(&weightZpAttr))) {
+    op.emitOpError(
+        "bail out if the actual value of zero points cannot be determined");
+    return failure();
+  }
+
+  // Get and verify explicit zero points.
+  int64_t inputZpVal;
+  int64_t weightZpVal;
+
+  if (tosa::getZeroPoint(inputZpAttr, inputZpVal).failed() ||
+      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(inputZpAttr), inputZpVal)
+          .failed()) {
+    op.emitOpError("input zero point must be zero for non-int8 integer types");
     return failure();
   }
+
+  if (tosa::getZeroPoint(weightZpAttr, weightZpVal).failed() ||
+      tosa::verifyZeroPoint<T>(getElementTypeOrSelf(weightZpAttr), weightZpVal)
+          .failed()) {
+    op.emitOpError("weight zero point must be zero for non-int8 integer types");
+    return failure();
+  }
+
   return success();
 }
 
@@ -314,6 +360,39 @@ static LogicalResult verifyConvOpModes(T op) {
   return success();
 }
 
+// verify that inType and outType have same element types
+template <typename T>
+static LogicalResult verifySameElementTypes(T op, Type inType, Type outType) {
+  auto inputType = llvm::dyn_cast<TensorType>(inType);
+  auto outputType = llvm::dyn_cast<TensorType>(outType);
+  if (!inputType) {
+    op.emitOpError("expect shaped tensor for input, got ") << inType;
+    return failure();
+  }
+  if (!outputType) {
+    op.emitOpError("expect shaped tensor for output, got ") << outType;
+    return failure();
+  }
+  auto inputElementType = inputType.getElementType();
+  auto outputElementType = outputType.getElementType();
+  auto inputQuantType =
+      llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputElementType);
+  auto outputQuantType =
+      llvm::dyn_cast<mlir::quant::UniformQuantizedType>(outputElementType);
+  if ((inputElementType.isIntOrIndexOrFloat() || inputQuantType) &&
+      (outputElementType.isIntOrIndexOrFloat() || outputQuantType) &&
+      inputElementType != outputElementType) {
+    // only check if both element types are int/index/float/UniformQuantized
+    // eg, not sure how to check quant::QuantizedType
+    // this happens in test_conv2d_q_grouped_convolution in
+    // tfl-to-tosa-pipeline.mlir
+    op.emitOpError("expect input and output to have same element type, got ")
+        << inputElementType << " and " << outputElementType;
+    return failure();
+  }
+  return success();
+}
+
 LogicalResult tosa::ArgMaxOp::verify() {
   // Ensure output is of 32-bit integer
   const auto resultETy = llvm::cast<ShapedType>(getType()).getElementType();
@@ -413,21 +492,13 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
                                      DenseI64ArrayAttr stride,
                                      DenseI64ArrayAttr dilation,
                                      TypeAttr accType) {
-
-  result.addOperands({input, weight, bias});
+  auto zps = createZPsAsConst(builder, input, weight);
+  result.addOperands({input, weight, bias, zps.first, zps.second});
   result.addAttribute("pad", pad);
   result.addAttribute("stride", stride);
   result.addAttribute("dilation", dilation);
   result.addAttribute("acc_type", accType);
-
-  auto quantAttr = buildConvOpQuantizationAttr(builder, input, weight);
-  if (quantAttr) {
-    result.addAttribute("quantization_info", quantAttr);
-    result.addTypes(
-        buildConvOpResultTypeInfo(builder, outputType, input, weight));
-  } else {
-    result.addTypes(outputType);
-  }
+  result.addTypes(outputType);
 }
 
 /// Handles tosa.transpose_conv2d which has outpad and output shape
@@ -782,7 +853,7 @@ LogicalResult tosa::FullyConnectedOp::inferReturnTypeComponents(
   return success();
 }
 
-LogicalResult FullyConnectedOp::verify() { return verifyConvOp(*this); }
+LogicalResult FullyConnectedOp::verify() { return success(); }
 
 LogicalResult tosa::MatMulOp::inferReturnTypeComponents(
     MLIRContext *context, ::std::optional<Location> location,
@@ -1850,7 +1921,7 @@ LogicalResult TransposeConv2DOp::inferReturnTypeComponents(
   }
 
   // Weight shapes describes the filter width/height and the output channels.
-  ShapeAdaptor weightShape(adaptor.getFilter().getType());
+  ShapeAdaptor weightShape(adaptor.getWeight().getType());
   if (weightShape.hasRank()) {
     outputShape[3] = ShapedType::isDynamic(outputShape[3])
                          ? weightShape.getDimSize(0)
@@ -1906,10 +1977,8 @@ LogicalResult IfOp::inferReturnTypeComponents(
       if (...
[truncated]

@github-actions
Copy link

github-actions bot commented Jan 14, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo formatting issues

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to review the implementation and check if it reflects the spec, and if is layered as expected?
Happy to approve but I feel it doesn't align with the final spec but might be missing something.
@sjarus @eric-k256 could you please have a look as well?

@sjarus
Copy link
Contributor

sjarus commented Jan 23, 2025

Is it possible to review the implementation and check if it reflects the spec, and if is layered as expected?

Happy to approve but I feel it doesn't align with the final spec but might be missing something.

@sjarus @eric-k256 could you please have a look as well?

It looked ok to me. What layering do you have concerns with @GeorgeARM ?

@GeorgeARM
Copy link
Contributor

Is it possible to review the implementation and check if it reflects the spec, and if is layered as expected?
Happy to approve but I feel it doesn't align with the final spec but might be missing something.
@sjarus @eric-k256 could you please have a look as well?

It looked ok to me. What layering do you have concerns with @GeorgeARM ?

Need to review the new patch. The first instance from what I recall wasn't aligned with the spec.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, I had another look at this and noticed something I didn't previously

@FranklandJack FranklandJack force-pushed the zero_points branch 4 times, most recently from 4b6a660 to bd7e1c3 Compare January 30, 2025 11:31
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! I added a couple of nitpick comments, not blocking on them though (I'll be away for 10 days from tomorrow). I think it would be good to add a note in the commit message about the other changes here just to highlight these changes to other users of TOSA:

  • transpose_conv2d: filter -> weight
  • removal of quantization_info on the conv ops

The TOSA-v1.0 specification moves the "zero point" parameters of the
convolution operators CONV2D, CONV3D, DEPTHWISE_CONV2D, and
TRANSPOSE_CONV2D from attributes to inputs.

Make the zero points of the convolutions in the MLIR TOSA dialect inputs
and update any transformations, materializations and lit tests
appropriately.

Rename the "filter" argument of `tosa.transpose_conv2d` to weight to
align with the TOSA specification.

Remove the quantization_info attribute on the convolution operations.
Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for all the changes

@FranklandJack FranklandJack merged commit cc72042 into llvm:main Feb 3, 2025
8 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
The TOSA-v1.0 specification moves the "zero point" parameters of the
convolution operators CONV2D, CONV3D, DEPTHWISE_CONV2D, and
TRANSPOSE_CONV2D from attributes to inputs.

Make the zero points of the convolutions in the MLIR TOSA dialect inputs
and update any transformations, materializations and lit tests
appropriately.

Rename the "filter" argument of `tosa.transpose_conv2d` to weight to
align with the TOSA specification.

Remove the quantization_info attribute on the convolution operations.

Co-authored-by: TatWai Chong <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants