Skip to content

Commit 1ca7b30

Browse files
committed
add unit tests for scf.if
1 parent 689bb05 commit 1ca7b30

File tree

1 file changed

+58
-1
lines changed

1 file changed

+58
-1
lines changed

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,6 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
206206
gpu.return
207207
}
208208

209-
210209
gpu.func @test_scf_while_and_condition(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
211210
%c1_i32 = arith.constant 1 : i32
212211
%c10_i32 = arith.constant 10 : i32
@@ -232,5 +231,63 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
232231
gpu.return
233232
}
234233

234+
gpu.func @test_scf_if(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
235+
%c10 = arith.constant 10 : index
236+
%id = gpu.subgroup_id : index
237+
238+
%0 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
239+
%1 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
240+
241+
%4 = arith.cmpi eq, %id, %c10 : index
242+
// CHECK-LABEL: scf.if
243+
// CHECK-SAME: (vector<16xf32>)
244+
%5 = scf.if %4 -> (vector<256xf32>) {
245+
// CHECK-LABEL: xegpu.load_nd
246+
// CHECK-SAME: !xegpu.tensor_desc<16xf32> -> vector<16xf32>
247+
%2 = xegpu.load_nd %0 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> -> vector<256xf32>
248+
// CHECK-LABEL: scf.yield
249+
// CHECK-SAME: vector<16xf32>
250+
scf.yield %2 : vector<256xf32>
251+
} else {
252+
// CHECK-LABEL: xegpu.load_nd
253+
// CHECK-SAME: !xegpu.tensor_desc<16xf32> -> vector<16xf32>
254+
%3 = xegpu.load_nd %1 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> -> vector<256xf32>
255+
// CHECK-LABEL: scf.yield
256+
// CHECK-SAME: vector<16xf32>
257+
scf.yield %3 : vector<256xf32>
258+
} {layout_result_0 = #xegpu.layout<sg_layout = [16], sg_data = [16]>}
259+
xegpu.store_nd %5, %0 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
260+
gpu.return
261+
}
262+
263+
gpu.func @test_scf_if_tensor_desc(%arg0: memref<1024xf32>, %arg1: memref<1024xf32>) {
264+
%c10 = arith.constant 10 : index
265+
%id = gpu.subgroup_id : index
266+
267+
%t = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
268+
%d = xegpu.load_nd %t : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>> -> vector<256xf32>
269+
270+
%0 = arith.cmpi eq, %id, %c10 : index
271+
// CHECK-LABEL: scf.if
272+
// CHECK-SAME: !xegpu.tensor_desc<16xf32>
273+
%1 = scf.if %0 -> (!xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>) {
274+
// CHECK-LABEL: xegpu.create_nd_tdesc
275+
// CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
276+
%2 = xegpu.create_nd_tdesc %arg0[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
277+
// CHECK-LABEL: scf.yield
278+
// CHECK-SAME: !xegpu.tensor_desc<16xf32>
279+
scf.yield %2 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
280+
} else {
281+
// CHECK-LABEL: xegpu.create_nd_tdesc
282+
// CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
283+
%3 = xegpu.create_nd_tdesc %arg1[0] : memref<1024xf32> -> !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
284+
// CHECK-LABEL: scf.yield
285+
// CHECK-SAME: !xegpu.tensor_desc<16xf32>
286+
scf.yield %3 : !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
287+
}
288+
xegpu.store_nd %d, %1 : vector<256xf32>, !xegpu.tensor_desc<256xf32, #xegpu.layout<sg_layout = [16], sg_data = [16]>>
289+
gpu.return
290+
}
291+
235292

236293
}

0 commit comments

Comments
 (0)