Skip to content

Commit c0fee39

Browse files
committed
Add more error checking tests and select/constant tests
1 parent f2ba086 commit c0fee39

File tree

6 files changed

+79
-0
lines changed

6 files changed

+79
-0
lines changed

mlir/test/Dialect/SPIRV/IR/atomic-ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,3 +272,11 @@ func.func @atomic_fadd(%ptr : !spirv.ptr<f32, StorageBuffer>, %value : f32) -> f
272272
%0 = spirv.EXT.AtomicFAdd <Device> <Acquire|Release> %ptr, %value : !spirv.ptr<f32, StorageBuffer>
273273
return %0 : f32
274274
}
275+
276+
// -----
277+
278+
func.func @atomic_bf16_fadd(%ptr : !spirv.ptr<bf16, StorageBuffer>, %value : bf16) -> bf16 {
279+
// expected-error @+1 {{op operand #1 must be 16/32/64-bit float, but got 'bf16'}}
280+
%0 = spirv.EXT.AtomicFAdd <Device> <None> %ptr, %value : !spirv.ptr<bf16, StorageBuffer>
281+
return %0 : bf16
282+
}

mlir/test/Dialect/SPIRV/IR/composite-ops.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ func.func @composite_construct_vector(%arg0: f32, %arg1: f32, %arg2 : f32) -> ve
1111
return %0: vector<3xf32>
1212
}
1313

14+
// CHECK-LABEL: func @composite_construct_bf16_vector
15+
func.func @composite_construct_bf16_vector(%arg0: bf16, %arg1: bf16, %arg2 : bf16) -> vector<3xbf16> {
16+
// CHECK: spirv.CompositeConstruct {{%.*}}, {{%.*}}, {{%.*}} : (bf16, bf16, bf16) -> vector<3xbf16>
17+
%0 = spirv.CompositeConstruct %arg0, %arg1, %arg2 : (bf16, bf16, bf16) -> vector<3xbf16>
18+
return %0: vector<3xbf16>
19+
}
20+
1421
// CHECK-LABEL: func @composite_construct_struct
1522
func.func @composite_construct_struct(%arg0: vector<3xf32>, %arg1: !spirv.array<4xf32>, %arg2 : !spirv.struct<(f32)>) -> !spirv.struct<(vector<3xf32>, !spirv.array<4xf32>, !spirv.struct<(f32)>)> {
1623
// CHECK: spirv.CompositeConstruct

mlir/test/Dialect/SPIRV/IR/gl-ops.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,14 @@ func.func @exp(%arg0 : i32) -> () {
5050

5151
// -----
5252

53+
func.func @exp_bf16(%arg0 : bf16) -> () {
54+
// expected-error @+1 {{op operand #0 must be 16/32-bit float or vector of 16/32-bit float values of length 2/3/4}}
55+
%2 = spirv.GL.Exp %arg0 : bf16
56+
return
57+
}
58+
59+
// -----
60+
5361
//===----------------------------------------------------------------------===//
5462
// spirv.GL.{F|S|U}{Max|Min}
5563
//===----------------------------------------------------------------------===//
@@ -92,6 +100,15 @@ func.func @iminmax(%arg0: i32, %arg1: i32) {
92100

93101
// -----
94102

103+
func.func @fmaxminbf16vec(%arg0 : vector<3xbf16>, %arg1 : vector<3xbf16>) {
104+
// expected-error @+1 {{operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
105+
%1 = spirv.GL.FMax %arg0, %arg1 : vector<3xbf16>
106+
%2 = spirv.GL.FMin %arg0, %arg1 : vector<3xbf16>
107+
return
108+
}
109+
110+
// -----
111+
95112
//===----------------------------------------------------------------------===//
96113
// spirv.GL.InverseSqrt
97114
//===----------------------------------------------------------------------===//

mlir/test/Dialect/SPIRV/IR/logical-ops.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,14 @@ func.func @select_op_float(%arg0: i1) -> () {
201201
return
202202
}
203203

204+
func.func @select_op_bfloat16(%arg0: i1) -> () {
205+
%0 = spirv.Constant 2.0 : bf16
206+
%1 = spirv.Constant 3.0 : bf16
207+
// CHECK: spirv.Select {{%.*}}, {{%.*}}, {{%.*}} : i1, bf16
208+
%2 = spirv.Select %arg0, %0, %1 : i1, bf16
209+
return
210+
}
211+
204212
func.func @select_op_ptr(%arg0: i1) -> () {
205213
%0 = spirv.Variable : !spirv.ptr<f32, Function>
206214
%1 = spirv.Variable : !spirv.ptr<f32, Function>

mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,14 @@ func.func @group_non_uniform_fmul_clustered_reduce(%val: vector<2xf32>) -> vecto
184184

185185
// -----
186186

187+
func.func @group_non_uniform_bf16_fmul_reduce(%val: bf16) -> bf16 {
188+
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
189+
%0 = spirv.GroupNonUniformFMul <Workgroup> <Reduce> %val : bf16 -> bf16
190+
return %0: bf16
191+
}
192+
193+
// -----
194+
187195
//===----------------------------------------------------------------------===//
188196
// spirv.GroupNonUniformFMax
189197
//===----------------------------------------------------------------------===//
@@ -197,6 +205,14 @@ func.func @group_non_uniform_fmax_reduce(%val: f32) -> f32 {
197205

198206
// -----
199207

208+
func.func @group_non_uniform_bf16_fmax_reduce(%val: bf16) -> bf16 {
209+
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'bf16'}}
210+
%0 = spirv.GroupNonUniformFMax <Workgroup> <Reduce> %val : bf16 -> bf16
211+
return %0: bf16
212+
}
213+
214+
// -----
215+
200216
//===----------------------------------------------------------------------===//
201217
// spirv.GroupNonUniformFMin
202218
//===----------------------------------------------------------------------===//

mlir/test/Target/SPIRV/logical-ops.mlir

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,3 +108,26 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
108108
spirv.Return
109109
}
110110
}
111+
112+
// -----
113+
114+
// Test select works with bf16 scalar and vectors.
115+
116+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
117+
spirv.SpecConstant @condition_scalar = true
118+
spirv.func @select_bf16() -> () "None" {
119+
%0 = spirv.Constant 4.0 : bf16
120+
%1 = spirv.Constant 5.0 : bf16
121+
%2 = spirv.mlir.referenceof @condition_scalar : i1
122+
// CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : i1, bf16
123+
%3 = spirv.Select %2, %0, %1 : i1, bf16
124+
%4 = spirv.Constant dense<[2.0, 3.0, 4.0, 5.0]> : vector<4xbf16>
125+
%5 = spirv.Constant dense<[6.0, 7.0, 8.0, 9.0]> : vector<4xbf16>
126+
// CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : i1, vector<4xbf16>
127+
%6 = spirv.Select %2, %4, %5 : i1, vector<4xbf16>
128+
%7 = spirv.Constant dense<[true, true, true, true]> : vector<4xi1>
129+
// CHECK: spirv.Select {{.*}}, {{.*}}, {{.*}} : vector<4xi1>, vector<4xbf16>
130+
%8 = spirv.Select %7, %4, %5 : vector<4xi1>, vector<4xbf16>
131+
spirv.Return
132+
}
133+
}

0 commit comments

Comments
 (0)