@@ -296,5 +296,88 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
296296 gpu.return
297297 }
298298
299+ // CHECK-LABEL: @warp_specialized
300+ gpu.func @warp_specialized (%src: memref <256 x128 xf32 >, %src1: memref <128 x256 xf32 >, %src2: memref <128 x64 xf32 >) {
301+ %sg_id = gpu.subgroup_id : index
302+ %c0 = arith.constant 0 : index
303+ %c1 = arith.constant 1 : index
304+ %c2 = arith.constant 2 : index
305+ %c31 = arith.constant 31 : index
306+ %c3 = arith.constant 3 : index
307+ %cond1 = arith.cmpi sge , %sg_id , %c0 : index
308+ %cond2 = arith.cmpi slt , %sg_id , %c1 : index
309+ %cond = arith.andi %cond1 , %cond2 : i1
310+ scf.if %cond {
311+ // CHECK-NOT: index.sub
312+ %tdesc = xegpu.create_nd_tdesc %src [0 , 0 ] : memref <256 x128 xf32 >
313+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
314+ %load = xegpu.load_nd %tdesc
315+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
316+ -> vector <256 x128 xf32 >
317+ } {xegpu.sg_id_range = array<i32 : 0 , 1 >}
318+ %cond3 = arith.cmpi sge , %sg_id , %c1 : index
319+ %cond4 = arith.cmpi slt , %sg_id , %c2 : index
320+ %cond5 = arith.andi %cond3 , %cond4 : i1
321+ scf.if %cond5 {
322+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
323+ // CHECK: %[[C1:.*]] = arith.constant 1 : index
324+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C1]]
325+ %tdesc_a = xegpu.create_nd_tdesc %src [0 , 0 ] : memref <256 x128 xf32 >
326+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
327+ %load_a = xegpu.load_nd %tdesc_a
328+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
329+ -> vector <256 x128 xf32 >
330+ %tdesc_b = xegpu.create_nd_tdesc %src1 [0 , 0 ] : memref <128 x256 xf32 >
331+ -> !xegpu.tensor_desc <128 x256 xf32 , #xegpu.layout <sg_layout = [4 , 8 ], sg_data = [32 , 32 ], lane_layout = [4 , 8 ], lane_data = [1 , 1 ]>>
332+ %load_b = xegpu.load_nd %tdesc_b
333+ : !xegpu.tensor_desc <128 x256 xf32 , #xegpu.layout <sg_layout = [4 , 8 ], sg_data = [32 , 32 ], lane_layout = [4 , 8 ], lane_data = [1 , 1 ]>>
334+ -> vector <128 x256 xf32 >
335+ %dpas = xegpu.dpas %load_a , %load_b {layout_result_0 = #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [32 , 32 ], lane_layout = [4 , 8 ], lane_data = [1 , 1 ]>} : vector <256 x128 xf32 >, vector <128 x256 xf32 > -> vector <256 x256 xf32 >
336+ }{xegpu.sg_id_range = array<i32 : 1 , 2 >}
337+ %cond6 = arith.cmpi sge , %sg_id , %c2 : index
338+ %cond7 = arith.cmpi slt , %sg_id , %c31 : index
339+ %cond8 = arith.andi %cond6 , %cond7 : i1
340+ scf.if %cond8 {
341+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
342+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
343+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
344+ %tdesc = xegpu.create_nd_tdesc %src2 [0 , 0 ] : memref <128 x64 xf32 >
345+ -> !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
346+ %load = xegpu.load_nd %tdesc
347+ : !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
348+ -> vector <128 x64 xf32 >
349+ %exp = math.exp %load {layout_result_0 = #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>} : vector <128 x64 xf32 >
350+ }{xegpu.sg_id_range = array<i32 : 2 , 32 >}
351+ gpu.return
352+ }
299353
354+ // CHECK-LABEL: @subgroup_id_range_nested_if
355+ gpu.func @subgroup_id_range_nested_if (%src: memref <256 x128 xf32 >, %src1: memref <128 x64 xf32 >) {
356+ %sg_id = gpu.subgroup_id : index
357+ %c1 = arith.constant 1 : i1
358+ %c3 = arith.constant 3 : index
359+ %c32 = arith.constant 32 : index
360+ %tdesc = xegpu.create_nd_tdesc %src [0 , 0 ] : memref <256 x128 xf32 >
361+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
362+ %load = xegpu.load_nd %tdesc
363+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
364+ -> vector <256 x128 xf32 >
365+ %cond1 = arith.cmpi sge , %sg_id , %c3 : index
366+ %cond2 = arith.cmpi slt , %sg_id , %c32 : index
367+ %cond = arith.andi %cond1 , %cond2 : i1
368+ scf.if %c1 {
369+ scf.if %cond {
370+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
371+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
372+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
373+ %td = xegpu.create_nd_tdesc %src1 [0 , 0 ] : memref <128 x64 xf32 >
374+ -> !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
375+ %ld = xegpu.load_nd %td
376+ : !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
377+ -> vector <128 x64 xf32 >
378+ %exp = math.exp %ld {layout_result_0 = #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>} : vector <128 x64 xf32 >
379+ }
380+ } {xegpu.sg_id_range = array<i32 : 3 , 8 >}
381+ gpu.return
382+ }
300383}
0 commit comments