Skip to content

Commit 160428e

Browse files
committed
fix: no fold for mul for zero const if output shape is dynamic
1 parent f713706 commit 160428e

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -718,7 +718,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
718718
auto lhsTy = llvm::dyn_cast<RankedTensorType>(lhs.getType());
719719
auto rhsTy = llvm::dyn_cast<RankedTensorType>(rhs.getType());
720720
auto resultTy = llvm::dyn_cast<RankedTensorType>(getType());
721-
if (!lhsTy || !rhsTy || !resultTy)
721+
if (!lhsTy || !rhsTy || !resultTy || !resultTy.hasStaticShape())
722722
return {};
723723

724724
auto resultETy = resultTy.getElementType();

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,16 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
301301
return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
302302
}
303303

304+
// CHECK-LABEL: @mul_zero_broadcast_dynamic_result
305+
func.func @mul_zero_broadcast_dynamic_result(%arg0: tensor<?x3xf32>) -> (tensor<?x3xf32>, tensor<?x3xf32>) {
306+
// CHECK: tosa.mul
307+
// CHECK: tosa.mul
308+
%zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
309+
%1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor<?x3xf32>, tensor<1x1xf32>) -> tensor<?x3xf32>
310+
%2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<?x3xf32>) -> tensor<?x3xf32>
311+
return %1, %2 : tensor<?x3xf32>, tensor<?x3xf32>
312+
}
313+
304314
// CHECK-LABEL: @select_same_value
305315
func.func @select_same_value(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> {
306316
%0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>

0 commit comments

Comments
 (0)