diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index bebfaa8c1ea82..65efc88e9c403 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -698,6 +698,14 @@ struct RemoveEmptyShapeOperandsPattern : public OpRewritePattern { auto newOperands = llvm::filter_to_vector<8>(op->getOperands(), isPotentiallyNonEmptyShape); + // Replace the op with empty shape constant if all operants are reduced to + // be empty. + if (newOperands.empty()) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes().front(), rewriter.getIndexTensorAttr({})); + return success(); + } + // Reduce op to equivalent without empty shape operands. if (newOperands.size() < op.getNumOperands()) { rewriter.replaceOpWithNewOp(op, op->getResultTypes(), newOperands, diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 5b98a7790debf..cf439c9c1b854 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence" %s | FileCheck %s +// RUN: mlir-opt -split-input-file -allow-unregistered-dialect -canonicalize="test-convergence top-down=0" %s | FileCheck %s // CHECK-LABEL: func @f func.func @f(%arg0: tensor<2x3x4xf32>) -> tensor<3xindex> { @@ -134,6 +135,21 @@ func.func @all_but_one_empty(%arg0 : !shape.shape) -> !shape.shape { // ----- +// All operands are known empty shapes. +// CHECK-LABEL: @all_empty +// CHECK-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor) +func.func @all_empty(%arg0: tensor, %arg1: tensor) -> tensor<0xindex> { + // CHECK: %[[CST:.*]] = shape.const_shape [] : tensor<0xindex> + // CHECK: return %[[CST]] : tensor<0xindex> + %1 = shape.shape_of %arg0 : tensor -> tensor<0xindex> + %2 = shape.shape_of %arg1 : tensor -> tensor<0xindex> + %3 = shape.const_shape [] : tensor<0xindex> + %4 = shape.broadcast %1, %2, %3 : tensor<0xindex>, tensor<0xindex>, tensor<0xindex> -> tensor<0xindex> + return %4 : tensor<0xindex> +} + +// ----- + // Partial folding. // CHECK-LABEL: @partial_folding // CHECK-SAME: (%[[ARG:.*]]: !shape.shape)