@@ -327,5 +327,70 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
327327 xegpu.store_nd %d , %1 : vector <256 xf32 >, !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
328328 gpu.return
329329 }
330- }
331330
331+ // CHECK-LABEL: @subgroup_id_range
332+ gpu.func @subgroup_id_range (%src: memref <256 x128 xf32 >, %src1: memref <128 x256 xf32 >, %src2: memref <128 x64 xf32 >) {
333+ %sg_id = gpu.subgroup_id : index
334+ %c0 = arith.constant 0 : index
335+ %c1 = arith.constant 1 : index
336+ %c2 = arith.constant 2 : index
337+ %c31 = arith.constant 31 : index
338+ %c3 = arith.constant 3 : index
339+ %cond1 = arith.cmpi sge , %sg_id , %c0 : index
340+ %cond2 = arith.cmpi slt , %sg_id , %c1 : index
341+ %cond = arith.andi %cond1 , %cond2 : i1
342+ scf.if %cond {
343+ // CHECK-NOT: index.sub
344+ %tdesc = xegpu.create_nd_tdesc %src [0 , 0 ] : memref <256 x128 xf32 >
345+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
346+ %load = xegpu.load_nd %tdesc
347+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
348+ -> vector <256 x128 xf32 >
349+ } {sg_id_range = #xegpu.range <[0 , 32 ]>}
350+ %cond3 = arith.cmpi sge , %sg_id , %c2 : index
351+ %cond4 = arith.cmpi slt , %sg_id , %c31 : index
352+ %cond5 = arith.andi %cond3 , %cond4 : i1
353+ scf.if %cond5 {
354+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
355+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
356+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
357+ %tdesc = xegpu.create_nd_tdesc %src2 [0 , 0 ] : memref <128 x64 xf32 >
358+ -> !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
359+ %load = xegpu.load_nd %tdesc
360+ : !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
361+ -> vector <128 x64 xf32 >
362+ %exp = math.exp %load {layout_result_0 = #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>} : vector <128 x64 xf32 >
363+ }{sg_id_range = #xegpu.range <[2 , 18 ]>}
364+ gpu.return
365+ }
366+
367+ // CHECK-LABEL: @subgroup_id_range_nested_if
368+ gpu.func @subgroup_id_range_nested_if (%src: memref <256 x128 xf32 >, %src1: memref <128 x64 xf32 >) {
369+ %sg_id = gpu.subgroup_id : index
370+ %c1 = arith.constant 1 : i1
371+ %c3 = arith.constant 3 : index
372+ %c32 = arith.constant 32 : index
373+ %tdesc = xegpu.create_nd_tdesc %src [0 , 0 ] : memref <256 x128 xf32 >
374+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
375+ %load = xegpu.load_nd %tdesc
376+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
377+ -> vector <256 x128 xf32 >
378+ %cond1 = arith.cmpi sge , %sg_id , %c3 : index
379+ %cond2 = arith.cmpi slt , %sg_id , %c32 : index
380+ %cond = arith.andi %cond1 , %cond2 : i1
381+ scf.if %c1 {
382+ scf.if %cond {
383+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
384+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
385+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
386+ %td = xegpu.create_nd_tdesc %src1 [0 , 0 ] : memref <128 x64 xf32 >
387+ -> !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
388+ %ld = xegpu.load_nd %td
389+ : !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
390+ -> vector <128 x64 xf32 >
391+ %exp = math.exp %ld {layout_result_0 = #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>} : vector <128 x64 xf32 >
392+ }
393+ } {sg_id_range = #xegpu.range <[3 , 19 ]>}
394+ gpu.return
395+ }
396+ }
0 commit comments