diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 5f445231b80fd..700258a1b6254 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -393,8 +393,12 @@ void arith::AddUIExtendedOp::getCanonicalizationPatterns( OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) { // subi(x,x) -> 0 - if (getOperand(0) == getOperand(1)) - return Builder(getContext()).getZeroAttr(getType()); + if (getOperand(0) == getOperand(1)) { + auto shapedType = dyn_cast(getType()); + // We can't generate a constant with a dynamic shaped tensor. + if (!shapedType || shapedType.hasStaticShape()) + return Builder(getContext()).getZeroAttr(getType()); + } // subi(x,0) -> x if (matchPattern(adaptor.getRhs(), m_Zero())) return getLhs(); diff --git a/mlir/test/Dialect/Arith/canonicalize.mlir b/mlir/test/Dialect/Arith/canonicalize.mlir index 69df83d42f543..f1e36c2707a8f 100644 --- a/mlir/test/Dialect/Arith/canonicalize.mlir +++ b/mlir/test/Dialect/Arith/canonicalize.mlir @@ -869,6 +869,27 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index { return %add2 : index } + +// CHECK-LABEL: @foldSubXX_tensor +// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32> +// CHECK: %[[sub:.+]] = arith.subi +// CHECK: return %[[c0]], %[[sub]] +func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor) -> (tensor<10xi32>, tensor) { + %static_sub = arith.subi %static, %static : tensor<10xi32> + %dyn_sub = arith.subi %dyn, %dyn : tensor + return %static_sub, %dyn_sub : tensor<10xi32>, tensor +} + +// CHECK-LABEL: @foldSubXX_vector +// CHECK-DAG: %[[c0:.+]] = arith.constant dense<0> : vector<8xi32> +// CHECK-DAG: %[[c0_scalable:.+]] = arith.constant dense<0> : vector<[4]xi32> +// CHECK: return %[[c0]], %[[c0_scalable]] +func.func @foldSubXX_vector(%static : vector<8xi32>, %dyn : vector<[4]xi32>) -> (vector<8xi32>, vector<[4]xi32>) { + %static_sub = arith.subi %static, %static : vector<8xi32> + %dyn_sub = arith.subi %dyn, %dyn : vector<[4]xi32> + return %static_sub, %dyn_sub : vector<8xi32>, vector<[4]xi32> +} + // CHECK-LABEL: @tripleAddSub0 // CHECK: %[[cres:.+]] = arith.constant 59 : index // CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index