@@ -750,26 +750,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>,
750750
751751// CHECK-LABEL: test_mul_type_mismatch
752752func.func @test_mul_type_mismatch (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf16 >) -> tensor <13 x21 x3 xf32 > {
753+ %shift = " tosa.const" () {value = dense <0 > : tensor <1 xi8 >} : () -> tensor <1 xi8 >
753754 // expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}}
754- %0 = tosa.mul %arg0 , %arg1 : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf16 >) -> tensor <13 x21 x3 xf32 >
755+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf16 >, tensor < 1 x i8 >) -> tensor <13 x21 x3 xf32 >
755756 return %0 : tensor <13 x21 x3 xf32 >
756757}
757758
758759// -----
759760
760761// CHECK-LABEL: test_mul_invalid_shift
761- func.func @test_mul_invalid_shift (%arg0: tensor <13 x 21 x 3 x i32 >, %arg1: tensor <13 x 1 x 3 x i32 >) -> tensor <13 x 21 x 3 x i32 > {
762- %shift = " tosa.const" () {value = dense <0.0 > : tensor <f32 >} : () -> tensor <f32 >
763- // expected-error@+1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor<f32>' }}
764- %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x 21 x 3 x i32 >, tensor <13 x 1 x 3 x i32 >, tensor <f32 >) -> tensor <13 x 21 x 3 x i32 >
765- return %0 : tensor <13 x 21 x 3 x i32 >
762+ func.func @test_mul_invalid_shift (%arg0: tensor <13 x 21 x 3 x f32 >, %arg1: tensor <13 x 1 x 3 x f32 >) -> tensor <13 x 21 x 3 x f32 > {
763+ %shift = " tosa.const" () {value = dense <1 > : tensor <1 x i8 >} : () -> tensor <1 x i8 >
764+ // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type }}
765+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x 21 x 3 x f32 >, tensor <13 x 1 x 3 x f32 >, tensor <1 x i8 >) -> tensor <13 x 21 x 3 x f32 >
766+ return %0 : tensor <13 x 21 x 3 x f32 >
766767}
767768
768769// -----
769770
770771// CHECK-LABEL: test_mul_missing_shift
771772func.func @test_mul_missing_shift (%arg0: tensor <13 x21 x3 xi32 >, %arg1: tensor <13 x1 x3 xi32 >) -> tensor <13 x21 x3 xi32 > {
772- // this is ok because mul's shift operand is optional for now
773+ // expected-error@+1 {{'tosa. mul' op expected 3 operands, but found 2}}
773774 %0 = tosa.mul %arg0 , %arg1 : (tensor <13 x21 x3 xi32 >, tensor <13 x1 x3 xi32 >) -> tensor <13 x21 x3 xi32 >
774775 return %0 : tensor <13 x21 x3 xi32 >
775776}
@@ -1081,3 +1082,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1:
10811082 %0 = tosa.sub %arg0 , %arg1 : (tensor <1 x21 x3 xf32 >, tensor <13 x21 x3 xf32 >) -> tensor <1 x13 x21 x3 xf32 >
10821083 return %0 : tensor <1 x13 x21 x3 xf32 >
10831084}
1085+
1086+ // -----
1087+ // CHECK-LABEL: test_mul_non_scalar_shift_2d
1088+ func.func @test_mul_non_scalar_shift_2d (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1089+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 x1 xi8 >}> : () -> tensor <1 x1 xi8 >
1090+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}}
1091+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf32 >, tensor <1 x1 xi8 >) -> tensor <13 x21 x3 xf32 >
1092+ return %0 : tensor <13 x21 x3 xf32 >
1093+ }
1094+
1095+ // -----
1096+ // CHECK-LABEL: test_mul_non_scalar_shift_1d
1097+ func.func @test_mul_non_scalar_shift_1d (%arg0: tensor <13 x21 x3 xf32 >, %arg1: tensor <13 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1098+ %shift = " tosa.const" () <{value = dense <0 > : tensor <2 xi8 >}> : () -> tensor <2 xi8 >
1099+ // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}}
1100+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x3 xf32 >, tensor <13 x1 x3 xf32 >, tensor <2 xi8 >) -> tensor <13 x21 x3 xf32 >
1101+ return %0 : tensor <13 x21 x3 xf32 >
1102+ }
1103+
1104+ // -----
1105+ // CHECK-LABEL: test_mul_non_broadcast
1106+ func.func @test_mul_non_broadcast (%arg0: tensor <13 x21 x2 xf32 >, %arg1: tensor <3 x1 x3 xf32 >) -> tensor <13 x21 x3 xf32 > {
1107+ %shift = " tosa.const" () <{value = dense <0 > : tensor <1 xi8 >}> : () -> tensor <1 xi8 >
1108+ // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}}
1109+ %0 = tosa.mul %arg0 , %arg1 , %shift : (tensor <13 x21 x2 xf32 >, tensor <3 x1 x3 xf32 >, tensor <1 xi8 >) -> tensor <13 x21 x3 xf32 >
1110+ return %0 : tensor <13 x21 x3 xf32 >
1111+ }
0 commit comments