diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index bf3810ff231da..8cf1fd0d4acb6 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1761,6 +1761,11 @@ LogicalResult tosa::ConcatOp::verify() { } } + const ShapeAdaptor outputShape(outType); + if (outputShape.hasRank() && outputShape.getRank() != firstInputRank) + return emitOpError("expect output rank to match inputs rank, got ") + << outputShape.getRank() << " vs " << firstInputRank; + // ERROR_IF(axis_sum != shape[axis]); int64_t axisSum = 0; for (const auto &input : inputList) { @@ -1772,7 +1777,7 @@ LogicalResult tosa::ConcatOp::verify() { } axisSum += inputShape.getDimSize(axis); } - const ShapeAdaptor outputShape(outType); + if (axisSum >= 0 && outputShape.hasRank() && !outputShape.isDynamicDim(axis) && axisSum != outputShape.getDimSize(axis)) diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 119991ca7b451..c9e03ca53a729 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -306,6 +306,14 @@ func.func @test_concat_input_rank_mismatch(%arg0: tensor<1x2x3xf32>, %arg1: tens // ----- +func.func @test_concat_input_output_rank_mismatch(%arg0: tensor<2x2xf32>, %arg1: tensor<2x1xf32>) -> tensor<2xf32> { + // expected-error@+1 {{'tosa.concat' op expect output rank to match inputs rank, got 1 vs 2}} + %0 = tosa.concat %arg0, %arg1 {axis = 1 : i32} : (tensor<2x2xf32>, tensor<2x1xf32>) -> tensor<2xf32> + return %0 : tensor<2xf32> +} + +// ----- + func.func @test_pad_invalid_padConst_rank(%arg0: tensor<13x21xf32>) { %0 = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> %1 = "tosa.const"() {values = dense<3.14> : tensor<2xf32>} : () -> tensor<2xf32>