@@ -242,4 +242,70 @@ gpu.module @test_distribution {
242242 xegpu.store_nd %8#2 , %2 [%0 , %1 ] : vector <128 x128 xf32 >, !xegpu.tensor_desc <128 x128 xf32 , #xegpu.layout <sg_layout = [8 , 8 ], sg_data = [16 , 16 ]>>
243243 gpu.return
244244 }
245+
246+ // CHECK-LABEL: @subgroup_id_range
247+ gpu.func @subgroup_id_range (%src: memref <256 x128 xf32 >, %src1: memref <128 x256 xf32 >, %src2: memref <128 x64 xf32 >) {
248+ %sg_id = gpu.subgroup_id : index
249+ %c0 = arith.constant 0 : index
250+ %c1 = arith.constant 1 : index
251+ %c2 = arith.constant 2 : index
252+ %c31 = arith.constant 31 : index
253+ %c3 = arith.constant 3 : index
254+ %cond1 = arith.cmpi sge , %sg_id , %c0 : index
255+ %cond2 = arith.cmpi slt , %sg_id , %c1 : index
256+ %cond = arith.andi %cond1 , %cond2 : i1
257+ scf.if %cond {
258+ // CHECK-NOT: index.sub
259+ %tdesc = xegpu.create_nd_tdesc %src : memref <256 x128 xf32 >
260+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
261+ %load = xegpu.load_nd %tdesc [0 , 0 ]
262+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
263+ -> vector <256 x128 xf32 >
264+ } {sg_id_range = #xegpu.range <[0 , 32 ]>}
265+ %cond3 = arith.cmpi sge , %sg_id , %c2 : index
266+ %cond4 = arith.cmpi slt , %sg_id , %c31 : index
267+ %cond5 = arith.andi %cond3 , %cond4 : i1
268+ scf.if %cond5 {
269+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
270+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
271+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
272+ %tdesc = xegpu.create_nd_tdesc %src2 : memref <128 x64 xf32 >
273+ -> !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
274+ %load = xegpu.load_nd %tdesc [0 , 0 ]
275+ : !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
276+ -> vector <128 x64 xf32 >
277+ %exp = math.exp %load {layout_result_0 = #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>} : vector <128 x64 xf32 >
278+ }{sg_id_range = #xegpu.range <[2 , 18 ]>}
279+ gpu.return
280+ }
281+
282+ // CHECK-LABEL: @subgroup_id_range_nested_if
283+ gpu.func @subgroup_id_range_nested_if (%src: memref <256 x128 xf32 >, %src1: memref <128 x64 xf32 >) {
284+ %sg_id = gpu.subgroup_id : index
285+ %c1 = arith.constant 1 : i1
286+ %c3 = arith.constant 3 : index
287+ %c32 = arith.constant 32 : index
288+ %tdesc = xegpu.create_nd_tdesc %src : memref <256 x128 xf32 >
289+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
290+ %load = xegpu.load_nd %tdesc [0 , 0 ]
291+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
292+ -> vector <256 x128 xf32 >
293+ %cond1 = arith.cmpi sge , %sg_id , %c3 : index
294+ %cond2 = arith.cmpi slt , %sg_id , %c32 : index
295+ %cond = arith.andi %cond1 , %cond2 : i1
296+ scf.if %c1 {
297+ scf.if %cond {
298+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
299+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
300+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
301+ %td = xegpu.create_nd_tdesc %src1 : memref <128 x64 xf32 >
302+ -> !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
303+ %ld = xegpu.load_nd %td [0 , 0 ]
304+ : !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
305+ -> vector <128 x64 xf32 >
306+ %exp = math.exp %ld {layout_result_0 = #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>} : vector <128 x64 xf32 >
307+ }
308+ } {sg_id_range = #xegpu.range <[3 , 19 ]>}
309+ gpu.return
310+ }
245311}
0 commit comments