diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index d9840e3923c4f..d31d3ef4bd7ef 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -1749,6 +1749,10 @@ void ReduceOp::print(OpAsmPrinter &p) { LogicalResult ReduceOp::verify() { ArrayRef dimensionsRef = getDimensions(); + if (getNumDpsInits() != getNumDpsInputs()) { + return emitOpError() << "requires same number of input and init operands"; + } + for (int64_t i = 1; i < getNumDpsInputs(); ++i) { if (llvm::cast(getInputs()[i].getType()).getShape() != llvm::cast(getInputs()[0].getType()).getShape()) { diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index e3b6958cfa881..bf502eab79878 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -728,6 +728,25 @@ func.func @reduce_reduced_input_init_rank_mismatch(%input: tensor<16x32x64xf32>, func.return %reduce : tensor<16x64xf32> } + +// ----- + +func.func @reduce_mismatched_inputs_outputs( + %input1: tensor<16x32x64xf32>, + %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xf32>) -> (tensor<16x64xf32>) { + // expected-error @+1{{'linalg.reduce' op requires same number of input and init operands}} + %reduce = linalg.reduce + ins(%input1, %input2 : tensor<16x32x64xf32>, tensor<16x32x64xf32>) + outs(%init1 : tensor<16x64xf32>) + dimensions = [1] + (%in: f32, %in2: f32, %out: f32) { + %0 = arith.mulf %in, %in2: f32 + %1 = arith.addf %in, %out: f32 + linalg.yield %1: f32 + } + func.return %reduce : tensor<16x64xf32> +} + // ----- func.func @reduce_wrong_number_of_block_arguments(