@@ -12,6 +12,14 @@ func.func @fadd_scalar(%arg: f32) -> f32 {
1212
1313// -----
1414
15+ func.func @fadd_bf16_scalar (%arg: bf16 ) -> bf16 {
16+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
17+ %0 = spirv.FAdd %arg , %arg : bf16
18+ return %0 : bf16
19+ }
20+
21+ // -----
22+
1523//===----------------------------------------------------------------------===//
1624// spirv.FDiv
1725//===----------------------------------------------------------------------===//
@@ -24,6 +32,14 @@ func.func @fdiv_scalar(%arg: f32) -> f32 {
2432
2533// -----
2634
35+ func.func @fdiv_bf16_scalar (%arg: bf16 ) -> bf16 {
36+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
37+ %0 = spirv.FDiv %arg , %arg : bf16
38+ return %0 : bf16
39+ }
40+
41+ // -----
42+
2743//===----------------------------------------------------------------------===//
2844// spirv.FMod
2945//===----------------------------------------------------------------------===//
@@ -36,6 +52,14 @@ func.func @fmod_scalar(%arg: f32) -> f32 {
3652
3753// -----
3854
55+ func.func @fmod_bf16_scalar (%arg: bf16 ) -> bf16 {
56+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
57+ %0 = spirv.FMod %arg , %arg : bf16
58+ return %0 : bf16
59+ }
60+
61+ // -----
62+
3963//===----------------------------------------------------------------------===//
4064// spirv.FMul
4165//===----------------------------------------------------------------------===//
@@ -70,6 +94,14 @@ func.func @fmul_bf16(%arg: bf16) -> bf16 {
7094
7195// -----
7296
97+ func.func @fmul_bf16_vector (%arg: vector <4 xbf16 >) -> vector <4 xbf16 > {
98+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
99+ %0 = spirv.FMul %arg , %arg : vector <4 xbf16 >
100+ return %0 : vector <4 xbf16 >
101+ }
102+
103+ // -----
104+
73105func.func @fmul_tensor (%arg: tensor <4 xf32 >) -> tensor <4 xf32 > {
74106 // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
75107 %0 = spirv.FMul %arg , %arg : tensor <4 xf32 >
@@ -90,6 +122,14 @@ func.func @fnegate_scalar(%arg: f32) -> f32 {
90122
91123// -----
92124
125+ func.func @fnegate_bf16_scalar (%arg: bf16 ) -> bf16 {
126+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
127+ %0 = spirv.FNegate %arg : bf16
128+ return %0 : bf16
129+ }
130+
131+ // -----
132+
93133//===----------------------------------------------------------------------===//
94134// spirv.FRem
95135//===----------------------------------------------------------------------===//
@@ -102,6 +142,14 @@ func.func @frem_scalar(%arg: f32) -> f32 {
102142
103143// -----
104144
145+ func.func @frem_bf16_scalar (%arg: bf16 ) -> bf16 {
146+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
147+ %0 = spirv.FRem %arg , %arg : bf16
148+ return %0 : bf16
149+ }
150+
151+ // -----
152+
105153//===----------------------------------------------------------------------===//
106154// spirv.FSub
107155//===----------------------------------------------------------------------===//
@@ -114,6 +162,14 @@ func.func @fsub_scalar(%arg: f32) -> f32 {
114162
115163// -----
116164
165+ func.func @fsub_bf16_scalar (%arg: bf16 ) -> bf16 {
166+ // expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
167+ %0 = spirv.FSub %arg , %arg : bf16
168+ return %0 : bf16
169+ }
170+
171+ // -----
172+
117173//===----------------------------------------------------------------------===//
118174// spirv.IAdd
119175//===----------------------------------------------------------------------===//
@@ -489,3 +545,11 @@ func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3
489545 %0 = spirv.VectorTimesScalar %vector , %scalar : (vector <4 xf32 >, f32 ) -> vector <3 xf32 >
490546 return %0 : vector <3 xf32 >
491547}
548+
549+ // -----
550+
551+ func.func @vector_bf16_times_scalar_bf16 (%vector: vector <4 xbf16 >, %scalar: bf16 ) -> vector <4 xbf16 > {
552+ // expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}}
553+ %0 = spirv.VectorTimesScalar %vector , %scalar : (vector <4 xbf16 >, bf16 ) -> vector <4 xbf16 >
554+ return %0 : vector <4 xbf16 >
555+ }
0 commit comments