diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index a85ff10aa0d73..a4217051f2dfa 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1484,7 +1484,24 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { return {}; } +static bool +mayRequireBroadcast(ValueTypeRange operandTypes) { + const auto isDynamic = [](Type ty) { + const auto shapedTy = llvm::dyn_cast(ty); + return !shapedTy || !shapedTy.hasStaticShape(); + }; + + return llvm::any_of(operandTypes, isDynamic) || + failed(verifyCompatibleShapes(operandTypes)); +} + OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { + // Select allows operand shapes to be broadcast to the output shape. For + // now, don't support folding when we cannot prove no broadcasting is + // involved. + if (mayRequireBroadcast(getOperandTypes())) + return {}; + if (getOnTrue() == getOnFalse()) return getOnTrue(); diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 5a40f3fa8572c..fc5ea7710e2c4 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -643,6 +643,48 @@ func.func @select_not_pred(%arg0: tensor<2x3xi1>, %arg1: tensor<2x3xi32>, %arg2: // ----- +// CHECK-LABEL: @select_broadcast_same_value_no_fold +func.func @select_broadcast_same_value_no_fold(%arg0: tensor<2x2xi1>, %arg1: tensor<1x1xf32>) -> tensor<2x2xf32> { + // CHECK: tosa.select %arg0, %arg1, %arg1 + %0 = tosa.select %arg0, %arg1, %arg1 : (tensor<2x2xi1>, tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @select_broadcast_true_value_no_fold +func.func @select_broadcast_true_value_no_fold(%arg0: tensor<1x1xf32>, %arg1: tensor<2x2xf32>) -> tensor { + // CHECK: %[[CONST:.*]] = "tosa.const" + %0 = "tosa.const"() {values = dense<1> : tensor<2x2xi1>} : () -> tensor<2x2xi1> + // CHECK: tosa.select %[[CONST]], %arg0, %arg1 + %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<1x1xf32>, tensor<2x2xf32>) -> tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @select_broadcast_false_value_no_fold +func.func @select_broadcast_false_value_no_fold(%arg0: tensor<2x2xf32>, %arg1: tensor<1x1xf32>) -> tensor<2x2xf32> { + // CHECK: %[[CONST:.*]] = "tosa.const" + %0 = "tosa.const"() {values = dense<0> : tensor<2x2xi1>} : () -> tensor<2x2xi1> + // CHECK: tosa.select %[[CONST]], %arg0, %arg1 + %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<1x1xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @select_broadcast_false_value_dynamic_operand_no_fold +func.func @select_broadcast_false_value_dynamic_operand_no_fold(%arg0: tensor<2x?xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: %[[CONST:.*]] = "tosa.const" + %0 = "tosa.const"() {values = dense<0> : tensor<2x2xi1>} : () -> tensor<2x2xi1> + // CHECK: tosa.select %[[CONST]], %arg0, %arg1 + %1 = tosa.select %0, %arg0, %arg1 : (tensor<2x2xi1>, tensor<2x?xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %1 : tensor<2x2xf32> +} + +// ----- + // CHECK-LABEL: @reduce_all_fold func.func @reduce_all_fold(%arg0: tensor) -> tensor { // CHECK: return %arg0