Skip to content

Commit f2ba086

Browse files
committed
Add tests for errors for arithmetic ops
1 parent 4248d7c commit f2ba086

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

mlir/test/Dialect/SPIRV/IR/arithmetic-ops.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ func.func @fadd_scalar(%arg: f32) -> f32 {
1212

1313
// -----
1414

15+
func.func @fadd_bf16_scalar(%arg: bf16) -> bf16 {
16+
// expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
17+
%0 = spirv.FAdd %arg, %arg : bf16
18+
return %0 : bf16
19+
}
20+
21+
// -----
22+
1523
//===----------------------------------------------------------------------===//
1624
// spirv.FDiv
1725
//===----------------------------------------------------------------------===//
@@ -24,6 +32,14 @@ func.func @fdiv_scalar(%arg: f32) -> f32 {
2432

2533
// -----
2634

35+
func.func @fdiv_bf16_scalar(%arg: bf16) -> bf16 {
36+
// expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
37+
%0 = spirv.FDiv %arg, %arg : bf16
38+
return %0 : bf16
39+
}
40+
41+
// -----
42+
2743
//===----------------------------------------------------------------------===//
2844
// spirv.FMod
2945
//===----------------------------------------------------------------------===//
@@ -36,6 +52,14 @@ func.func @fmod_scalar(%arg: f32) -> f32 {
3652

3753
// -----
3854

55+
func.func @fmod_bf16_scalar(%arg: bf16) -> bf16 {
56+
// expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
57+
%0 = spirv.FMod %arg, %arg : bf16
58+
return %0 : bf16
59+
}
60+
61+
// -----
62+
3963
//===----------------------------------------------------------------------===//
4064
// spirv.FMul
4165
//===----------------------------------------------------------------------===//
@@ -70,6 +94,14 @@ func.func @fmul_bf16(%arg: bf16) -> bf16 {
7094

7195
// -----
7296

97+
func.func @fmul_bf16_vector(%arg: vector<4xbf16>) -> vector<4xbf16> {
98+
// expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
99+
%0 = spirv.FMul %arg, %arg : vector<4xbf16>
100+
return %0 : vector<4xbf16>
101+
}
102+
103+
// -----
104+
73105
func.func @fmul_tensor(%arg: tensor<4xf32>) -> tensor<4xf32> {
74106
// expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
75107
%0 = spirv.FMul %arg, %arg : tensor<4xf32>
@@ -90,6 +122,14 @@ func.func @fnegate_scalar(%arg: f32) -> f32 {
90122

91123
// -----
92124

125+
func.func @fnegate_bf16_scalar(%arg: bf16) -> bf16 {
126+
// expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
127+
%0 = spirv.FNegate %arg : bf16
128+
return %0 : bf16
129+
}
130+
131+
// -----
132+
93133
//===----------------------------------------------------------------------===//
94134
// spirv.FRem
95135
//===----------------------------------------------------------------------===//
@@ -102,6 +142,14 @@ func.func @frem_scalar(%arg: f32) -> f32 {
102142

103143
// -----
104144

145+
func.func @frem_bf16_scalar(%arg: bf16) -> bf16 {
146+
// expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
147+
%0 = spirv.FRem %arg, %arg : bf16
148+
return %0 : bf16
149+
}
150+
151+
// -----
152+
105153
//===----------------------------------------------------------------------===//
106154
// spirv.FSub
107155
//===----------------------------------------------------------------------===//
@@ -114,6 +162,14 @@ func.func @fsub_scalar(%arg: f32) -> f32 {
114162

115163
// -----
116164

165+
func.func @fsub_bf16_scalar(%arg: bf16) -> bf16 {
166+
// expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
167+
%0 = spirv.FSub %arg, %arg : bf16
168+
return %0 : bf16
169+
}
170+
171+
// -----
172+
117173
//===----------------------------------------------------------------------===//
118174
// spirv.IAdd
119175
//===----------------------------------------------------------------------===//
@@ -489,3 +545,11 @@ func.func @vector_times_scalar(%vector: vector<4xf32>, %scalar: f32) -> vector<3
489545
%0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xf32>, f32) -> vector<3xf32>
490546
return %0 : vector<3xf32>
491547
}
548+
549+
// -----
550+
551+
func.func @vector_bf16_times_scalar_bf16(%vector: vector<4xbf16>, %scalar: bf16) -> vector<4xbf16> {
552+
// expected-error @+1 {{op operand #0 must be vector of 16/32/64-bit float values of length 2/3/4}}
553+
%0 = spirv.VectorTimesScalar %vector, %scalar : (vector<4xbf16>, bf16) -> vector<4xbf16>
554+
return %0 : vector<4xbf16>
555+
}

0 commit comments

Comments
 (0)