Skip to content

Conversation

@Tai78641
Copy link
Contributor

This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.

This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.

Signed-off-by: Tai Ly <[email protected]>
Change-Id: I16685bceef25f428669c5412d897b6918a424119
@llvmbot
Copy link
Member

llvmbot commented Apr 24, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Tai Ly (Tai78641)

Changes

This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.


Full diff: https://github.com/llvm/llvm-project/pull/137204.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+46-2)
  • (modified) mlir/test/Dialect/Tosa/verifier.mlir (+32)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 751ae785bda6f..22aca774a403d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -2262,8 +2262,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
 }
 
 LogicalResult tosa::GatherOp::verify() {
-  return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
-                                /* outType = */ getOutput().getType());
+  if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
+                             /* outType = */ getOutput().getType())
+          .failed()) {
+    return failure();
+  }
+
+  const ShapeAdaptor valuesShape(getValues().getType());
+  const ShapeAdaptor indicesShape(getIndices().getType());
+  const ShapeAdaptor outputShape(getOutput().getType());
+
+  int64_t N = ShapedType::kDynamic;
+  int64_t W = ShapedType::kDynamic;
+  int64_t C = ShapedType::kDynamic;
+
+  if (valuesShape.hasRank()) {
+    N = valuesShape.getDimSize(0);
+    C = valuesShape.getDimSize(2);
+  }
+  if (indicesShape.hasRank()) {
+    const int64_t indicesN = indicesShape.getDimSize(0);
+    W = indicesShape.getDimSize(1);
+    if (N == ShapedType::kDynamic)
+      N = indicesN;
+    else if (indicesN != ShapedType::kDynamic && N != indicesN)
+      return emitOpError() << "requires indices dimension 0 to have size " << N
+                           << ", got " << indicesN;
+  }
+  if (outputShape.hasRank()) {
+    const int64_t outputN = outputShape.getDimSize(0);
+    const int64_t outputW = outputShape.getDimSize(1);
+    const int64_t outputC = outputShape.getDimSize(2);
+    if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
+        N != outputN)
+      return emitOpError() << "requires output dimension 0 to have size " << N
+                           << ", got " << outputN;
+
+    if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
+        W != outputW)
+      return emitOpError() << "requires output dimension 1 to have size " << W
+                           << ", got " << outputW;
+    if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
+        C != outputC)
+      return emitOpError() << "requires output dimension 2 to have size " << C
+                           << ", got " << outputC;
+  }
+  return success();
 }
 
 LogicalResult tosa::ResizeOp::inferReturnTypeComponents(
diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir
index 262e6d4265ea6..2b78773e4ed7f 100644
--- a/mlir/test/Dialect/Tosa/verifier.mlir
+++ b/mlir/test/Dialect/Tosa/verifier.mlir
@@ -358,3 +358,35 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
   %0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
   return %0 : tensor<2x?xf32>
 }
+
+// -----
+// CHECK-LABEL: @test_gather_invalid_indices_N
+func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> {
+  // expected-error@+1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<12x26xi32>) -> tensor<13x26x3xf32>
+  return %0 : tensor<13x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_N
+func.func @test_gather_invalid_out_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<12x26x3xf32> {
+  // expected-error@+1 {{'tosa.gather' op requires output dimension 0 to have size 13, got 12}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<12x26x3xf32>
+  return %0 : tensor<12x26x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_W
+func.func @test_gather_invalid_out_W(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x28x3xf32> {
+  // expected-error@+1 {{'tosa.gather' op requires output dimension 1 to have size 26, got 28}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x28x3xf32>
+  return %0 : tensor<13x28x3xf32>
+}
+
+// -----
+// CHECK-LABEL: test_gather_invalid_out_C
+func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x8xf32> {
+  // expected-error@+1 {{'tosa.gather' op requires output dimension 2 to have size 3, got 8}}
+  %0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
+  return %0 : tensor<13x26x8xf32>
+}

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@lhutton1 lhutton1 merged commit b4e2592 into llvm:main Apr 25, 2025
7 checks passed
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This adds verifier checks for the gather op
to make sure the shapes of inputs and output
are consistent with respect to spec.

---------

Signed-off-by: Tai Ly <[email protected]>
Co-authored-by: Luke Hutton <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants