Skip to content

Conversation

@Hsiangkai
Copy link
Contributor

In TOSA 1.0, we have local_bound attribute to control compliance error bound. When local_bound is true, direct dot-product calculation precision is required. However, we lost the information when converting tosa to linalg. This patch passes local_bound attribute from tosa.conv2d, tosa.conv3d, and tosa.depthwise_conv2d to the corresponding named linalg operations.

In TOSA 1.0, we have local_bound attribute to control compliance error
bound. When local_bound is true, direct dot-product calculation
precision is required. However, we lost the information when converting
tosa to linalg. This patch passes local_bound attribute from
tosa.conv2d, tosa.conv3d, and tosa.depthwise_conv2d to the corresponding
named linalg operations.
@llvmbot
Copy link
Member

llvmbot commented Jul 3, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Hsiangkai Wang (Hsiangkai)

Changes

In TOSA 1.0, we have local_bound attribute to control compliance error bound. When local_bound is true, direct dot-product calculation precision is required. However, we lost the information when converting tosa to linalg. This patch passes local_bound attribute from tosa.conv2d, tosa.conv3d, and tosa.depthwise_conv2d to the corresponding named linalg operations.


Full diff: https://github.com/llvm/llvm-project/pull/146883.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp (+45-31)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir (+66)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index c460a8bb2f4b2..b6ddce3f785fe 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -398,6 +398,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
     Value broadcastBias =
         linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
 
+    bool localBound = false;
+    if (auto localBoundAttr =
+            op->template getAttrOfType<BoolAttr>("local_bound"))
+      localBound = localBoundAttr.getValue();
+
     if (hasZp) {
       auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
       auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
@@ -405,29 +410,31 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
 
-      Value conv =
-          rewriter
-              .create<LinalgConvQOp>(
-                  loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
-                  ValueRange{broadcastBias}, strideAttr, dilationAttr)
-              ->getResult(0);
+      auto conv = rewriter.create<LinalgConvQOp>(
+          loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
+          ValueRange{broadcastBias}, strideAttr, dilationAttr);
+
+      if (localBound)
+        conv->setAttr("local_bound", rewriter.getBoolAttr(true));
 
-      rewriter.replaceOp(op, conv);
+      rewriter.replaceOp(op, conv->getResult(0));
       return success();
     }
 
-    Value conv = rewriter
-                     .create<LinalgConvOp>(
-                         loc, accTy, ValueRange{input, weight},
-                         ValueRange{broadcastBias}, strideAttr, dilationAttr)
-                     ->getResult(0);
+    auto conv = rewriter.create<LinalgConvOp>(
+        loc, accTy, ValueRange{input, weight}, ValueRange{broadcastBias},
+        strideAttr, dilationAttr);
+    Value convVal = conv.getResult(0);
+
+    if (localBound)
+      conv->setAttr("local_bound", rewriter.getBoolAttr(true));
 
     // We may need to truncate back to the result type if the accumulator was
     // wider than the result.
     if (resultTy != accTy)
-      conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
+      convVal = rewriter.create<tosa::CastOp>(loc, resultTy, convVal);
 
-    rewriter.replaceOp(op, conv);
+    rewriter.replaceOp(op, convVal);
     return success();
   }
 };
@@ -551,26 +558,32 @@ class DepthwiseConvConverter
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
     indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
 
+    bool localBound = false;
+    if (auto localBoundAttr = op->getAttrOfType<BoolAttr>("local_bound"))
+      localBound = localBoundAttr.getValue();
+
     if (hasNullZps) {
-      Value conv = rewriter
-                       .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
-                           loc, linalgConvTy, ValueRange{input, weight},
-                           ValueRange{zeroTensor}, strideAttr, dilationAttr)
-                       .getResult(0);
+      auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
+          loc, linalgConvTy, ValueRange{input, weight}, ValueRange{zeroTensor},
+          strideAttr, dilationAttr);
+      Value convVal = conv.getResult(0);
+
+      if (localBound)
+        conv->setAttr("local_bound", rewriter.getBoolAttr(true));
 
       // We may need to truncate back to the result type if the accumulator was
       // wider than the result.
       if (accETy != resultETy)
-        conv = rewriter.create<tosa::CastOp>(
+        convVal = rewriter.create<tosa::CastOp>(
             loc,
-            RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
-                                  resultETy),
-            conv);
+            RankedTensorType::get(
+                cast<ShapedType>(convVal.getType()).getShape(), resultETy),
+            convVal);
 
       SmallVector<ReassociationExprs, 4> reassociationMap;
       createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
       Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
-          loc, resultTy, conv, reassociationMap);
+          loc, resultTy, convVal, reassociationMap);
 
       Value result =
           rewriter
@@ -596,16 +609,17 @@ class DepthwiseConvConverter
       IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
       auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
       auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
-      Value conv =
-          rewriter
-              .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
-                  loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
-                  ValueRange{zeroTensor}, strideAttr, dilationAttr)
-              .getResult(0);
+      auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+          loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
+          ValueRange{zeroTensor}, strideAttr, dilationAttr);
+
+      if (localBound)
+        conv->setAttr("local_bound", rewriter.getBoolAttr(true));
+
       SmallVector<ReassociationExprs, 4> reassociationMap;
       createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
       Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
-          loc, resultTy, conv, reassociationMap);
+          loc, resultTy, conv.getResult(0), reassociationMap);
       Value result = linalgIntBroadcastExtSIAdd(
           rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
       rewriter.replaceOp(op, result);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index a737a8a05bae6..0ee0b573ddf11 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1045,3 +1045,69 @@ func.func @max_pool2d_nan_ignore(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32
   %0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
   return %0: tensor<1x4x32x62xf32>
 }
+
+// -----
+
+// CHECK-LABEL: @conv2d_local_bound_true
+func.func @conv2d_local_bound_true(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () {
+  // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, local_bound = true, strides = dense<1> : tensor<2xi64>}
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = true, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @conv2d_local_bound_false
+func.func @conv2d_local_bound_false(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () {
+  // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = false, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @conv3d_local_bound_true
+func.func @conv3d_local_bound_true(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<1xf32>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  // CHECK: linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>, local_bound = true, strides = dense<1> : tensor<3xi64>}
+  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = true, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @conv3d_local_bound_false
+func.func @conv3d_local_bound_false(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<1xf32>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  // CHECK: linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+  %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = false, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv_local_bound_true
+func.func @depthwise_conv_local_bound_true(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<1xf32>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  // CHECK: linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, local_bound = true, strides = dense<1> : tensor<2xi64>}
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, local_bound = true, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv_local_bound_false
+func.func @depthwise_conv_local_bound_false(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<1xf32>) -> () {
+  %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  // CHECK: linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+  %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, local_bound = false, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
+  return
+}

@Hsiangkai Hsiangkai requested a review from sjarus July 3, 2025 13:24
@lhutton1
Copy link
Contributor

lhutton1 commented Jul 3, 2025

Just wanted to mention that RFFT2D and FFT2D also have the local_bound option, would they also need a similar update? (transpose_conv2d also has the option, but I didn't notice a linalg conversion for it)

@Hsiangkai
Copy link
Contributor Author

Just wanted to mention that RFFT2D and FFT2D also have the local_bound option, would they also need a similar update? (transpose_conv2d also has the option, but I didn't notice a linalg conversion for it)

RFFT2D and FFT2D are lowered to linalg.generic. I'm not sure how local_bound used in linalg.generic. That's why I didn't apply the modification to these two operations. TransposeConv2D is decomposed to a sequence of tosa operations and Conv2D. We may pass the attribute to Conv2D in the decomposition. I will add it. Thanks for your reminding.

@GeorgeARM
Copy link
Contributor

Are we sure that we want to pass it down as a custom named attribute to linalg named operations?
Am quite skeptical in polluting linalg operations. @sjarus any throughts?

@sjarus
Copy link
Contributor

sjarus commented Jul 7, 2025

This PR seems premature since it alters Linalg conv with the semantics of a producing dialect construct, but doesn't express the semantics to consumers of Linalg here. I would recommend removing the reviewers to offer a level playing field, marking this WIP and starting an RFC on Discourse.

Copy link
Contributor

@sjarus sjarus left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest making WIP and doing an RFC.

@Hsiangkai
Copy link
Contributor Author

Agree with @sjarus that it's not clear for consumers of linalg in the MLIR Core. Abandon this PR.

@Hsiangkai Hsiangkai closed this Jul 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants