Skip to content

Commit e5ce030

Browse files
[mlir][xegpu] Improve XeGPU op verification logic for SIMT flavor and update tests. (#127920)
This PR adds required changes for XeGPU ops to support the SIMT distribution. 1. Adds verification logic for SIMT flavor for load_nd, store_nd, dpas, load_gather and store_scatter ops. 2. Adds test cases to cover the SIMT version of these ops along with their VC counter parts. --------- Co-authored-by: Artem Kroviakov <[email protected]>
1 parent eabe2eb commit e5ce030

File tree

7 files changed

+1109
-535
lines changed

7 files changed

+1109
-535
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 113 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
8080
information e.g., memref<?x?xf16>, the strides information has to be explicitly
8181
passed via the "strides" and "const_strides" argument.
8282

83+
In SIMT mode, tensor descriptor is augmented with `SGMapAttr` which describes the
84+
mapping of the tensor descriptor to the work items.
85+
8386
Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
8487
```mlir
8588
%0 = memref.alloc() : memref<1024x1024xf32>
@@ -103,6 +106,15 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
103106
%c1 = arith.constant 1 : index
104107
%1 = xegpu.create_nd_tdesc %0[%c0, %c0], [%h, %w], [%w, %c1]: ui64 -> TensorDesc<8x16xf32>
105108
```
109+
110+
Example 4 (SIMT mode):
111+
```mlir
112+
%0 = memref.alloc() : memref<1024x1024xf32>
113+
%c0 = arith.constant 0 : index
114+
%c1 = arith.constant 8 : index
115+
%1 = xegpu.create_nd_tdesc %0[%c0, %c0] : memref<1024x1024xf32>
116+
-> !xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
117+
```
106118
}];
107119

108120
let arguments = (ins
@@ -294,14 +306,25 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
294306
fp32 or fp64. It implies that vnni and transpose cannot exit at the
295307
same time.
296308

297-
Example:
309+
In SIMT mode, LoadNdOp expects the tensor descriptor to be augmented with `SGMapAttr`
310+
which describes the mapping of the tensor to the work items. In this case, result
311+
vector represents the data to be loaded by each work-item.
312+
313+
Example 1:
298314
```mlir
299315
xegpu.load_nd %1 {transpose = [1, 0],
300316
l1_hint = #xegpu.cache_hint<cached>,
301317
l2_hint = #xegpu.cache_hint<uncached>,
302318
l3_hint = #xegpu.cache_hint<streaming>}
303319
: !xegpu.tensor_desc<8x16xf32> -> vector<16x8xf32>
304320
```
321+
Example 2 (SIMT mode):
322+
```mlir
323+
xegpu.load_nd %1 {l1_hint = #xegpu.cache_hint<cached>,
324+
l2_hint = #xegpu.cache_hint<uncached>}>
325+
: !xegpu.tensor_desc<8x16xf32,
326+
#xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>> -> vector<8x1xf32>
327+
```
305328

306329

307330
}];
@@ -341,13 +364,25 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
341364
of cache, L1, L2 and L3. If hardware does not have a correspoding cache,
342365
Corresponding cache hint attribute will be masked.
343366

344-
Example:
367+
In SIMT mode, StoreNdOp expects the tensor descriptor to be augmented with `SGMapAttr`
368+
which describes the mapping of the tensor to the work items. In this case, input
369+
vector represents the data to be stored by each work-item.
370+
371+
Example 1:
345372
```mlir
346373
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
347374
l2_hint = #xegpu.cache_hint<write_back>,
348375
l3_hint = #xegpu.cache_hint<write_through>}
349376
: vector<8x16xf16>, !xegpu.tensor_desc<8x16xf16>
350377
```
378+
Example 2 (SIMT mode):
379+
```mlir
380+
xegpu.store_nd %3, %2 {l1_hint = #xegpu.cache_hint<uncached>,
381+
l2_hint = #xegpu.cache_hint<write_back>,
382+
l3_hint = #xegpu.cache_hint<write_through>}
383+
: vector<8x1xf16>, !xegpu.tensor_desc<8x16xf16,
384+
#xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
385+
```
351386

352387

353388
}];
@@ -380,10 +415,15 @@ def XeGPU_UpdateNdOffsetOp : XeGPU_Op<"update_nd_offset",
380415
The offsets are relative offset to the current position in the number
381416
of elements. It will result in a same type TensorDesc as the input.
382417

383-
example:
418+
Example 1:
384419
```
385420
%2 = xegpu.update_nd_offset %1, [0, 16]: !xegpu.tensor_desc<8x16xf32>
386421
```
422+
Example 2 (SIMT mode):
423+
```
424+
%2 = xegpu.update_nd_offset %1, [0, 16]:
425+
!xegpu.tensor_desc<8x16xf32, #xegpu.sg_map<wi_layout = [1, 16], wi_data = [1, 1]>>
426+
```
387427
}];
388428

389429
let arguments = (ins
@@ -441,14 +481,19 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
441481
match the dimension of offsets. It may also has a second dimension corresponding to
442482
the chunk_size if the chunk size is larger than 1.
443483

444-
Example 1. It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
484+
In SIMT mode, similar to `create_nd_tdesc` the resulting tensor descriptor is augmented
485+
with `SGMapAttr` which describes the mapping of the tensor descriptor to the work items.
486+
In this case, the first dimension of the tensor descriptor represents the work-items, and
487+
the second dimension represents the chunk size.
488+
489+
Example 1: It assumes subgroup size is 4, and accesses a[0], a[16], a[32], a[64]
445490
```mlir
446491
%a = memref.alloc() : memref<1024xf32>
447492
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
448493
%1 = xegpu.create_tdesc %a, %0: memref<1024xf32>, vector<4xindex> -> TensorDesc<4xf32>
449494
```
450495

451-
Example 2. It assumes subgroup size is 4, and each workitem access 8 elements.
496+
Example 2: It assumes subgroup size is 4, and each workitem access 8 elements.
452497
It will access totally 32 data elements: a[0:7], a[16:23], a[32:39], a[64:71]
453498
```mlir
454499
%0 = memref.alloc() : memref<1024xf32>
@@ -457,14 +502,23 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
457502
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
458503
```
459504

460-
Example 3. It is similar to Example 2, but there is some overlaps among workitems.
505+
Example 3: It is similar to Example 2, but there is some overlaps among workitems.
461506
It accesses: a[0:7], a[4:11], a[8:15], a[12:19]
462507
```mlir
463508
%0 = memref.alloc() : memref<1024xf32>
464509
%off = arith.constant dense<[0, 4, 8, 12]> : vector<4xindex>
465510
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
466511
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>>
467512
```
513+
514+
Example 4: SIMT mode
515+
```mlir
516+
%0 = memref.alloc() : memref<1024xf32>
517+
%off = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
518+
%1 = xegpu.create_tdesc %0, %off : memref<1024xf32>, vector<4xindex>
519+
-> TensorDesc<4x8xf32, #xegpu.scattered_tdesc_attr<chunk_size = 8>,
520+
#xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>
521+
```
468522
}];
469523

470524
let arguments = (ins XeGPU_BaseAddrType: $source,
@@ -569,6 +623,11 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
569623
The mask operand masks out memory access so that it is safe to pass out-of-boundary
570624
addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
571625

626+
In SIMT mode, LoadGatherOp expects the tensor descriptor to be augmented with `SGMapAttr`
627+
which describes the mapping of the tensor to the work items. In this case, result vector
628+
represents the data to be loaded by each work-item. Each work-item recieves a `chunk_size`
629+
number of elements.
630+
572631
Example 1:
573632
```mlir
574633
%2 = xegpu.load %1, %0 {l1_hint = #xegpu.cache_hint<cached>,
@@ -587,6 +646,16 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
587646
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
588647
vector<16xi1> -> vector<8x16xf32>
589648
```
649+
Example 3 (SIMT mode):
650+
```mlir
651+
%2 = xegpu.load %1, %0 {transpose,
652+
l1_hint = #xegpu.cache_hint<cached>,
653+
l2_hint = #xegpu.cache_hint<uncached>,
654+
l3_hint = #xegpu.cache_hint<uncached>}
655+
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>,
656+
!xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>>
657+
vector<16xi1> -> vector<8x1xf32>
658+
```
590659

591660
}];
592661

@@ -608,8 +677,8 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
608677
return getElementTypeOrSelf(type);
609678
}
610679

611-
Type getValueType() {
612-
return getValue().getType();
680+
VectorType getValueType() {
681+
return llvm::dyn_cast<VectorType>(getValue().getType());
613682
}
614683

615684
Type getMaskType() {
@@ -635,22 +704,36 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
635704
has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
636705
introduced on purpose, making sure users are aware of this implicit transformation.
637706

707+
In SIMT mode, StoreScatterOp expects the tensor descriptor to be augmented with `SGMapAttr`
708+
which describes the mapping of the tensor to the work items. In this case, input vector
709+
represents the data to be stored by each work-item. Each work-item recieves a `chunk_size`
710+
number of elements.
711+
638712
Example 1:
639713
```mlir
640-
%3 = xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
714+
xegpu.store %0, %1, %2 {l1_hint = #xegpu.cache_hint<uncached>,
641715
l2_hint = #xegpu.cache_hint<write_back>,
642716
l3_hint = #xegpu.cache_hint<write_through>}
643717
: vector<16xf32>, !xegpu.tensor_desc<16xf32, #xegpu.scattered_tdesc_attr<>>, vector<16xi1>
644718
```
645719

646720
Example 2:
647721
```mlir
648-
%3 = xegpu.store %0, %1, %2 {transpose,
722+
xegpu.store %0, %1, %2 {transpose,
649723
l1_hint = #xegpu.cache_hint<uncached>,
650724
l2_hint = #xegpu.cache_hint<write_back>,
651725
l3_hint = #xegpu.cache_hint<write_through>}
652726
: vector<8x16xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
653727
```
728+
Example 3 (SIMT mode):
729+
```mlir
730+
xegpu.store %0, %1, %2 {transpose,
731+
l1_hint = #xegpu.cache_hint<uncached>,
732+
l2_hint = #xegpu.cache_hint<write_back>,
733+
l3_hint = #xegpu.cache_hint<write_through>}
734+
: vector<8x1xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>,
735+
!xegpu.sg_map<wi_layout = [16, 1], wi_data = [1, 1]>> vector<16xi1>
736+
```
654737

655738
}];
656739

@@ -668,8 +751,8 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
668751
return getTensorDesc().getType();
669752
}
670753

671-
Type getValueType() {
672-
return getValue().getType();
754+
VectorType getValueType() {
755+
return llvm::dyn_cast<VectorType>(getValue().getType());
673756
}
674757

675758
Type getMaskType() {
@@ -695,11 +778,19 @@ def XeGPU_UpdateOffsetOp: XeGPU_Op<"update_offset",
695778
update the offset per work-item, so its offsets contains values representing
696779
shifts for each work-item.
697780

698-
Example:
781+
Example 1:
699782
```mlir
700783
%off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
701784
%2 = xegpu.update_offset %1, %off :
702-
!xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<>>, vector<4xindex>
785+
!xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<chunk_size=2>>, vector<4xindex>
786+
```
787+
788+
Example 2 (SIMT mode):
789+
```mlir
790+
%off = arith.constant dense<[32, 32, 32, 32]> : vector<4xindex>
791+
%2 = xegpu.update_offset %1, %off :
792+
!xegpu.tensor_desc<4x2xf32, #xegpu.scattered_tdesc_attr<chunk_size=2>,
793+
#xegpu.sg_map<wi_layout = [4, 1], wi_data = [1, 1]>>, vector<4xindex>
703794
```
704795
}];
705796

@@ -749,6 +840,10 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
749840
factor, which is computed as `32/bit_width_of_elem_type`. Thus, `B: vector<16x16xf16>`
750841
can be represented as `B: vector<8x16x2xf16>`.
751842

843+
In SIMT mode, DpasOp expects attributes `sg_map_a`, `sg_map_b`, and `sg_map_c`
844+
which descibes the data fragment owned by each work-item w.r.t. the tensor
845+
descriptor these data are loaded from.
846+
752847
Note: on PVC, the hardware can perform load with VNNI transformation when data
753848
element type is 16-bit or lower precision, taking 2 or 4 elements from
754849
the first dimension and inserted into the newly added innermost dimension.
@@ -757,7 +852,10 @@ def XeGPU_DpasOp : XeGPU_Op<"dpas", [Pure, AllElementTypesMatch<["lhs", "rhs"]>]
757852
let arguments = (ins
758853
XeGPU_DpasOpType : $lhs,
759854
XeGPU_DpasOpType : $rhs,
760-
Optional<XeGPU_Vector2DType>: $acc);
855+
Optional<XeGPU_Vector2DType>: $acc,
856+
OptionalAttr<XeGPU_SGMapAttr>:$sg_map_a,
857+
OptionalAttr<XeGPU_SGMapAttr>:$sg_map_b,
858+
OptionalAttr<XeGPU_SGMapAttr>:$sg_map_c);
761859
let results = (outs XeGPU_Vector2DType: $result);
762860

763861
let extraClassDeclaration = [{

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
103103
CArg<"xegpu::MemorySpace", "xegpu::MemorySpace::Global">:$memory_space,
104104
CArg<"mlir::Attribute", "mlir::Attribute()">:$sg_map)>
105105
];
106-
106+
107107
let extraClassDeclaration = [{
108108
using TensorType::clone;
109109
using mlir::ShapedType::Trait<TensorDescType>::getElementTypeBitWidth;
@@ -176,6 +176,11 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
176176
return scatter_attr.getChunkSize().getInt();
177177
return 1;
178178
}
179+
180+
// This returns a vector type that represents the fragment of data owned by
181+
// a work item in SIMT mode if this tensor descriptor is used in a XeGPU
182+
// load/store operation.
183+
FailureOr<VectorType> getDistributedVectorType();
179184
}];
180185

181186
let hasCustomAssemblyFormat = true;

0 commit comments

Comments
 (0)