@@ -768,26 +768,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
768768
769769// CHECK-LABEL: test_mul_type_mismatch
770770func.func @test_mul_type_mismatch (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf16 >) -> tensor <13 x21 x3 xf32 > {
771+ %shift = " tosa.const" () {value = dense <0 > : tensor <1 xi8 >} : () -> tensor <1 xi8 >
771772 // expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}}
772- %0 = tosa.mul %arg0 , %arg1 : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf16 >) -> tensor <13 x21 x3 xf32 >
773+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf16 >, tensor < 1 x i8 >) -> tensor <13 x21 x3 xf32 >
773774 return %0 : tensor <13 x21 x3 xf32 >
774775}
775776
776777// -----
777778
778779// CHECK-LABEL: test_mul_invalid_shift
779- func.func @test_mul_invalid_shift (%arg0: tensor <13 x 21 x 3 x i32 >, %arg1: tensor <13 x 1 x 3 x i32 >) -> tensor <13 x 21 x 3 x i32 > {
780- %shift = " tosa.const" () {value = dense <0.0 > : tensor <f32 >} : () -> tensor <f32 >
781- // expected-error@+1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>' }}
782- %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x 21 x 3 x i32 >, tensor <13 x 1 x 3 x i32 >, tensor <f32 >) -> tensor <13 x 21 x 3 x i32 >
783- return %0 : tensor <13 x 21 x 3 x i32 >
780+ func.func @test_mul_invalid_shift (%arg0: tensor <13 x 21 x 3 x f32 >, %arg1: tensor <13 x 1 x 3 x f32 >) -> tensor <13 x 21 x 3 x f32 > {
781+ %shift = " tosa.const" () {value = dense <1 > : tensor <1 x i8 >} : () -> tensor <1 x i8 >
782+ // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type }}
783+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x 21 x 3 x f32 >, tensor <13 x 1 x 3 x f32 >, tensor <1 x i8 >) -> tensor <13 x 21 x 3 x f32 >
784+ return %0 : tensor <13 x 21 x 3 x f32 >
784785}
785786
786787// -----
787788
788789// CHECK-LABEL: test_mul_missing_shift
789790func.func @test_mul_missing_shift (%arg0: tensor <13 x21 x3 xi32 >, %arg1: tensor <13 x1 x3 xi32 >) -> tensor <13 x21 x3 xi32 > {
790- // this is ok because mul's shift operand is optional for now
791+ // expected-error@+1 {{'tosa. mul' op expected 3 operands, but found 2}}
791792 %0 = tosa.mul %arg0 , %arg1 : (tensor <13 x21 x3 xi32 >, tensor <13 x1 x3 xi32 >) -> tensor <13 x21 x3 xi32 >
792793 return %0 : tensor <13 x21 x3 xi32 >
793794}
@@ -1099,3 +1100,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
10991100 %0 = tosa.sub %arg0 , %arg1 : (tensor <1 x21 x3 xf32 >, tensor <13 x21 x3 xf32 >) -> tensor <1 x13 x21 x3 xf32 >
11001101 return %0 : tensor <1 x13 x21 x3 xf32 >
11011102}
1103+
1104+ // -----
1105+ // CHECK-LABEL: test_mul_non_scalar_shift_2d
1106+ func.func @test_mul_non_scalar_shift_2d (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1107+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 x1 xi8 >}> : () -> tensor <1 x1 xi8 >
1108+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
1109+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf32 >, tensor <1 x1 xi8 >) -> tensor <13 x21 x3 xf32 >
1110+ return %0 : tensor <13 x21 x3 xf32 >
1111+ }
1112+
1113+ // -----
1114+ // CHECK-LABEL: test_mul_non_scalar_shift_1d
1115+ func.func @test_mul_non_scalar_shift_1d (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1116+ %shift = " tosa.const" () <{value = dense <0 > : tensor <2 xi8 >}> : () -> tensor <2 xi8 >
1117+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
1118+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf32 >, tensor <2 xi8 >) -> tensor <13 x21 x3 xf32 >
1119+ return %0 : tensor <13 x21 x3 xf32 >
1120+ }
1121+
1122+ // -----
1123+ // CHECK-LABEL: test_mul_non_broadcast
1124+ func.func @test_mul_non_broadcast (%arg0: tensor <13 x21 x2 xf32 >, %arg1: tensor <3 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1125+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
1126+ // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
1127+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x2 xf32 >, tensor <3 x1 x3 xf32 >, tensor <1 xi8 >) -> tensor <13 x21 x3 xf32 >
1128+ return %0 : tensor <13 x21 x3 xf32 >
1129+ }
0 commit comments