@@ -327,5 +327,70 @@ gpu.func @dpas_no_sg_data(%a: memref<24x32xf32>, %b: memref<32x24xf32>) {
327
327
xegpu.store_nd %d , %1 : vector <256 xf32 >, !xegpu.tensor_desc <256 xf32 , #xegpu.layout <sg_layout = [16 ], sg_data = [16 ]>>
328
328
gpu.return
329
329
}
330
- }
331
330
331
+ // CHECK-LABEL: @subgroup_id_range
332
+ gpu.func @subgroup_id_range (%src: memref <256 x128 xf32 >, %src1: memref <128 x256 xf32 >, %src2: memref <128 x64 xf32 >) {
333
+ %sg_id = gpu.subgroup_id : index
334
+ %c0 = arith.constant 0 : index
335
+ %c1 = arith.constant 1 : index
336
+ %c2 = arith.constant 2 : index
337
+ %c31 = arith.constant 31 : index
338
+ %c3 = arith.constant 3 : index
339
+ %cond1 = arith.cmpi sge , %sg_id , %c0 : index
340
+ %cond2 = arith.cmpi slt , %sg_id , %c1 : index
341
+ %cond = arith.andi %cond1 , %cond2 : i1
342
+ scf.if %cond {
343
+ // CHECK-NOT: index.sub
344
+ %tdesc = xegpu.create_nd_tdesc %src [0 , 0 ] : memref <256 x128 xf32 >
345
+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
346
+ %load = xegpu.load_nd %tdesc
347
+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
348
+ -> vector <256 x128 xf32 >
349
+ } {sg_id_range = #xegpu.range <[0 , 32 ]>}
350
+ %cond3 = arith.cmpi sge , %sg_id , %c2 : index
351
+ %cond4 = arith.cmpi slt , %sg_id , %c31 : index
352
+ %cond5 = arith.andi %cond3 , %cond4 : i1
353
+ scf.if %cond5 {
354
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
355
+ // CHECK: %[[C2:.*]] = arith.constant 2 : index
356
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C2]]
357
+ %tdesc = xegpu.create_nd_tdesc %src2 [0 , 0 ] : memref <128 x64 xf32 >
358
+ -> !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
359
+ %load = xegpu.load_nd %tdesc
360
+ : !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
361
+ -> vector <128 x64 xf32 >
362
+ %exp = math.exp %load {layout_result_0 = #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>} : vector <128 x64 xf32 >
363
+ }{sg_id_range = #xegpu.range <[2 , 18 ]>}
364
+ gpu.return
365
+ }
366
+
367
+ // CHECK-LABEL: @subgroup_id_range_nested_if
368
+ gpu.func @subgroup_id_range_nested_if (%src: memref <256 x128 xf32 >, %src1: memref <128 x64 xf32 >) {
369
+ %sg_id = gpu.subgroup_id : index
370
+ %c1 = arith.constant 1 : i1
371
+ %c3 = arith.constant 3 : index
372
+ %c32 = arith.constant 32 : index
373
+ %tdesc = xegpu.create_nd_tdesc %src [0 , 0 ] : memref <256 x128 xf32 >
374
+ -> !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
375
+ %load = xegpu.load_nd %tdesc
376
+ : !xegpu.tensor_desc <256 x128 xf32 , #xegpu.layout <sg_layout = [8 , 4 ], sg_data = [32 , 32 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
377
+ -> vector <256 x128 xf32 >
378
+ %cond1 = arith.cmpi sge , %sg_id , %c3 : index
379
+ %cond2 = arith.cmpi slt , %sg_id , %c32 : index
380
+ %cond = arith.andi %cond1 , %cond2 : i1
381
+ scf.if %c1 {
382
+ scf.if %cond {
383
+ // CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
384
+ // CHECK: %[[C3:.*]] = arith.constant 3 : index
385
+ // CHECK: %[[SUB:.*]] = index.sub %{{.*}}, %[[C3]]
386
+ %td = xegpu.create_nd_tdesc %src1 [0 , 0 ] : memref <128 x64 xf32 >
387
+ -> !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
388
+ %ld = xegpu.load_nd %td
389
+ : !xegpu.tensor_desc <128 x64 xf32 , #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>>
390
+ -> vector <128 x64 xf32 >
391
+ %exp = math.exp %ld {layout_result_0 = #xegpu.layout <sg_layout = [4 , 4 ], sg_data = [32 , 16 ], lane_layout = [8 , 4 ], lane_data = [1 , 1 ]>} : vector <128 x64 xf32 >
392
+ }
393
+ } {sg_id_range = #xegpu.range <[3 , 19 ]>}
394
+ gpu.return
395
+ }
396
+ }
0 commit comments