@@ -729,26 +729,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
729729
730730// CHECK-LABEL: test_mul_type_mismatch
731731func.func @test_mul_type_mismatch (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf16 >) -> tensor <13 x21 x3 xf32 > {
732+ %shift = " tosa.const" () {value = dense <0 > : tensor <1 xi8 >} : () -> tensor <1 xi8 >
732733 // expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}}
733- %0 = tosa.mul %arg0 , %arg1 : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf16 >) -> tensor <13 x21 x3 xf32 >
734+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf16 >, tensor < 1 x i8 >) -> tensor <13 x21 x3 xf32 >
734735 return %0 : tensor <13 x21 x3 xf32 >
735736}
736737
737738// -----
738739
739740// CHECK-LABEL: test_mul_invalid_shift
740- func.func @test_mul_invalid_shift (%arg0: tensor <13 x 21 x 3 x i32 >, %arg1: tensor <13 x 1 x 3 x i32 >) -> tensor <13 x 21 x 3 x i32 > {
741- %shift = " tosa.const" () {value = dense <0.0 > : tensor <f32 >} : () -> tensor <f32 >
742- // expected-error@+1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>' }}
743- %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x 21 x 3 x i32 >, tensor <13 x 1 x 3 x i32 >, tensor <f32 >) -> tensor <13 x 21 x 3 x i32 >
744- return %0 : tensor <13 x 21 x 3 x i32 >
741+ func.func @test_mul_invalid_shift (%arg0: tensor <13 x 21 x 3 x f32 >, %arg1: tensor <13 x 1 x 3 x f32 >) -> tensor <13 x 21 x 3 x f32 > {
742+ %shift = " tosa.const" () {value = dense <1 > : tensor <1 x i8 >} : () -> tensor <1 x i8 >
743+ // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type }}
744+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x 21 x 3 x f32 >, tensor <13 x 1 x 3 x f32 >, tensor <1 x i8 >) -> tensor <13 x 21 x 3 x f32 >
745+ return %0 : tensor <13 x 21 x 3 x f32 >
745746}
746747
747748// -----
748749
749750// CHECK-LABEL: test_mul_missing_shift
750751func.func @test_mul_missing_shift (%arg0: tensor <13 x21 x3 xi32 >, %arg1: tensor <13 x1 x3 xi32 >) -> tensor <13 x21 x3 xi32 > {
751- // this is ok because mul's shift operand is optional for now
752+ // expected-error@+1 {{'tosa. mul' op expected 3 operands, but found 2}}
752753 %0 = tosa.mul %arg0 , %arg1 : (tensor <13 x21 x3 xi32 >, tensor <13 x1 x3 xi32 >) -> tensor <13 x21 x3 xi32 >
753754 return %0 : tensor <13 x21 x3 xi32 >
754755}
@@ -1060,3 +1061,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
10601061 %0 = tosa.sub %arg0 , %arg1 : (tensor <1 x21 x3 xf32 >, tensor <13 x21 x3 xf32 >) -> tensor <1 x13 x21 x3 xf32 >
10611062 return %0 : tensor <1 x13 x21 x3 xf32 >
10621063}
1064+
1065+ // -----
1066+ // CHECK-LABEL: test_mul_non_scalar_shift_2d
1067+ func.func @test_mul_non_scalar_shift_2d (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1068+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 x1 xi8 >}> : () -> tensor <1 x1 xi8 >
1069+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
1070+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf32 >, tensor <1 x1 xi8 >) -> tensor <13 x21 x3 xf32 >
1071+ return %0 : tensor <13 x21 x3 xf32 >
1072+ }
1073+
1074+ // -----
1075+ // CHECK-LABEL: test_mul_non_scalar_shift_1d
1076+ func.func @test_mul_non_scalar_shift_1d (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1077+ %shift = " tosa.const" () <{value = dense <0 > : tensor <2 xi8 >}> : () -> tensor <2 xi8 >
1078+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
1079+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf32 >, tensor <2 xi8 >) -> tensor <13 x21 x3 xf32 >
1080+ return %0 : tensor <13 x21 x3 xf32 >
1081+ }
1082+
1083+ // -----
1084+ // CHECK-LABEL: test_mul_non_broadcast
1085+ func.func @test_mul_non_broadcast (%arg0: tensor <13 x21 x2 xf32 >, %arg1: tensor <3 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1086+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
1087+ // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
1088+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x2 xf32 >, tensor <3 x1 x3 xf32 >, tensor <1 xi8 >) -> tensor <13 x21 x3 xf32 >
1089+ return %0 : tensor <13 x21 x3 xf32 >
1090+ }
0 commit comments