Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2692,6 +2692,73 @@ LogicalResult tosa::ScatterOp::verify() {
.failed()) {
return failure();
}

const ShapeAdaptor valuesInShape(getValuesIn().getType());
const ShapeAdaptor indicesShape(getIndices().getType());
const ShapeAdaptor inputShape(getInput().getType());
const ShapeAdaptor outputShape(getValuesOut().getType());

int64_t N = ShapedType::kDynamic;
int64_t K = ShapedType::kDynamic;
int64_t W = ShapedType::kDynamic;
int64_t C = ShapedType::kDynamic;
if (valuesInShape.hasRank()) {
N = valuesInShape.getDimSize(0);
K = valuesInShape.getDimSize(1);
C = valuesInShape.getDimSize(2);
}
if (indicesShape.hasRank()) {
const int64_t indicesN = indicesShape.getDimSize(0);
W = indicesShape.getDimSize(1);
if (N == ShapedType::kDynamic)
N = indicesN;
else if (indicesN != ShapedType::kDynamic && N != indicesN)
return emitOpError() << "requires indices dimension 0 to have size " << N
<< ", got " << indicesN;
}
if (inputShape.hasRank()) {
const int64_t inputN = inputShape.getDimSize(0);
const int64_t inputW = inputShape.getDimSize(1);
const int64_t inputC = inputShape.getDimSize(2);
if (N == ShapedType::kDynamic)
N = inputN;
else if (inputN != ShapedType::kDynamic && N != inputN)
return emitOpError() << "requires input dimension 0 to have size " << N
<< ", got " << inputN;
if (W == ShapedType::kDynamic)
W = inputW;
else if (inputW != ShapedType::kDynamic && W != inputW)
return emitOpError() << "requires input dimension 1 to have size " << W
<< ", got " << inputW;

if (C == ShapedType::kDynamic)
C = inputC;
else if (inputC != ShapedType::kDynamic && C != inputC)
return emitOpError() << "requires input dimension 2 to have size " << C
<< ", got " << inputC;
}
if (outputShape.hasRank()) {
const int64_t outputN = outputShape.getDimSize(0);
const int64_t outputK = outputShape.getDimSize(1);
const int64_t outputC = outputShape.getDimSize(2);
if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic &&
N != outputN)
return emitOpError() << "requires values_out dimension 0 to have size "
<< N << ", got " << outputN;
if (K == ShapedType::kDynamic)
K = outputK;
else if (outputK != ShapedType::kDynamic && K != outputK)
return emitOpError() << "requires values_out dimension 1 to have size "
<< K << ", got " << outputK;
if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic &&
C != outputC)
return emitOpError() << "requires values_out dimension 2 to have size "
<< C << ", got " << outputC;
}
if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W))
return emitOpError() << "requires dimensions K >= W, got K=" << K
<< " and W=" << W;

return success();
}

Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/availability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->

// -----
// CHECK-LABEL: scatter
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
// CHECK: profiles: [ [pro_int, pro_fp] ]
// CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ]
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
return %0 : tensor<13x28x3xf32>
}

// -----
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/invalid_extension.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) ->
}

// -----
func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> {
func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> {
// expected-error@+1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}}
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16>
return %0 : tensor<13x21x3xbf16>
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16>
return %0 : tensor<13x26x3xbf16>
}

// -----
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/level_check.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1080,10 +1080,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a

// -----

func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> {
func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> {
// expected-error@+1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}}
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32>
return %0 : tensor<13x210000000x3xf32>
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32>
return %0 : tensor<13x260000000x3xf32>
}

// -----
Expand Down
22 changes: 11 additions & 11 deletions mlir/test/Dialect/Tosa/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -714,9 +714,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->

// -----
// CHECK-LABEL: scatter
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> {
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32>
return %0 : tensor<13x52x3xf32>
}

// -----
Expand All @@ -728,8 +728,8 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso

// -----
// CHECK-LABEL: scatter_unranked_indices
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> {
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
}

Expand Down Expand Up @@ -1010,9 +1010,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26

// -----
// CHECK-LABEL: scatter_f8E5M2
func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> {
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2>
return %0 : tensor<13x21x3xf8E5M2>
func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> {
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2>
return %0 : tensor<13x52x3xf8E5M2>
}

// -----
Expand Down Expand Up @@ -1155,7 +1155,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1

// -----
// CHECK-LABEL: scatter_f8E4M3FN
func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> {
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN>
return %0 : tensor<13x21x3xf8E4M3FN>
func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> {
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN>
return %0 : tensor<13x29x3xf8E4M3FN>
}
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) ->
}

// -----
func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> {
func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> {
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}}
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32>
return %0 : tensor<13x21x3xf32>
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32>
return %0 : tensor<13x28x3xf32>
}

// -----
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) ->
}

// -----
func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> {
func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> {
// expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}}
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32>
return %0 : tensor<13x21x3xi32>
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32>
return %0 : tensor<13x27x3xi32>
}

// -----
Expand Down
6 changes: 3 additions & 3 deletions mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -656,9 +656,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor<?x6xi32
// -----

// CHECK-LABEL: @scatter_static
func.func @scatter_static(%arg0 : tensor<3x4x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
// CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32>
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) {
// CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32>
%0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<?x?x?xi32>
return
}

Expand Down
72 changes: 72 additions & 0 deletions mlir/test/Dialect/Tosa/verifier.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32>
return
}

// -----

// CHECK-LABEL: @scatter_invalid_indices_N
func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
%1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32>
return
}

// -----

// CHECK-LABEL: @scatter_invalid_input_N
func.func @scatter_invalid_input_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32>
return
}

// -----

// CHECK-LABEL: @scatter_invalid_out_N
func.func @scatter_invalid_out_N(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<3x4x5xi32>
return
}

// -----

// CHECK-LABEL: @scatter_invalid_out_K
func.func @scatter_invalid_out_K(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x3x5xi32>
return
}

// -----

// CHECK-LABEL: @scatter_invalid_input_W
func.func @scatter_invalid_input_W(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x3x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x3x5xi32>) -> tensor<2x4x5xi32>
return
}

// -----

// CHECK-LABEL: @scatter_invalid_input_C
func.func @scatter_invalid_input_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x6xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x6xi32>) -> tensor<2x4x5xi32>
return
}

// -----

// CHECK-LABEL: @scatter_invalid_out_C
func.func @scatter_invalid_out_C(%arg0 : tensor<?x4x5xi32>, %arg1 : tensor<?x2xi32>, %arg2 : tensor<2x2x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<?x4x5xi32>, tensor<?x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x6xi32>
return
}

// -----

// CHECK-LABEL: @scatter_invalid_K_W
func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) {
// expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
%2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32>
return
}
Loading