diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index a22e6b7aa9791..f707770970e5f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -2692,6 +2692,73 @@ LogicalResult tosa::ScatterOp::verify() { .failed()) { return failure(); } + + const ShapeAdaptor valuesInShape(getValuesIn().getType()); + const ShapeAdaptor indicesShape(getIndices().getType()); + const ShapeAdaptor inputShape(getInput().getType()); + const ShapeAdaptor outputShape(getValuesOut().getType()); + + int64_t N = ShapedType::kDynamic; + int64_t K = ShapedType::kDynamic; + int64_t W = ShapedType::kDynamic; + int64_t C = ShapedType::kDynamic; + if (valuesInShape.hasRank()) { + N = valuesInShape.getDimSize(0); + K = valuesInShape.getDimSize(1); + C = valuesInShape.getDimSize(2); + } + if (indicesShape.hasRank()) { + const int64_t indicesN = indicesShape.getDimSize(0); + W = indicesShape.getDimSize(1); + if (N == ShapedType::kDynamic) + N = indicesN; + else if (indicesN != ShapedType::kDynamic && N != indicesN) + return emitOpError() << "requires indices dimension 0 to have size " << N + << ", got " << indicesN; + } + if (inputShape.hasRank()) { + const int64_t inputN = inputShape.getDimSize(0); + const int64_t inputW = inputShape.getDimSize(1); + const int64_t inputC = inputShape.getDimSize(2); + if (N == ShapedType::kDynamic) + N = inputN; + else if (inputN != ShapedType::kDynamic && N != inputN) + return emitOpError() << "requires input dimension 0 to have size " << N + << ", got " << inputN; + if (W == ShapedType::kDynamic) + W = inputW; + else if (inputW != ShapedType::kDynamic && W != inputW) + return emitOpError() << "requires input dimension 1 to have size " << W + << ", got " << inputW; + + if (C == ShapedType::kDynamic) + C = inputC; + else if (inputC != ShapedType::kDynamic && C != inputC) + return emitOpError() << "requires input dimension 2 to have size " << C + << ", got " << inputC; + } + if (outputShape.hasRank()) { + const int64_t outputN = outputShape.getDimSize(0); + const int64_t outputK = outputShape.getDimSize(1); + const int64_t outputC = outputShape.getDimSize(2); + if (N != ShapedType::kDynamic && outputN != ShapedType::kDynamic && + N != outputN) + return emitOpError() << "requires values_out dimension 0 to have size " + << N << ", got " << outputN; + if (K == ShapedType::kDynamic) + K = outputK; + else if (outputK != ShapedType::kDynamic && K != outputK) + return emitOpError() << "requires values_out dimension 1 to have size " + << K << ", got " << outputK; + if (C != ShapedType::kDynamic && outputC != ShapedType::kDynamic && + C != outputC) + return emitOpError() << "requires values_out dimension 2 to have size " + << C << ", got " << outputC; + } + if (K != ShapedType::kDynamic && W != ShapedType::kDynamic && !(K >= W)) + return emitOpError() << "requires dimensions K >= W, got K=" << K + << " and W=" << W; + return success(); } diff --git a/mlir/test/Dialect/Tosa/availability.mlir b/mlir/test/Dialect/Tosa/availability.mlir index 75126a11ac504..0176fc2883518 100644 --- a/mlir/test/Dialect/Tosa/availability.mlir +++ b/mlir/test/Dialect/Tosa/availability.mlir @@ -583,11 +583,11 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> // ----- // CHECK-LABEL: scatter -func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> { +func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> { // CHECK: profiles: [ [pro_int, pro_fp] ] // CHECK: extensions: [ [fp8e4m3, fp8e5m2, bf16] ] - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32> - return %0 : tensor<13x21x3xf32> + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32> + return %0 : tensor<13x28x3xf32> } // ----- diff --git a/mlir/test/Dialect/Tosa/invalid_extension.mlir b/mlir/test/Dialect/Tosa/invalid_extension.mlir index 2364985442e43..5630c33639d86 100644 --- a/mlir/test/Dialect/Tosa/invalid_extension.mlir +++ b/mlir/test/Dialect/Tosa/invalid_extension.mlir @@ -243,10 +243,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>) -> } // ----- -func.func @test_scatter(%arg0: tensor<13x21x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> { +func.func @test_scatter(%arg0: tensor<13x26x3xbf16>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> { // expected-error@+1 {{'tosa.scatter' op illegal: requires [bf16] but not enabled in target}} - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x21x3xbf16> - return %0 : tensor<13x21x3xbf16> + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x26x3xbf16>, tensor<13x26xi32>, tensor<13x26x3xbf16>) -> tensor<13x26x3xbf16> + return %0 : tensor<13x26x3xbf16> } // ----- diff --git a/mlir/test/Dialect/Tosa/level_check.mlir b/mlir/test/Dialect/Tosa/level_check.mlir index 223bf3b635e18..0dddf26fb1f85 100644 --- a/mlir/test/Dialect/Tosa/level_check.mlir +++ b/mlir/test/Dialect/Tosa/level_check.mlir @@ -1080,10 +1080,10 @@ func.func @test_gather_tensor_size_invalid(%arg0: tensor<268435456x21x3xf32>, %a // ----- -func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x210000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> { +func.func @test_scatter_tensor_size_invalid(%arg0: tensor<13x260000000x3xf32>, %arg1: tensor<13x260000000xi32>, %arg2: tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> { // expected-error@+1 {{'tosa.scatter' op failed level check: operand tensor size (in bytes) <= (1 << MAX_LOG2_SIZE - 1)}} - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x210000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x210000000x3xf32> - return %0 : tensor<13x210000000x3xf32> + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x260000000x3xf32>, tensor<13x260000000xi32>, tensor<13x260000000x3xf32>) -> tensor<13x260000000x3xf32> + return %0 : tensor<13x260000000x3xf32> } // ----- diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 767fa833dedd4..1ac82400843ed 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -714,9 +714,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> // ----- // CHECK-LABEL: scatter -func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32> - return %0 : tensor<13x21x3xf32> +func.func @test_scatter(%arg0: tensor<13x52x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x52x3xf32> { + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x52x3xf32> + return %0 : tensor<13x52x3xf32> } // ----- @@ -728,8 +728,8 @@ func.func @test_gather_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tenso // ----- // CHECK-LABEL: scatter_unranked_indices -func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32> +func.func @test_scatter_unranked_indices(%arg0: tensor<13x21x3xf32>, %arg1: tensor<*xi32>, %arg2: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<*xi32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } @@ -1010,9 +1010,9 @@ func.func @test_gather_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26 // ----- // CHECK-LABEL: scatter_f8E5M2 -func.func @test_scatter_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> { - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> - return %0 : tensor<13x21x3xf8E5M2> +func.func @test_scatter_f8E5M2(%arg0: tensor<13x52x3xf8E5M2>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> { + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x52x3xf8E5M2>, tensor<13x26xi32>, tensor<13x26x3xf8E5M2>) -> tensor<13x52x3xf8E5M2> + return %0 : tensor<13x52x3xf8E5M2> } // ----- @@ -1155,7 +1155,7 @@ func.func @test_gather_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<1 // ----- // CHECK-LABEL: scatter_f8E4M3FN -func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> { - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> - return %0 : tensor<13x21x3xf8E4M3FN> +func.func @test_scatter_f8E4M3FN(%arg0: tensor<13x29x3xf8E4M3FN>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> { + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x29x3xf8E4M3FN>, tensor<13x26xi32>, tensor<13x26x3xf8E4M3FN>) -> tensor<13x29x3xf8E4M3FN> + return %0 : tensor<13x29x3xf8E4M3FN> } diff --git a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir index 72669c62c95ca..fad4859351251 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_fp_unsupported.mlir @@ -310,10 +310,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>) -> } // ----- -func.func @test_scatter(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x21x3xf32> { +func.func @test_scatter(%arg0: tensor<13x28x3xf32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xf32>) -> tensor<13x28x3xf32> { // expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_fp] but not enabled in target}} - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x21x3xf32> - return %0 : tensor<13x21x3xf32> + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x28x3xf32>, tensor<13x26xi32>, tensor<13x26x3xf32>) -> tensor<13x28x3xf32> + return %0 : tensor<13x28x3xf32> } // ----- diff --git a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir index e98b906377b22..9438179622aad 100644 --- a/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir +++ b/mlir/test/Dialect/Tosa/profile_pro_int_unsupported.mlir @@ -242,10 +242,10 @@ func.func @test_gather(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>) -> } // ----- -func.func @test_scatter(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x21x3xi32> { +func.func @test_scatter(%arg0: tensor<13x27x3xi32>, %arg1: tensor<13x26xi32>, %arg2: tensor<13x26x3xi32>) -> tensor<13x27x3xi32> { // expected-error@+1 {{'tosa.scatter' op illegal: requires [pro_int] but not enabled in target}} - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x21x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x21x3xi32> - return %0 : tensor<13x21x3xi32> + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<13x27x3xi32>, tensor<13x26xi32>, tensor<13x26x3xi32>) -> tensor<13x27x3xi32> + return %0 : tensor<13x27x3xi32> } // ----- diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 1ad1e6c76c294..591a3f0acf65d 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -656,9 +656,9 @@ func.func @gather_minimum_info(%arg0 : tensor<3x?x5xi32>, %arg1 : tensor, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) { - // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x4x5xi32> - %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x4x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor +func.func @scatter_static(%arg0 : tensor<3x8x5xi32>, %arg1 : tensor<3x6xi32>, %arg2 : tensor<3x6x5xi32>) { + // CHECK: tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor<3x8x5xi32> + %0 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<3x8x5xi32>, tensor<3x6xi32>, tensor<3x6x5xi32>) -> tensor return } diff --git a/mlir/test/Dialect/Tosa/verifier.mlir b/mlir/test/Dialect/Tosa/verifier.mlir index 990e0d954f54e..b3052369b055e 100644 --- a/mlir/test/Dialect/Tosa/verifier.mlir +++ b/mlir/test/Dialect/Tosa/verifier.mlir @@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () { tosa.variable_write @stored_var, %arg0 : tensor<2x4x8xf32> return } + +// ----- + +// CHECK-LABEL: @scatter_invalid_indices_N +func.func @scatter_invalid_indices_N(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<3x2xi32>, %arg2 : tensor<2x2x5xi32>) { + // expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}} + %1 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<3x2xi32>, tensor<2x2x5xi32>) -> tensor<2x4x5xi32> + return +} + +// ----- + +// CHECK-LABEL: @scatter_invalid_input_N +func.func @scatter_invalid_input_N(%arg0 : tensor, %arg1 : tensor<2x2xi32>, %arg2 : tensor<3x2x5xi32>) { + // expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}} + %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor<2x2xi32>, tensor<3x2x5xi32>) -> tensor<2x4x5xi32> + return +} + +// ----- + +// CHECK-LABEL: @scatter_invalid_out_N +func.func @scatter_invalid_out_N(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x5xi32>) { + // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}} + %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x5xi32>) -> tensor<3x4x5xi32> + return +} + +// ----- + +// CHECK-LABEL: @scatter_invalid_out_K +func.func @scatter_invalid_out_K(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x5xi32>) { + // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}} + %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x5xi32>) -> tensor<2x3x5xi32> + return +} + +// ----- + +// CHECK-LABEL: @scatter_invalid_input_W +func.func @scatter_invalid_input_W(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor<2x3x5xi32>) { + // expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}} + %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x3x5xi32>) -> tensor<2x4x5xi32> + return +} + +// ----- + +// CHECK-LABEL: @scatter_invalid_input_C +func.func @scatter_invalid_input_C(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x6xi32>) { + // expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}} + %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x6xi32>) -> tensor<2x4x5xi32> + return +} + +// ----- + +// CHECK-LABEL: @scatter_invalid_out_C +func.func @scatter_invalid_out_C(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor<2x2x5xi32>) { + // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}} + %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor, tensor, tensor<2x2x5xi32>) -> tensor<2x4x6xi32> + return +} + +// ----- + +// CHECK-LABEL: @scatter_invalid_K_W +func.func @scatter_invalid_K_W(%arg0 : tensor<2x4x5xi32>, %arg1 : tensor<2x6xi32>, %arg2 : tensor<2x6x5xi32>) { + // expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}} + %2 = tosa.scatter %arg0, %arg1, %arg2 : (tensor<2x4x5xi32>, tensor<2x6xi32>, tensor<2x6x5xi32>) -> tensor<2x4x5xi32> + return +}