@@ -206,7 +206,6 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
206206 gpu.return
207207 }
208208
209-
210209 gpu.func @test_scf_while_and_condition (%arg0: memref <1024 xf32 >, %arg1: memref <1024 xf32 >) {
211210 %c1_i32 = arith.constant 1 : i32
212211 %c10_i32 = arith.constant 10 : i32
@@ -232,5 +231,63 @@ gpu.func @test_dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
232231 gpu.return
233232 }
234233
234+ gpu.func @test_scf_if (%arg0: memref <1024 xf32 >, %arg1: memref <1024 xf32 >) {
235+ %c10 = arith.constant 10 : index
236+ %id = gpu.subgroup_id : index
237+
238+ %0 = xegpu.create_nd_tdesc %arg0 [0 ] : memref <1024 xf32 > -> !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
239+ %1 = xegpu.create_nd_tdesc %arg1 [0 ] : memref <1024 xf32 > -> !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
240+
241+ %4 = arith.cmpi eq , %id , %c10 : index
242+ // CHECK-LABEL: scf.if
243+ // CHECK-SAME: (vector<16xf32>)
244+ %5 = scf.if %4 -> (vector <256 xf32 >) {
245+ // CHECK-LABEL: xegpu.load_nd
246+ // CHECK-SAME: !xegpu.tensor_desc<16xf32> -> vector<16xf32>
247+ %2 = xegpu.load_nd %0 : !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>> -> vector <256 xf32 >
248+ // CHECK-LABEL: scf.yield
249+ // CHECK-SAME: vector<16xf32>
250+ scf.yield %2 : vector <256 xf32 >
251+ } else {
252+ // CHECK-LABEL: xegpu.load_nd
253+ // CHECK-SAME: !xegpu.tensor_desc<16xf32> -> vector<16xf32>
254+ %3 = xegpu.load_nd %1 : !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>> -> vector <256 xf32 >
255+ // CHECK-LABEL: scf.yield
256+ // CHECK-SAME: vector<16xf32>
257+ scf.yield %3 : vector <256 xf32 >
258+ } {layout_result_0 = #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>}
259+ xegpu.store_nd %5 , %0 : vector <256 xf32 >, !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
260+ gpu.return
261+ }
262+
263+ gpu.func @test_scf_if_tensor_desc (%arg0: memref <1024 xf32 >, %arg1: memref <1024 xf32 >) {
264+ %c10 = arith.constant 10 : index
265+ %id = gpu.subgroup_id : index
266+
267+ %t = xegpu.create_nd_tdesc %arg0 [0 ] : memref <1024 xf32 > -> !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
268+ %d = xegpu.load_nd %t : !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>> -> vector <256 xf32 >
269+
270+ %0 = arith.cmpi eq , %id , %c10 : index
271+ // CHECK-LABEL: scf.if
272+ // CHECK-SAME: !xegpu.tensor_desc<16xf32>
273+ %1 = scf.if %0 -> (!xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>) {
274+ // CHECK-LABEL: xegpu.create_nd_tdesc
275+ // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
276+ %2 = xegpu.create_nd_tdesc %arg0 [0 ] : memref <1024 xf32 > -> !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
277+ // CHECK-LABEL: scf.yield
278+ // CHECK-SAME: !xegpu.tensor_desc<16xf32>
279+ scf.yield %2 : !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
280+ } else {
281+ // CHECK-LABEL: xegpu.create_nd_tdesc
282+ // CHECK-SAME: memref<1024xf32> -> !xegpu.tensor_desc<16xf32>
283+ %3 = xegpu.create_nd_tdesc %arg1 [0 ] : memref <1024 xf32 > -> !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
284+ // CHECK-LABEL: scf.yield
285+ // CHECK-SAME: !xegpu.tensor_desc<16xf32>
286+ scf.yield %3 : !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
287+ }
288+ xegpu.store_nd %d , %1 : vector <256 xf32 >, !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
289+ gpu.return
290+ }
291+
235292
236293}
0 commit comments