@@ -730,26 +730,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
730730
731731// CHECK-LABEL: test_mul_type_mismatch
732732func.func @test_mul_type_mismatch (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf16 >) -> tensor <13 x21 x3 xf32 > {
733+ %shift = " tosa.const" () {value = dense <0 > : tensor <1 xi8 >} : () -> tensor <1 xi8 >
733734 // expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}}
734- %0 = tosa.mul %arg0 , %arg1 : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf16 >) -> tensor <13 x21 x3 xf32 >
735+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf16 >, tensor < 1 x i8 >) -> tensor <13 x21 x3 xf32 >
735736 return %0 : tensor <13 x21 x3 xf32 >
736737}
737738
738739// -----
739740
740741// CHECK-LABEL: test_mul_invalid_shift
741- func.func @test_mul_invalid_shift (%arg0: tensor <13 x 21 x 3 x i32 >, %arg1: tensor <13 x 1 x 3 x i32 >) -> tensor <13 x 21 x 3 x i32 > {
742- %shift = " tosa.const" () {value = dense <0.0 > : tensor <f32 >} : () -> tensor <f32 >
743- // expected-error@+1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>' }}
744- %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x 21 x 3 x i32 >, tensor <13 x 1 x 3 x i32 >, tensor <f32 >) -> tensor <13 x 21 x 3 x i32 >
745- return %0 : tensor <13 x 21 x 3 x i32 >
742+ func.func @test_mul_invalid_shift (%arg0: tensor <13 x 21 x 3 x f32 >, %arg1: tensor <13 x 1 x 3 x f32 >) -> tensor <13 x 21 x 3 x f32 > {
743+ %shift = " tosa.const" () {value = dense <1 > : tensor <1 x i8 >} : () -> tensor <1 x i8 >
744+ // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type }}
745+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x 21 x 3 x f32 >, tensor <13 x 1 x 3 x f32 >, tensor <1 x i8 >) -> tensor <13 x 21 x 3 x f32 >
746+ return %0 : tensor <13 x 21 x 3 x f32 >
746747}
747748
748749// -----
749750
750751// CHECK-LABEL: test_mul_missing_shift
751752func.func @test_mul_missing_shift (%arg0: tensor <13 x21 x3 xi32 >, %arg1: tensor <13 x1 x3 xi32 >) -> tensor <13 x21 x3 xi32 > {
752- // this is ok because mul's shift operand is optional for now
753+ // expected-error@+1 {{'tosa. mul' op expected 3 operands, but found 2}}
753754 %0 = tosa.mul %arg0 , %arg1 : (tensor <13 x21 x3 xi32 >, tensor <13 x1 x3 xi32 >) -> tensor <13 x21 x3 xi32 >
754755 return %0 : tensor <13 x21 x3 xi32 >
755756}
@@ -1061,3 +1062,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
10611062 %0 = tosa.sub %arg0 , %arg1 : (tensor <1 x21 x3 xf32 >, tensor <13 x21 x3 xf32 >) -> tensor <1 x13 x21 x3 xf32 >
10621063 return %0 : tensor <1 x13 x21 x3 xf32 >
10631064}
1065+
1066+ // -----
1067+ // CHECK-LABEL: test_mul_non_scalar_shift_2d
1068+ func.func @test_mul_non_scalar_shift_2d (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1069+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 x1 xi8 >}> : () -> tensor <1 x1 xi8 >
1070+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
1071+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf32 >, tensor <1 x1 xi8 >) -> tensor <13 x21 x3 xf32 >
1072+ return %0 : tensor <13 x21 x3 xf32 >
1073+ }
1074+
1075+ // -----
1076+ // CHECK-LABEL: test_mul_non_scalar_shift_1d
1077+ func.func @test_mul_non_scalar_shift_1d (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1078+ %shift = " tosa.const" () <{value = dense <0 > : tensor <2 xi8 >}> : () -> tensor <2 xi8 >
1079+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
1080+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf32 >, tensor <2 xi8 >) -> tensor <13 x21 x3 xf32 >
1081+ return %0 : tensor <13 x21 x3 xf32 >
1082+ }
1083+
1084+ // -----
1085+ // CHECK-LABEL: test_mul_non_broadcast
1086+ func.func @test_mul_non_broadcast (%arg0: tensor <13 x21 x2 xf32 >, %arg1: tensor <3 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1087+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
1088+ // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
1089+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x2 xf32 >, tensor <3 x1 x3 xf32 >, tensor <1 xi8 >) -> tensor <13 x21 x3 xf32 >
1090+ return %0 : tensor <13 x21 x3 xf32 >
1091+ }
0 commit comments