@@ -864,3 +864,75 @@ func.func @test_variable_write_shape_mismatch(%arg0: tensor<2x4x8xf32>) -> () {
864864 tosa.variable_write @stored_var , %arg0 : tensor <2 x4 x8 xf32 >
865865 return
866866}
867+
868+ // -----
869+
870+ // CHECK-LABEL: @scatter_invalid_indices_N
871+ func.func @scatter_invalid_indices_N (%arg0 : tensor <2 x4 x5 xi32 >, %arg1 : tensor <3 x2 xi32 >, %arg2 : tensor <2 x2 x5 xi32 >) {
872+ // expected-error@+1 {{'tosa.scatter' op requires indices dimension 0 to have size 2, got 3}}
873+ %1 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <2 x4 x5 xi32 >, tensor <3 x2 xi32 >, tensor <2 x2 x5 xi32 >) -> tensor <2 x4 x5 xi32 >
874+ return
875+ }
876+
877+ // -----
878+
879+ // CHECK-LABEL: @scatter_invalid_input_N
880+ func.func @scatter_invalid_input_N (%arg0 : tensor <?x4 x5 xi32 >, %arg1 : tensor <2 x2 xi32 >, %arg2 : tensor <3 x2 x5 xi32 >) {
881+ // expected-error@+1 {{'tosa.scatter' op requires input dimension 0 to have size 2, got 3}}
882+ %2 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <?x4 x5 xi32 >, tensor <2 x2 xi32 >, tensor <3 x2 x5 xi32 >) -> tensor <2 x4 x5 xi32 >
883+ return
884+ }
885+
886+ // -----
887+
888+ // CHECK-LABEL: @scatter_invalid_out_N
889+ func.func @scatter_invalid_out_N (%arg0 : tensor <?x4 x5 xi32 >, %arg1 : tensor <?x2 xi32 >, %arg2 : tensor <2 x2 x5 xi32 >) {
890+ // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 0 to have size 2, got 3}}
891+ %2 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <?x4 x5 xi32 >, tensor <?x2 xi32 >, tensor <2 x2 x5 xi32 >) -> tensor <3 x4 x5 xi32 >
892+ return
893+ }
894+
895+ // -----
896+
897+ // CHECK-LABEL: @scatter_invalid_out_K
898+ func.func @scatter_invalid_out_K (%arg0 : tensor <?x4 x5 xi32 >, %arg1 : tensor <?x2 xi32 >, %arg2 : tensor <2 x2 x5 xi32 >) {
899+ // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 1 to have size 4, got 3}}
900+ %2 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <?x4 x5 xi32 >, tensor <?x2 xi32 >, tensor <2 x2 x5 xi32 >) -> tensor <2 x3 x5 xi32 >
901+ return
902+ }
903+
904+ // -----
905+
906+ // CHECK-LABEL: @scatter_invalid_input_W
907+ func.func @scatter_invalid_input_W (%arg0 : tensor <?x4 x5 xi32 >, %arg1 : tensor <?x2 xi32 >, %arg2 : tensor <2 x3 x5 xi32 >) {
908+ // expected-error@+1 {{'tosa.scatter' op requires input dimension 1 to have size 2, got 3}}
909+ %2 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <?x4 x5 xi32 >, tensor <?x2 xi32 >, tensor <2 x3 x5 xi32 >) -> tensor <2 x4 x5 xi32 >
910+ return
911+ }
912+
913+ // -----
914+
915+ // CHECK-LABEL: @scatter_invalid_input_C
916+ func.func @scatter_invalid_input_C (%arg0 : tensor <?x4 x5 xi32 >, %arg1 : tensor <?x2 xi32 >, %arg2 : tensor <2 x2 x6 xi32 >) {
917+ // expected-error@+1 {{'tosa.scatter' op requires input dimension 2 to have size 5, got 6}}
918+ %2 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <?x4 x5 xi32 >, tensor <?x2 xi32 >, tensor <2 x2 x6 xi32 >) -> tensor <2 x4 x5 xi32 >
919+ return
920+ }
921+
922+ // -----
923+
924+ // CHECK-LABEL: @scatter_invalid_out_C
925+ func.func @scatter_invalid_out_C (%arg0 : tensor <?x4 x5 xi32 >, %arg1 : tensor <?x2 xi32 >, %arg2 : tensor <2 x2 x5 xi32 >) {
926+ // expected-error@+1 {{'tosa.scatter' op requires values_out dimension 2 to have size 5, got 6}}
927+ %2 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <?x4 x5 xi32 >, tensor <?x2 xi32 >, tensor <2 x2 x5 xi32 >) -> tensor <2 x4 x6 xi32 >
928+ return
929+ }
930+
931+ // -----
932+
933+ // CHECK-LABEL: @scatter_invalid_K_W
934+ func.func @scatter_invalid_K_W (%arg0 : tensor <2 x4 x5 xi32 >, %arg1 : tensor <2 x6 xi32 >, %arg2 : tensor <2 x6 x5 xi32 >) {
935+ // expected-error@+1 {{'tosa.scatter' op requires dimensions K >= W, got K=4 and W=6}}
936+ %2 = tosa.scatter %arg0 , %arg1 , %arg2 : (tensor <2 x4 x5 xi32 >, tensor <2 x6 xi32 >, tensor <2 x6 x5 xi32 >) -> tensor <2 x4 x5 xi32 >
937+ return
938+ }
0 commit comments