Skip to content

Conversation

@stefankoncarevic
Copy link
Contributor

This patch adds two new ops: linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp, and uses them to convert tosa group conv2d Ops.

  • Added linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp.
  • Updated the conversion process to use these new ops for tosa group conv2d operations.

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: None (stefankoncarevic)

Changes

This patch adds two new ops: linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp, and uses them to convert tosa group conv2d Ops.

  • Added linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp.
  • Updated the conversion process to use these new ops for tosa group conv2d operations.

Patch is 26.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108192.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+237)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+12)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+103-42)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+13)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+3-2)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+54)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+16)
  • (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 8cb698096ef5b7..011c4858d6521b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -3410,6 +3410,243 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: K
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_nhwgc_gfhwc
+  cpp_class_name: Conv2DNhwgcGfhwcOp
+  doc: |-
+    Performs 2-D grouped convolution.
+
+    Layout:
+      * Input: NHWGC.
+      * Kernel: GFHWC.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s9, s11, s3, s7, s10)>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s5, s9, s11)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s2, s6)>
+    default_indices:
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s4, s8)>
+    default_indices:
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: I
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: K
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_nhwgc_gfhwc_q
+  cpp_class_name: Conv2DNhwgcGfhwcQOp
+  doc: |-
+    Performs 2-D grouped convolution with zero point offsets.
+
+    Layout:
+      * Input: NHWGC.
+      * Kernel: GFHWC.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s9, s11, s3, s7, s10)>
+  - !LinalgOperandDefConfig
+    name: IZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: KZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s5, s9, s11)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s2, s6)>
+    default_indices:
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s4, s8)>
+    default_indices:
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: I
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: IZp
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: K
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: KZp
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_ngchw_gfchw_q
   cpp_class_name: Conv2DNgchwGfchwQOp
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..d4697f0afbf466 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -133,6 +133,18 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
                              pad, stride, dilation);
   }]>;
 
+// Handles grouped convolution
+def Tosa_ConvOpGroupQuantBuilder : OpBuilder<
+  (ins "::mlir::Type":$outputType, "::mlir::Value":$input,
+       "::mlir::Value":$weight, "::mlir::Value":$bias,
+       "::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride,
+       "::mlir::DenseI64ArrayAttr":$dilation, "::mlir::IntegerAttr":$group),
+  [{
+    buildConvOpWithQuantInfo($_builder, $_state, outputType,
+                             input, weight, bias,
+                             pad, stride, dilation, group);
+  }]>;
+
 // Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
 def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
   (ins "::mlir::Type":$outputType, "::mlir::Value":$input,
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..0b67019fd0c7bb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -108,6 +108,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
+    OptionalAttr<I64Attr>:$group,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
@@ -116,7 +117,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
     Tosa_Tensor4D:$output
   );
 
-  let builders = [Tosa_ConvOpQuantInfoBuilder];
+  let builders = [Tosa_ConvOpQuantInfoBuilder, Tosa_ConvOpGroupQuantBuilder];
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 77c3d2e8757910..898bed4a895864 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -236,6 +236,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
   LogicalResult
   matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
+    bool isConv2DOp = isa<tosa::Conv2DOp>(op);
     Location loc = op->getLoc();
     Value input = op->getOperand(0);
     Value weight = op->getOperand(1);
@@ -253,6 +254,24 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
     bool isQuantized = op.getQuantizationInfo().has_value();
+    int64_t group = 1;
+
+    if (auto convop = dyn_cast<tosa::Conv2DOp>(&op)) {
+      if (convop->getGroup().has_value())
+        group = convop->getGroup().value();
+    }
+
+    if (group > 1 && isConv2DOp &&
+        !std::is_same<LinalgConvOp, linalg::Conv2DNhwgcGfhwcOp>::value &&
+        !std::is_same<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>::value)
+      return rewriter.notifyMatchFailure(
+          op, "tosa.conv ops should map to grouped convolution ops");
+
+    if (group == 1 && isConv2DOp &&
+        !std::is_same<LinalgConvOp, linalg::Conv2DNhwcFhwcOp>::value &&
+        !std::is_same<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>::value)
+      return rewriter.notifyMatchFailure(
+          op, "tosa.conv ops should map to non-grouped convolution ops");
 
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
@@ -274,8 +293,6 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
         inputSizeDims, kernelSizeDims, rewriter);
 
-    auto weightShape = weightTy.getShape();
-
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
     if (isQuantized) {
@@ -302,15 +319,64 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     pad.resize(pad.size() + 2, 0);
     input = applyPad(loc, input, pad, zeroAttr, rewriter);
 
-    if (4 == inputTy.getRank()) {
-      // For 2D convolutions, we need to check if the target convolution op
-      // wants a HWCF kernel layout.
-      bool wantHwcf =
-          isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
-                      : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
-      if (wantHwcf) {
-        // Transpose the kernel to match dimension ordering of the linalg
-        // convolution operation.
+    auto weightShape = weightTy.getShape();
+    SmallVector<int64_t> weightPerm;
+
+    auto resultShape = resultTy.getShape();
+    auto newResultTy = resultTy;
+
+    if (isConv2DOp && group > 1) {
+      // Map 4D-tensors to 5D tensors
+      auto inputShape = cast<ShapedType>(input.getType()).getShape();
+      SmallVector<int64_t, 5> newInputShape = {inputShape[0], inputShape[1],
+                                               inputShape[2], group,
+                                               inputShape[3] / group};
+
+      SmallVector<int64_t, 5> newWeightShape = {group, weightShape[0] / group,
+                                                weightShape[1], weightShape[2],
+                                                weightShape[3]};
+      input = rewriter.create<tosa::ReshapeOp>(
+          loc, RankedTensorType::get(newInputShape, inputETy), input,
+          rewriter.getDenseI64ArrayAttr(newInputShape));
+      weight = rewriter.create<tosa::ReshapeOp>(
+          loc, RankedTensorType::get(newWeightShape, weightTy.getElementType()),
+          weight, rewriter.getDenseI64ArrayAttr(newWeightShape));
+    } else {
+
+      if (4 == inputTy.getRank()) {
+        // For 2D convolutions, we need to check if the target convolution op
+        // wants a HWCF kernel layout.
+        bool wantHwcf =
+            isQuantized
+                ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+                : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+        if (wantHwcf) {
+          // Transpose the kernel to match dimension ordering of the linalg
+          // convolution operation.
+          // TODO(suderman): See if this can be efficiently folded - check
+          // whether the input is used anywhere else, if not fold the constant.
+          SmallVector<int64_t> weightPerm;
+          for (int i = 1; i < resultTy.getRank(); i++)
+            weightPerm.push_back(i);
+          weightPerm.push_back(0);
+
+          SmallVector<int64_t> newWeightShape;
+          for (auto dim : weightPerm)
+            newWeightShape.push_back(weightShape[dim]);
+          auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+          Value weightPermValue =
+              rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+          Type newWeightTy =
+              RankedTensorType::get(newWeightShape, weightTy.getElementType());
+          weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+                                                      weightPermValue);
+        }
+      }
+
+      // For Conv3D transpose the kernel to match dimension ordering of the
+      // linalg convolution operation. Conv2D has a 1-1 mapping in linalg so
+      // better to map directly and then transpose later if desired.
+      if (5 == inputTy.getRank()) {
         // TODO(suderman): See if this can be efficiently folded - check whether
         // the input is used anywhere else, if not fold the constant.
         SmallVector<int64_t> weightPerm;
@@ -331,27 +397,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       }
     }
 
-    // For Conv3D transpose the kernel to match dimension ordering of the linalg
-    // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
-    // map directly and then transpose later if desired.
-    if (5 == inputTy.getRank()) {
-      // TODO(suderman): See if this can be efficiently folded - check whether
-      // the input is used anywhere else, if not fold the constant.
-      SmallVector<int64_t> weightPerm;
-      for (int i = 1; i < resultTy.getRank(); i++)
-        weightPerm.push_back(i);
-      weightPerm.push_back(0);
-
-      SmallVector<int64_t> newWeightShape;
-      for (auto dim : weightPerm)
-        newWeightShape.push_back(weightShape[dim]);
-      auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
-      Value weightPermValue =
-          rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
-      Type newWeightTy =
-          RankedTensorType::get(newWeightShape, weightTy.getElementType());
-      weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                  weightPermValue);
+    if (isConv2DOp && group > 1) {
+      SmallVector<int64_t, 5> newResultShape{resultShape[0], resultShape[1],
+                                             resultShape[2], group,
+                                             resultShape[3] / group};
+      newResultTy = RankedTensorType::get(newResultShape, resultETy);
     }
 
     // Extract the attributes for convolution.
@@ -368,6 +418,13 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
+    if (isConv2DOp && group > 1) {
+      broadcastBias = rewriter.create<tosa::ReshapeOp>(
+          loc, RankedTensorType::get(newResultTy.getShape(), resultETy),
+          broadcastBias, rewriter.getDenseI64ArrayAttr(newResultTy.getShape()));
+    }
+
+    Value conv;
     if (isQuantized) {
       auto quantizationInfo = *op.getQuantizationInfo();
       auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
@@ -376,22 +433,25 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
 
-      Value conv =
+      conv =
           rewriter
               .create<LinalgConvQOp>(
-                  loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
+                  loc, newResultTy, ValueRange{input, weight, iZpVal, kZpVal},
                   ValueRange{broadcastBias}, strideAttr, dilationAttr)
               ->getResult(0);
-
-      rewriter.replaceOp(op, conv);
-      return success();
+    } else {
+      conv = rewriter
+                 .create<LinalgConvOp>(
+                     loc, newResultTy, ValueRange{input, weight},
+                     ValueRange{broadcastBias}, strideAttr, dilationAttr)
+                 ->getResult(0);
     }
 
-    Value conv = rewriter
-                     .create<LinalgConvOp>(
-                         loc, resultTy, ValueRange{input, weight},
-                         ValueRange{broadcastBias}, strideAttr, dilationAttr)
-                     ->getResult(0);
+    if (isConv2DOp && group > 1) {
+      conv = rewriter.create<tosa::ReshapeOp>(
+          loc, RankedTensorType::get(resultShape, resultETy), conv,
+          rewriter.getDenseI64ArrayAttr(resultShape));
+    }
 
     rewriter.replaceOp(op, conv);
     return success();
@@ -1074,6 +1134,7 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
   }
   patterns->add<
       // clang-format off
+      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwgcGfhwcOp, linalg::Conv2DNhwgcGfhwcQOp>,
       ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
       DepthwiseConvConverter,
       MatMulConverter,
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d93db1b237f316..cedfa8a5afd110 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -383,6 +383,19 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   }
 }
 
+// Handles grouped convolution
+static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+                                     Type outputType, Value input, Value weight,...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 11, 2024

@llvm/pr-subscribers-mlir

Author: None (stefankoncarevic)

Changes

This patch adds two new ops: linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp, and uses them to convert tosa group conv2d Ops.

  • Added linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp.
  • Updated the conversion process to use these new ops for tosa group conv2d operations.

Patch is 26.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/108192.diff

9 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+237)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td (+12)
  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+2-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+103-42)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+13)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+3-2)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+54)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+16)
  • (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+1-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 8cb698096ef5b7..011c4858d6521b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -3410,6 +3410,243 @@ structured_op: !LinalgStructuredOpConfig
                 - !ScalarExpression
                   scalar_arg: K
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_nhwgc_gfhwc
+  cpp_class_name: Conv2DNhwgcGfhwcOp
+  doc: |-
+    Performs 2-D grouped convolution.
+
+    Layout:
+      * Input: NHWGC.
+      * Kernel: GFHWC.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s9, s11, s3, s7, s10)>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s5, s9, s11)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s2, s6)>
+    default_indices:
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s4, s8)>
+    default_indices:
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: I
+            - !ScalarExpression
+              scalar_fn:
+                kind: type
+                fn_name: cast_signed
+                type_var: U
+                operands:
+                - !ScalarExpression
+                  scalar_arg: K
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: conv_2d_nhwgc_gfhwc_q
+  cpp_class_name: Conv2DNhwgcGfhwcQOp
+  doc: |-
+    Performs 2-D grouped convolution with zero point offsets.
+
+    Layout:
+      * Input: NHWGC.
+      * Kernel: GFHWC.
+
+    Numeric casting is performed on the operands to the inner multiply, promoting
+    them to the same data type as the accumulator/output. This includes the zero
+    point offsets common to quantized operations.
+  implements:
+  - LinalgConvolutionOpInterface
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1 * s2 + s3 * s4, s5 * s6 + s7 * s8, s9, s10)>
+  - !LinalgOperandDefConfig
+    name: K
+    kind: input_tensor
+    type_var: T2
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s9, s11, s3, s7, s10)>
+  - !LinalgOperandDefConfig
+    name: IZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: KZp
+    kind: scalar
+    type_var: I32
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: U
+    shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11] ->
+      (s0, s1, s5, s9, s11)>
+  - !LinalgOperandDefConfig
+    name: strides
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s2, s6)>
+    default_indices:
+    - 1
+    - 1
+  - !LinalgOperandDefConfig
+    name: dilations
+    kind: index_attr
+    index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10, s11]
+      -> (s4, s8)>
+    default_indices:
+    - 1
+    - 1
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1 * s2 + d5 * s4, d2 * s6 + d6 * s8, d3, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d3, d4, d5, d6, d7)>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> ()>
+    - affine_map<(d0, d1, d2, d3, d4, d5, d6, d7)[s0, s1, s2, s3, s4, s5, s6, s7,
+      s8, s9, s10, s11] -> (d0, d1, d2, d3, d4)>
+  iterator_types:
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - parallel
+  - reduction
+  - reduction
+  - reduction
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: add
+        operands:
+        - !ScalarExpression
+          scalar_arg: O
+        - !ScalarExpression
+          scalar_fn:
+            kind: binary
+            fn_name: mul
+            operands:
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: I
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: IZp
+            - !ScalarExpression
+              scalar_fn:
+                kind: binary
+                fn_name: sub
+                operands:
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: K
+                - !ScalarExpression
+                  scalar_fn:
+                    kind: type
+                    fn_name: cast_signed
+                    type_var: U
+                    operands:
+                    - !ScalarExpression
+                      scalar_arg: KZp
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: conv_2d_ngchw_gfchw_q
   cpp_class_name: Conv2DNgchwGfchwQOp
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
index 1412c7a2615d20..d4697f0afbf466 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOpBase.td
@@ -133,6 +133,18 @@ def Tosa_ConvOpQuantInfoBuilder : OpBuilder<
                              pad, stride, dilation);
   }]>;
 
+// Handles grouped convolution
+def Tosa_ConvOpGroupQuantBuilder : OpBuilder<
+  (ins "::mlir::Type":$outputType, "::mlir::Value":$input,
+       "::mlir::Value":$weight, "::mlir::Value":$bias,
+       "::mlir::DenseI64ArrayAttr":$pad, "::mlir::DenseI64ArrayAttr":$stride,
+       "::mlir::DenseI64ArrayAttr":$dilation, "::mlir::IntegerAttr":$group),
+  [{
+    buildConvOpWithQuantInfo($_builder, $_state, outputType,
+                             input, weight, bias,
+                             pad, stride, dilation, group);
+  }]>;
+
 // Handles tosa.transpose_conv2d which has an outpad and output shape attribute.
 def Tosa_TransConvOpQuantInfoBuilder : OpBuilder<
   (ins "::mlir::Type":$outputType, "::mlir::Value":$input,
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index ab6daa39708d13..0b67019fd0c7bb 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -108,6 +108,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
     Tosa_IntArrayAttr4:$pad,
     Tosa_IntArrayAttr2:$stride,
     Tosa_IntArrayAttr2:$dilation,
+    OptionalAttr<I64Attr>:$group,
     OptionalAttr<Tosa_ConvOpQuantizationAttr>:$quantization_info,
     DefaultValuedOptionalAttr<BoolAttr, "false">:$local_bound
   );
@@ -116,7 +117,7 @@ def Tosa_Conv2DOp : Tosa_InferShapedTypeOp<"conv2d"> {
     Tosa_Tensor4D:$output
   );
 
-  let builders = [Tosa_ConvOpQuantInfoBuilder];
+  let builders = [Tosa_ConvOpQuantInfoBuilder, Tosa_ConvOpGroupQuantBuilder];
   let hasVerifier = 1;
 }
 
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index 77c3d2e8757910..898bed4a895864 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -236,6 +236,7 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
   LogicalResult
   matchAndRewrite(TosaConvOp op, typename TosaConvOp::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
+    bool isConv2DOp = isa<tosa::Conv2DOp>(op);
     Location loc = op->getLoc();
     Value input = op->getOperand(0);
     Value weight = op->getOperand(1);
@@ -253,6 +254,24 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
     DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
     bool isQuantized = op.getQuantizationInfo().has_value();
+    int64_t group = 1;
+
+    if (auto convop = dyn_cast<tosa::Conv2DOp>(&op)) {
+      if (convop->getGroup().has_value())
+        group = convop->getGroup().value();
+    }
+
+    if (group > 1 && isConv2DOp &&
+        !std::is_same<LinalgConvOp, linalg::Conv2DNhwgcGfhwcOp>::value &&
+        !std::is_same<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>::value)
+      return rewriter.notifyMatchFailure(
+          op, "tosa.conv ops should map to grouped convolution ops");
+
+    if (group == 1 && isConv2DOp &&
+        !std::is_same<LinalgConvOp, linalg::Conv2DNhwcFhwcOp>::value &&
+        !std::is_same<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>::value)
+      return rewriter.notifyMatchFailure(
+          op, "tosa.conv ops should map to non-grouped convolution ops");
 
     if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return rewriter.notifyMatchFailure(
@@ -274,8 +293,6 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
         strideTosaAttr.asArrayRef(), dilationTosaAttr.asArrayRef(),
         inputSizeDims, kernelSizeDims, rewriter);
 
-    auto weightShape = weightTy.getShape();
-
     // Apply padding as necessary.
     TypedAttr zeroAttr = rewriter.getZeroAttr(inputETy);
     if (isQuantized) {
@@ -302,15 +319,64 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     pad.resize(pad.size() + 2, 0);
     input = applyPad(loc, input, pad, zeroAttr, rewriter);
 
-    if (4 == inputTy.getRank()) {
-      // For 2D convolutions, we need to check if the target convolution op
-      // wants a HWCF kernel layout.
-      bool wantHwcf =
-          isQuantized ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
-                      : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
-      if (wantHwcf) {
-        // Transpose the kernel to match dimension ordering of the linalg
-        // convolution operation.
+    auto weightShape = weightTy.getShape();
+    SmallVector<int64_t> weightPerm;
+
+    auto resultShape = resultTy.getShape();
+    auto newResultTy = resultTy;
+
+    if (isConv2DOp && group > 1) {
+      // Map 4D-tensors to 5D tensors
+      auto inputShape = cast<ShapedType>(input.getType()).getShape();
+      SmallVector<int64_t, 5> newInputShape = {inputShape[0], inputShape[1],
+                                               inputShape[2], group,
+                                               inputShape[3] / group};
+
+      SmallVector<int64_t, 5> newWeightShape = {group, weightShape[0] / group,
+                                                weightShape[1], weightShape[2],
+                                                weightShape[3]};
+      input = rewriter.create<tosa::ReshapeOp>(
+          loc, RankedTensorType::get(newInputShape, inputETy), input,
+          rewriter.getDenseI64ArrayAttr(newInputShape));
+      weight = rewriter.create<tosa::ReshapeOp>(
+          loc, RankedTensorType::get(newWeightShape, weightTy.getElementType()),
+          weight, rewriter.getDenseI64ArrayAttr(newWeightShape));
+    } else {
+
+      if (4 == inputTy.getRank()) {
+        // For 2D convolutions, we need to check if the target convolution op
+        // wants a HWCF kernel layout.
+        bool wantHwcf =
+            isQuantized
+                ? std::is_same_v<LinalgConvQOp, linalg::Conv2DNhwcHwcfQOp>
+                : std::is_same_v<LinalgConvOp, linalg::Conv2DNhwcHwcfOp>;
+        if (wantHwcf) {
+          // Transpose the kernel to match dimension ordering of the linalg
+          // convolution operation.
+          // TODO(suderman): See if this can be efficiently folded - check
+          // whether the input is used anywhere else, if not fold the constant.
+          SmallVector<int64_t> weightPerm;
+          for (int i = 1; i < resultTy.getRank(); i++)
+            weightPerm.push_back(i);
+          weightPerm.push_back(0);
+
+          SmallVector<int64_t> newWeightShape;
+          for (auto dim : weightPerm)
+            newWeightShape.push_back(weightShape[dim]);
+          auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
+          Value weightPermValue =
+              rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
+          Type newWeightTy =
+              RankedTensorType::get(newWeightShape, weightTy.getElementType());
+          weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
+                                                      weightPermValue);
+        }
+      }
+
+      // For Conv3D transpose the kernel to match dimension ordering of the
+      // linalg convolution operation. Conv2D has a 1-1 mapping in linalg so
+      // better to map directly and then transpose later if desired.
+      if (5 == inputTy.getRank()) {
         // TODO(suderman): See if this can be efficiently folded - check whether
         // the input is used anywhere else, if not fold the constant.
         SmallVector<int64_t> weightPerm;
@@ -331,27 +397,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       }
     }
 
-    // For Conv3D transpose the kernel to match dimension ordering of the linalg
-    // convolution operation. Conv2D has a 1-1 mapping in linalg so better to
-    // map directly and then transpose later if desired.
-    if (5 == inputTy.getRank()) {
-      // TODO(suderman): See if this can be efficiently folded - check whether
-      // the input is used anywhere else, if not fold the constant.
-      SmallVector<int64_t> weightPerm;
-      for (int i = 1; i < resultTy.getRank(); i++)
-        weightPerm.push_back(i);
-      weightPerm.push_back(0);
-
-      SmallVector<int64_t> newWeightShape;
-      for (auto dim : weightPerm)
-        newWeightShape.push_back(weightShape[dim]);
-      auto weightPermAttr = rewriter.getI64TensorAttr(weightPerm);
-      Value weightPermValue =
-          rewriter.create<arith::ConstantOp>(loc, weightPermAttr);
-      Type newWeightTy =
-          RankedTensorType::get(newWeightShape, weightTy.getElementType());
-      weight = rewriter.create<tosa::TransposeOp>(loc, newWeightTy, weight,
-                                                  weightPermValue);
+    if (isConv2DOp && group > 1) {
+      SmallVector<int64_t, 5> newResultShape{resultShape[0], resultShape[1],
+                                             resultShape[2], group,
+                                             resultShape[3] / group};
+      newResultTy = RankedTensorType::get(newResultShape, resultETy);
     }
 
     // Extract the attributes for convolution.
@@ -368,6 +418,13 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     Value broadcastBias =
         linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
 
+    if (isConv2DOp && group > 1) {
+      broadcastBias = rewriter.create<tosa::ReshapeOp>(
+          loc, RankedTensorType::get(newResultTy.getShape(), resultETy),
+          broadcastBias, rewriter.getDenseI64ArrayAttr(newResultTy.getShape()));
+    }
+
+    Value conv;
     if (isQuantized) {
       auto quantizationInfo = *op.getQuantizationInfo();
       auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
@@ -376,22 +433,25 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
 
-      Value conv =
+      conv =
           rewriter
               .create<LinalgConvQOp>(
-                  loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
+                  loc, newResultTy, ValueRange{input, weight, iZpVal, kZpVal},
                   ValueRange{broadcastBias}, strideAttr, dilationAttr)
               ->getResult(0);
-
-      rewriter.replaceOp(op, conv);
-      return success();
+    } else {
+      conv = rewriter
+                 .create<LinalgConvOp>(
+                     loc, newResultTy, ValueRange{input, weight},
+                     ValueRange{broadcastBias}, strideAttr, dilationAttr)
+                 ->getResult(0);
     }
 
-    Value conv = rewriter
-                     .create<LinalgConvOp>(
-                         loc, resultTy, ValueRange{input, weight},
-                         ValueRange{broadcastBias}, strideAttr, dilationAttr)
-                     ->getResult(0);
+    if (isConv2DOp && group > 1) {
+      conv = rewriter.create<tosa::ReshapeOp>(
+          loc, RankedTensorType::get(resultShape, resultETy), conv,
+          rewriter.getDenseI64ArrayAttr(resultShape));
+    }
 
     rewriter.replaceOp(op, conv);
     return success();
@@ -1074,6 +1134,7 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns(
   }
   patterns->add<
       // clang-format off
+      ConvConverter<tosa::Conv2DOp, linalg::Conv2DNhwgcGfhwcOp, linalg::Conv2DNhwgcGfhwcQOp>,
       ConvConverter<tosa::Conv3DOp, linalg::Conv3DNdhwcDhwcfOp, linalg::Conv3DNdhwcDhwcfQOp>,
       DepthwiseConvConverter,
       MatMulConverter,
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index d93db1b237f316..cedfa8a5afd110 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -383,6 +383,19 @@ static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
   }
 }
 
+// Handles grouped convolution
+static void buildConvOpWithQuantInfo(OpBuilder &builder, OperationState &result,
+                                     Type outputType, Value input, Value weight,...
[truncated]

@github-actions
Copy link

github-actions bot commented Sep 11, 2024

✅ With the latest revision this PR passed the Python code formatter.

@stefankoncarevic stefankoncarevic force-pushed the group-conv2d branch 2 times, most recently from 9cf14bf to 7ee8a07 Compare September 11, 2024 11:59
@GeorgeARM
Copy link
Contributor

GeorgeARM commented Sep 11, 2024

It is great to see this patch!
From what I gather so far the TOSA specification does not define anything around grouped convolutions in the Conv2d operator itself. There could be ongoing work in this space; @eric-k256 and @sjarus will be able to provide more information on this. If we are happy to deviate from the spec temporarily with an optional attribute I can proceed with reviewing this patch.

In the meantime may I suggest to split the linalg grouped conv structured operations to a separate patch and keep the TOSA dialect extension and legalisation in this one. This way we can merge the linalg part faster?

@sjarus
Copy link
Contributor

sjarus commented Sep 13, 2024

Hi there, sorry for the delay and thanks for this contribution! As @GeorgeARM mentioned, TOSA has a specification that the dialect implements. This is because the spec is also used to design hardware and backend software stacks that depend on the existence of a stable IR construct.

We'd be happy to help you/your organization contribute to the specification with this proposal and it'll then work its way into the dialect. However, this will take more time to add, than the time to make a dialect change that has no specification impact (e.g. update a verifier)

Until then, to unblock you is it possible to address this as two separate things - the change to the dialect (with its spec relationship) and the TosaToLinalg pass. Is it an option to refactor the latter to pattern match multiple tosa.conv2d instances that constitute a grouped convolution ?

Copy link
Contributor

@eric-k256 eric-k256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on what @sjarus said. This would be a notable deviation from the TOSA spec.

Marking this as requesting changes to note the request to split the linalg change from the TOSA change.

Tosa_IntArrayAttr4:$pad,
Tosa_IntArrayAttr2:$stride,
Tosa_IntArrayAttr2:$dilation,
OptionalAttr<I64Attr>:$group,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As @GeorgeARM pointed out, this doesn't match what is in the TOSA specification. The dialect should be a representation of the specification.
We recently took a look at grouped convolution, and most of the use cases we saw could be handled by creating multiple TOSA Conv2d ops.

@stefankoncarevic stefankoncarevic changed the title [mlir][tosa] Convert group tosa::Conv2DOp to linalg conv [mlir][linalg] Add Grouped Convolution Ops: conv_2d_nhwgc_gfhwc and conv_2d_nhwgc_gfhwc_q Sep 19, 2024
@stefankoncarevic
Copy link
Contributor Author

Hi everyone,
I've made the requested changes by separating the Linalg grouped convolution operations into this PR. This will allow us to proceed with upstreaming just this part.
Regarding the TOSA dialect changes, after discussions within our team, we've decided not to upstream that portion as it diverges too much from the current specification. We'll focus on the Linalg operations for now.
Thank you for your feedback and support!

@eric-k256 eric-k256 self-requested a review September 19, 2024 22:03
@eric-k256 eric-k256 dismissed their stale review September 19, 2024 22:05

This isn't dependent on TOSA now, it looks good to me, but someone closer to linalg would be better to approve.

@eric-k256
Copy link
Contributor

Thanks for the update. We're always looking for feedback and future changes we should be making to TOSA, feel free to reach out for a discussion on the topic.
I've removed my change request, one of the other reviewers would be better placed to approve the change.

@krzysz00
Copy link
Contributor

Ping for people who know linalg

@bjacob
Copy link
Contributor

bjacob commented Oct 25, 2024

FYI @rsuderman, @MaheshRavishankar, @hanhanW.


// -----

// CHECK-LABEL: func @conv_2d_nhwgc_gfhwc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test case for static shapes, like below test case? It is not easy to see if the indexing maps are defined properly or not. Having a static shape test case could capture the failure if there is something wrong.

@stefankoncarevic
Copy link
Contributor Author

Hi @krzysz00, @manupak can you merge PR, I don't have rights.

@rengolin
Copy link
Member

Just a comment, this has the same problems as #113953.

I'm not sure what the difference is here and why people are agreeing with this one and not the other one.

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we talk about the more general pattern here instead of adding dozens of variations on the same theme?

@stellaraccident
Copy link
Contributor

Shouldn't we talk about the more general pattern here instead of adding dozens of variations on the same theme?

Yes, I'm having trouble with the firehose of individual conv op additions, and responding to individual threads isn't working. Let's boost this to the forum and list the component patches to discuss together.

FTR - these seem to be very basic variations of conv needed to support common use cases and I am sympathetic to the need to have a practical solution for them in short order. I'm also cringing at the addition of technical debt this adds because the project overall doesn't have a plan for how to manage the expansion. By raising this to the forum, hopefully we can quickly converge on a practical path to land these now and start the conversation on the ultimate plan. We just did that for matmul and it wasn't too bad -- so I'm hopeful about these too.

@rengolin
Copy link
Member

By raising this to the forum, hopefully we can quickly converge on a practical path to land these now and start the conversation on the ultimate plan. We just did that for matmul and it wasn't too bad -- so I'm hopeful about these too.

https://discourse.llvm.org/t/rfc-op-explosion-in-linalg/82863

Copy link
Member

@rengolin rengolin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the block so that we have the discussion in the forums. But let's not add more stuff, please.

`conv_2d_nhwgc_gfhwc` and `conv_2d_nhwgc_gfhwc_q`.
These operations perform 2-D grouped convolutions with and
without zero point offsets, respectively. The input layout
is NHWGC, and the kernel layout is GFHWC. These additions enhance
support for grouped convolution operations in MLIR.
@stefankoncarevic
Copy link
Contributor Author

Hi @krzysz00, @manupak or someone else can merge this PR, I don't have rights.

@MaheshRavishankar MaheshRavishankar merged commit 39358f8 into llvm:main Nov 8, 2024
8 checks passed
Groverkss pushed a commit to iree-org/llvm-project that referenced this pull request Nov 15, 2024
…onv_2d_nhwgc_gfhwc_q (llvm#108192)

This patch adds two new ops: linalg::Conv2DNhwgcGfhwcOp and
linalg::Conv2DNhwgcGfhwcQOp, and uses them to convert tosa group conv2d
Ops.
- Added linalg::Conv2DNhwgcGfhwcOp and linalg::Conv2DNhwgcGfhwcQOp.
- Updated the conversion process to use these new ops for tosa group
conv2d operations.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.