-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[mlir][tosa][linalg] Pass local_bound attribute to named linalg ops #146883
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
In TOSA 1.0, we have local_bound attribute to control compliance error bound. When local_bound is true, direct dot-product calculation precision is required. However, we lost the information when converting tosa to linalg. This patch passes local_bound attribute from tosa.conv2d, tosa.conv3d, and tosa.depthwise_conv2d to the corresponding named linalg operations.
|
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Hsiangkai Wang (Hsiangkai) ChangesIn TOSA 1.0, we have local_bound attribute to control compliance error bound. When local_bound is true, direct dot-product calculation precision is required. However, we lost the information when converting tosa to linalg. This patch passes local_bound attribute from tosa.conv2d, tosa.conv3d, and tosa.depthwise_conv2d to the corresponding named linalg operations. Full diff: https://github.com/llvm/llvm-project/pull/146883.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index c460a8bb2f4b2..b6ddce3f785fe 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -398,6 +398,11 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
Value broadcastBias =
linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
+ bool localBound = false;
+ if (auto localBoundAttr =
+ op->template getAttrOfType<BoolAttr>("local_bound"))
+ localBound = localBoundAttr.getValue();
+
if (hasZp) {
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
auto kZp = rewriter.getI32IntegerAttr(weightZpVal);
@@ -405,29 +410,31 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
- Value conv =
- rewriter
- .create<LinalgConvQOp>(
- loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
- ->getResult(0);
+ auto conv = rewriter.create<LinalgConvQOp>(
+ loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
+ ValueRange{broadcastBias}, strideAttr, dilationAttr);
+
+ if (localBound)
+ conv->setAttr("local_bound", rewriter.getBoolAttr(true));
- rewriter.replaceOp(op, conv);
+ rewriter.replaceOp(op, conv->getResult(0));
return success();
}
- Value conv = rewriter
- .create<LinalgConvOp>(
- loc, accTy, ValueRange{input, weight},
- ValueRange{broadcastBias}, strideAttr, dilationAttr)
- ->getResult(0);
+ auto conv = rewriter.create<LinalgConvOp>(
+ loc, accTy, ValueRange{input, weight}, ValueRange{broadcastBias},
+ strideAttr, dilationAttr);
+ Value convVal = conv.getResult(0);
+
+ if (localBound)
+ conv->setAttr("local_bound", rewriter.getBoolAttr(true));
// We may need to truncate back to the result type if the accumulator was
// wider than the result.
if (resultTy != accTy)
- conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
+ convVal = rewriter.create<tosa::CastOp>(loc, resultTy, convVal);
- rewriter.replaceOp(op, conv);
+ rewriter.replaceOp(op, convVal);
return success();
}
};
@@ -551,26 +558,32 @@ class DepthwiseConvConverter
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultRank));
+ bool localBound = false;
+ if (auto localBoundAttr = op->getAttrOfType<BoolAttr>("local_bound"))
+ localBound = localBoundAttr.getValue();
+
if (hasNullZps) {
- Value conv = rewriter
- .create<linalg::DepthwiseConv2DNhwcHwcmOp>(
- loc, linalgConvTy, ValueRange{input, weight},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
- .getResult(0);
+ auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmOp>(
+ loc, linalgConvTy, ValueRange{input, weight}, ValueRange{zeroTensor},
+ strideAttr, dilationAttr);
+ Value convVal = conv.getResult(0);
+
+ if (localBound)
+ conv->setAttr("local_bound", rewriter.getBoolAttr(true));
// We may need to truncate back to the result type if the accumulator was
// wider than the result.
if (accETy != resultETy)
- conv = rewriter.create<tosa::CastOp>(
+ convVal = rewriter.create<tosa::CastOp>(
loc,
- RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
- resultETy),
- conv);
+ RankedTensorType::get(
+ cast<ShapedType>(convVal.getType()).getShape(), resultETy),
+ convVal);
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
- loc, resultTy, conv, reassociationMap);
+ loc, resultTy, convVal, reassociationMap);
Value result =
rewriter
@@ -596,16 +609,17 @@ class DepthwiseConvConverter
IntegerAttr wZp = rewriter.getI32IntegerAttr(weightZpVal);
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, wZp);
- Value conv =
- rewriter
- .create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
- loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
- ValueRange{zeroTensor}, strideAttr, dilationAttr)
- .getResult(0);
+ auto conv = rewriter.create<linalg::DepthwiseConv2DNhwcHwcmQOp>(
+ loc, linalgConvTy, ValueRange{input, weight, iZpVal, kZpVal},
+ ValueRange{zeroTensor}, strideAttr, dilationAttr);
+
+ if (localBound)
+ conv->setAttr("local_bound", rewriter.getBoolAttr(true));
+
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
- loc, resultTy, conv, reassociationMap);
+ loc, resultTy, conv.getResult(0), reassociationMap);
Value result = linalgIntBroadcastExtSIAdd(
rewriter, loc, bias, convReshape, biasEmptyTensor, indexingMaps);
rewriter.replaceOp(op, result);
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index a737a8a05bae6..0ee0b573ddf11 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -1045,3 +1045,69 @@ func.func @max_pool2d_nan_ignore(%arg0: tensor<1x6x34x62xf32>) -> (tensor<1x4x32
%0 = tosa.max_pool2d %arg0 {pad = array<i64: 0, 0, 0, 0>, kernel = array<i64: 3, 3>, stride = array<i64: 1, 1>, nan_mode = "IGNORE"} : (tensor<1x6x34x62xf32>) -> tensor<1x4x32x62xf32>
return %0: tensor<1x4x32x62xf32>
}
+
+// -----
+
+// CHECK-LABEL: @conv2d_local_bound_true
+func.func @conv2d_local_bound_true(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () {
+ // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, local_bound = true, strides = dense<1> : tensor<2xi64>}
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = true, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @conv2d_local_bound_false
+func.func @conv2d_local_bound_false(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<1xf32>) -> () {
+ // CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = false, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x45x40x28xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @conv3d_local_bound_true
+func.func @conv3d_local_bound_true(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<1xf32>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ // CHECK: linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>, local_bound = true, strides = dense<1> : tensor<3xi64>}
+ %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = true, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @conv3d_local_bound_false
+func.func @conv3d_local_bound_false(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<1xf32>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ // CHECK: linalg.conv_3d_ndhwc_dhwcf {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
+ %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, local_bound = false, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x47x45x43x28xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv_local_bound_true
+func.func @depthwise_conv_local_bound_true(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<1xf32>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ // CHECK: linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, local_bound = true, strides = dense<1> : tensor<2xi64>}
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, local_bound = true, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
+ return
+}
+
+// -----
+
+// CHECK-LABEL: @depthwise_conv_local_bound_false
+func.func @depthwise_conv_local_bound_false(%arg0 : tensor<1x7x5x3xf32>, %arg1 : tensor<3x1x3x11xf32>, %arg2 : tensor<1xf32>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ // CHECK: linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, local_bound = false, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf32>, tensor<3x1x3x11xf32>, tensor<1xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x5x5x33xf32>
+ return
+}
|
|
Just wanted to mention that RFFT2D and FFT2D also have the |
RFFT2D and FFT2D are lowered to linalg.generic. I'm not sure how |
|
Are we sure that we want to pass it down as a custom named attribute to linalg named operations? |
|
This PR seems premature since it alters Linalg conv with the semantics of a producing dialect construct, but doesn't express the semantics to consumers of Linalg here. I would recommend removing the reviewers to offer a level playing field, marking this WIP and starting an RFC on Discourse. |
sjarus
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest making WIP and doing an RFC.
|
Agree with @sjarus that it's not clear for consumers of linalg in the MLIR Core. Abandon this PR. |
In TOSA 1.0, we have local_bound attribute to control compliance error bound. When local_bound is true, direct dot-product calculation precision is required. However, we lost the information when converting tosa to linalg. This patch passes local_bound attribute from tosa.conv2d, tosa.conv3d, and tosa.depthwise_conv2d to the corresponding named linalg operations.