Skip to content

Commit 6ec102a

Browse files
committed
[tosa] : Don't fold mul with zero lhs/rhs if resulting type is dynamic.
1 parent 4f6ae2a commit 6ec102a

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1120,13 +1120,14 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
11201120
}
11211121

11221122
if (rhsTy == resultTy) {
1123-
if (isSplatZero(resultETy, lhsAttr))
1123+
if (isSplatZero(resultETy, lhsAttr) && resultTy.hasStaticShape())
1124+
// constant values can only be resized if resulting type is static
11241125
return lhsAttr.resizeSplat(resultTy);
11251126
if (isSplatOne(resultETy, lhsAttr, shift))
11261127
return rhs;
11271128
}
11281129
if (lhsTy == resultTy) {
1129-
if (isSplatZero(resultETy, rhsAttr))
1130+
if (isSplatZero(resultETy, rhsAttr) && resultTy.hasStaticShape())
11301131
return rhsAttr.resizeSplat(resultTy);
11311132
if (isSplatOne(resultETy, rhsAttr, shift))
11321133
return lhs;

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,33 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
565565

566566
// -----
567567

568+
// CHECK-LABEL: @mul_zero_dynamic_nofold
569+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
570+
// CHECK: %[[ZERO:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
571+
// CHECK: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
572+
// CHECK: %[[MUL:.*]] = tosa.mul %[[ARG0]], %[[ZERO]], %[[SHIFT]] : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
573+
// CHECK: return %[[MUL]]
574+
func.func @mul_zero_dynamic_nofold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
575+
%0 = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
576+
%1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
577+
%2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
578+
return %2 : tensor<?x17xf32>
579+
}
580+
581+
// -----
582+
583+
// CHECK-LABEL: @mul_one_dynamic_fold
584+
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x17xf32>) -> tensor<?x17xf32> {
585+
// CHECK: return %[[ARG0]]
586+
func.func @mul_one_dynamic_fold(%arg0: tensor<?x17xf32>) -> tensor<?x17xf32> {
587+
%0 = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1xf32>}> : () -> tensor<1x1xf32>
588+
%1 = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
589+
%2 = tosa.mul %arg0, %0, %1 : (tensor<?x17xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<?x17xf32>
590+
return %2 : tensor<?x17xf32>
591+
}
592+
593+
// -----
594+
568595
// CHECK-LABEL: @select_same_value
569596
func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
570597
%0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>

0 commit comments

Comments
 (0)