Skip to content

Commit 0651ac6

Browse files
committed
[tosa] Add verifier checks for Scatter
This adds verifier checks for the scatter op to make sure the shapes of inputs and output are consistent with respect to spec. Signed-off-by: Tai Ly <[email protected]> Change-Id: I59531fa63e2d1dbd2865e0ef9b08b76991915c9a
1 parent 50f9b8a commit 0651ac6

File tree

9 files changed

+168
-29
lines changed

9 files changed

+168
-29
lines changed

mlir/lib/Dialect/Tosa/IR/TosaOps.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2692,6 +2692,73 @@ LogicalResult tosa::ScatterOp::verify() {
26922692
.failed()) {
26932693
return failure();
26942694
}
2695+
2696+
const ShapeAdaptor valuesInShape(getValuesIn().getType());
2697+
const ShapeAdaptor indicesShape(getIndices().getType());
2698+
const ShapeAdaptor inputShape(getInput().getType());
2699+
const ShapeAdaptor outputShape(getValuesOut().getType());
2700+
2701+
int64_t N = ShapedType::kDynamic;
2702+
int64_t K = ShapedType::kDynamic;
2703+
int64_t W = ShapedType::kDynamic;
2704+
int64_t C = ShapedType::kDynamic;
2705+
if (valuesInShape.hasRank()) {
2706+
N = valuesInShape.getDimSize(0);
2707+
K = valuesInShape.getDimSize(1);
2708+
C = valuesInShape.getDimSize(2);
2709+
}
2710+
if (indicesShape.hasRank()) {
2711+
const int64_t indicesN = indicesShape.getDimSize(0);
2712+
W = indicesShape.getDimSize(1);
2713+
if (N == ShapedType::kDynamic)
2714+
N = indicesN;
2715+
else if (indicesN != ShapedType::kDynamic && N != indicesN)
2716+
return emitOpError() << "requires indices dimension 0 to have size " << N
2717+
<< ", got " << indicesN;
2718+
}
2719+
if (inputShape.hasRank()) {
2720+
const int64_t inputN = inputShape.getDimSize(0);
2721+
const int64_t inputW = inputShape.getDimSize(1);
2722+
const int64_t inputC = inputShape.getDimSize(2);
2723+
if (N == ShapedType::kDynamic)
2724+
N = inputN;
2725+
else if (inputN != ShapedType::kDynamic && N != inputN)
2726+
return emitOpError() << "requires input dimension 0 to have size " << N
2727+
<< ", got " << inputN;
2728+
if (W == ShapedType::kDynamic)
2729+
W = inputW;
2730+
else if (inputW != ShapedType::kDynamic && W != inputW)
2731+
return emitOpError() << "requires input dimension 1 to have size " << W
2732+
<< ", got " << inputW;
2733+
2734+
if (C == ShapedType::kDynamic)
2735+
C = inputC;
2736+
else if (inputC != ShapedType::kDynamic && C != inputC)
2737+
return emitOpError() << "requires input dimension 2 to have size " << C
2738+
<< ", got " << inputC;
2739+
}
2740+
if (outputShape.hasRank()) {
2741+
const int64_t outputN = outputShape.getDimSize(0);
2742+
const int64_t outputK = outputShape.getDimSize(1);
2743+
const int64_t outputC = outputShape.getDimSize(2);
2744+
if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
2745+
N != outputN)
2746+
return emitOpError() << "requires values_out dimension 0 to have size "
2747+
<< N << ", got " << outputN;
2748+
if (K == ShapedType::kDynamic)
2749+
K = outputK;
2750+
else if (outputK != ShapedType::kDynamic && K != outputK)
2751+
return emitOpError() << "requires values_out dimension 1 to have size "
2752+
<< K << ", got " << outputK;
2753+
if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
2754+
C != outputC)
2755+
return emitOpError() << "requires values_out dimension 2 to have size "
2756+
<< C << ", got " << outputC;
2757+
}
2758+
if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
2759+
return emitOpError() << "requires dimensions K >= W, got K=" << K
2760+
<< " and W=" << W;
2761+
26952762
return success();
26962763
}
26972764

mlir/test/Dialect/Tosa/availability.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
583583

584584
// -----
585585
// CHECK-LABEL: scatter
586-
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
586+
func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
587587
// CHECK: profiles: [ [pro_int, pro_fp] ]
588588
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
589-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
590-
return %0 : tensor<13x21x3xf32>
589+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
590+
return %0 : tensor<13x28x3xf32>
591591
}
592592

593593
// -----

mlir/test/Dialect/Tosa/invalid_extension.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
243243
}
244244

245245
// -----
246-
func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> {
246+
func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
247247
// expected-error@+1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}}
248-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16>
249-
return %0 : tensor<13x21x3xbf16>
248+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16>
249+
return %0 : tensor<13x26x3xbf16>
250250
}
251251

252252
// -----

mlir/test/Dialect/Tosa/level_check.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,10 +1080,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a
10801080

10811081
// -----
10821082

1083-
func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
1083+
func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> {
10841084
// expected-error@+1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
1085-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
1086-
return %0 : tensor<13x210000000x3xf32>
1085+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32>
1086+
return %0 : tensor<13x260000000x3xf32>
10871087
}
10881088

10891089
// -----

mlir/test/Dialect/Tosa/ops.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -714,9 +714,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
714714

715715
// -----
716716
// CHECK-LABEL: scatter
717-
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
718-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
719-
return %0 : tensor<13x21x3xf32>
717+
func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
718+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
719+
return %0 : tensor<13x52x3xf32>
720720
}
721721

722722
// -----
@@ -728,8 +728,8 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso
728728

729729
// -----
730730
// CHECK-LABEL: scatter_unranked_indices
731-
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
732-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
731+
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
732+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
733733
return %0 : tensor<13x21x3xf32>
734734
}
735735

@@ -1010,9 +1010,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26
10101010

10111011
// -----
10121012
// CHECK-LABEL: scatter_f8E5M2
1013-
func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
1014-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
1015-
return %0 : tensor<13x21x3xf8E5M2>
1013+
func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> {
1014+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2>
1015+
return %0 : tensor<13x52x3xf8E5M2>
10161016
}
10171017

10181018
// -----
@@ -1155,7 +1155,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1
11551155

11561156
// -----
11571157
// CHECK-LABEL: scatter_f8E4M3FN
1158-
func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
1159-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
1160-
return %0 : tensor<13x21x3xf8E4M3FN>
1158+
func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> {
1159+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
1160+
return %0 : tensor<13x29x3xf8E4M3FN>
11611161
}

mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
310310
}
311311

312312
// -----
313-
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
313+
func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
314314
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}}
315-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
316-
return %0 : tensor<13x21x3xf32>
315+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
316+
return %0 : tensor<13x28x3xf32>
317317
}
318318

319319
// -----

mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) ->
242242
}
243243

244244
// -----
245-
func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> {
245+
func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> {
246246
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}}
247-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32>
248-
return %0 : tensor<13x21x3xi32>
247+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32>
248+
return %0 : tensor<13x27x3xi32>
249249
}
250250

251251
// -----

mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -656,9 +656,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
656656
// -----
657657

658658
// CHECK-LABEL: @scatter_static
659-
func.func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
660-
// CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
661-
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
659+
func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
660+
// CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
661+
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
662662
return
663663
}
664664

mlir/test/Dialect/Tosa/verifier.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
864864
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
865865
return
866866
}
867+
868+
// -----
869+
870+
// CHECK-LABEL: @scatter_invalid_indices_N
871+
func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
872+
// expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
873+
%1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
874+
return
875+
}
876+
877+
// -----
878+
879+
// CHECK-LABEL: @scatter_invalid_input_N
880+
func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
881+
// expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
882+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
883+
return
884+
}
885+
886+
// -----
887+
888+
// CHECK-LABEL: @scatter_invalid_out_N
889+
func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
890+
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
891+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
892+
return
893+
}
894+
895+
// -----
896+
897+
// CHECK-LABEL: @scatter_invalid_out_K
898+
func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
899+
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
900+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
901+
return
902+
}
903+
904+
// -----
905+
906+
// CHECK-LABEL: @scatter_invalid_input_W
907+
func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
908+
// expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
909+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
910+
return
911+
}
912+
913+
// -----
914+
915+
// CHECK-LABEL: @scatter_invalid_input_C
916+
func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
917+
// expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
918+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
919+
return
920+
}
921+
922+
// -----
923+
924+
// CHECK-LABEL: @scatter_invalid_out_C
925+
func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
926+
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
927+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
928+
return
929+
}
930+
931+
// -----
932+
933+
// CHECK-LABEL: @scatter_invalid_K_W
934+
func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
935+
// expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
936+
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
937+
return
938+
}

0 commit comments

Comments
 (0)