Skip to content

Commit 008fc7b

Browse files
committed
[mlir][tosa] Add verifier checks for Gather
This adds verifier checks for the gather op to make sure the shapes of inputs and output are consistent with respect to spec. Signed-off-by: Tai Ly <[email protected]> Change-Id: I16685bceef25f428669c5412d897b6918a424119
1 parent feaa5aa commit 008fc7b

File tree

2 files changed

+78
-2
lines changed

2 files changed

+78
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2262,8 +2262,52 @@ LogicalResult tosa::GatherOp::inferReturnTypeComponents(
22622262
}
22632263

22642264
LogicalResult tosa::GatherOp::verify() {
2265-
return verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2266-
/* outType = */ getOutput().getType());
2265+
if (verifySameElementTypes(*this, /* inType = */ getValues().getType(),
2266+
/* outType = */ getOutput().getType())
2267+
.failed()) {
2268+
return failure();
2269+
}
2270+
2271+
const ShapeAdaptor valuesShape(getValues().getType());
2272+
const ShapeAdaptor indicesShape(getIndices().getType());
2273+
const ShapeAdaptor outputShape(getOutput().getType());
2274+
2275+
int64_t N = ShapedType::kDynamic;
2276+
int64_t W = ShapedType::kDynamic;
2277+
int64_t C = ShapedType::kDynamic;
2278+
2279+
if (valuesShape.hasRank()) {
2280+
N = valuesShape.getDimSize(0);
2281+
C = valuesShape.getDimSize(2);
2282+
}
2283+
if (indicesShape.hasRank()) {
2284+
const int64_t indicesN = indicesShape.getDimSize(0);
2285+
W = indicesShape.getDimSize(1);
2286+
if (N == ShapedType::kDynamic)
2287+
N = indicesN;
2288+
else if (indicesN != ShapedType::kDynamic && N != indicesN)
2289+
return emitOpError() << "requires indices dimension 0 to have size " << N
2290+
<< ", got " << indicesN;
2291+
}
2292+
if (outputShape.hasRank()) {
2293+
const int64_t outputN = outputShape.getDimSize(0);
2294+
const int64_t outputW = outputShape.getDimSize(1);
2295+
const int64_t outputC = outputShape.getDimSize(2);
2296+
if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2297+
N != outputN)
2298+
return emitOpError() << "requires output dimension 0 to have size " << N
2299+
<< ", got " << outputN;
2300+
2301+
if (W != ShapedType::kDynamic && outputW != ShapedType::kDynamic &&
2302+
W != outputW)
2303+
return emitOpError() << "requires output dimension 1 to have size " << W
2304+
<< ", got " << outputW;
2305+
if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2306+
C != outputC)
2307+
return emitOpError() << "requires output dimension 2 to have size " << C
2308+
<< ", got " << outputC;
2309+
}
2310+
return success();
22672311
}
22682312

22692313
LogicalResult tosa::ResizeOp::inferReturnTypeComponents(

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,3 +358,35 @@ func.func @test_concat_axis_sum_error(%arg0: tensor<1x2xf32>, %arg1: tensor<2x?x
358358
%0 = tosa.concat %arg0, %arg1 {axis = 0 : i32} : (tensor<1x2xf32>, tensor<2x?xf32>) -> tensor<2x?xf32>
359359
return %0 : tensor<2x?xf32>
360360
}
361+
362+
// -----
363+
// CHECK-LABEL: @test_gather_invalid_indices_N
364+
func.func @test_gather_invalid_indices_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<12x26xi32>) -> tensor<13x26x3xf32> {
365+
// expected-error@+1 {{'tosa.gather' op requires indices dimension 0 to have size 13, got 12}}
366+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<12x26xi32>) -> tensor<13x26x3xf32>
367+
return %0 : tensor<13x26x3xf32>
368+
}
369+
370+
// -----
371+
// CHECK-LABEL: test_gather_invalid_out_N
372+
func.func @test_gather_invalid_out_N(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<12x26x3xf32> {
373+
// expected-error@+1 {{'tosa.gather' op requires output dimension 0 to have size 13, got 12}}
374+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<12x26x3xf32>
375+
return %0 : tensor<12x26x3xf32>
376+
}
377+
378+
// -----
379+
// CHECK-LABEL: test_gather_invalid_out_W
380+
func.func @test_gather_invalid_out_W(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x28x3xf32> {
381+
// expected-error@+1 {{'tosa.gather' op requires output dimension 1 to have size 26, got 28}}
382+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x28x3xf32>
383+
return %0 : tensor<13x28x3xf32>
384+
}
385+
386+
// -----
387+
// CHECK-LABEL: test_gather_invalid_out_C
388+
func.func @test_gather_invalid_out_C(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> tensor<13x26x8xf32> {
389+
// expected-error@+1 {{'tosa.gather' op requires output dimension 2 to have size 3, got 8}}
390+
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<13x26x8xf32>
391+
return %0 : tensor<13x26x8xf32>
392+
}

0 commit comments

Comments
 (0)