Skip to content

Commit 1bf31c3

Browse files
authored
[MLIR][XeGPU] Update XeGPU create_tdesc, update_offset, load, store and prefetch. (#154653)
This PR tightens some loose ends in some XeGPU op definitions. Changes are backward compatible except for - Enforcing previous implicit assumption of load/store/prefetch offsets is required if source/dest is not a scatter tensor descriptor. - Likewise, enforce offsets is not allowed if source/dest is a scatter tensor descriptor. - Additionally, allow i64, i32 and ui32 as source/dest for load/store/prefetch. This matches behavior of tensor descriptor which allows i64, i32 and ui32 base address in addition to ui64 - Explicitly state that create_tdesc and update_offset ops are not valid in SIMT mode. create_tdesc and update_offset ops are still available for subgroup level non SIMT mode. - prefetch op adds attribute offset_align_byte to be used with integer pointer source to enable address calculation with offsets. New test cases are added for the new enforced checks. Other minor implementation change: XeGPU scatter tensor descriptor only allows 1D base memref. This was check in op verify() method. Now moved to tablegen - ODS - definition.
1 parent 3e39820 commit 1bf31c3

File tree

5 files changed

+340
-102
lines changed

5 files changed

+340
-102
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 137 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -70,28 +70,32 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
7070
future). Elements in the subview continuous in each dimension. It encodes the
7171
following important information for supporting Intel hardware features:
7272

73-
* source: an object representing (starting address/pointer of) a memory region.
73+
Arguments:
74+
- `source`: an object representing (starting address/pointer of) a memory region.
7475
It can be either a memref object, or simply a pointer represented by uint64_t type.
7576
For the case of dynamic memrefs or pointer, the shape and layout information of the
7677
memory region should be explicitly passed via `shape` and `strides` parameters.
7778

78-
* offsets: index values represents offsets from the "source" at the each dimension
79+
- `offsets`: index values represents offsets from the "source" at the each dimension
7980
at which the subview of the target memory will be created. It is encoded via
8081
"offsets" and "const_offsets", such that it can accept various forms, such as,
8182
operands (e.g., [%c0, %c]) and attributes (e.g., [2, 4]).
8283

83-
* shape: the shape information of the memory region pointed by the "source". It is
84+
- `shape`: the shape information of the memory region pointed by the "source". It is
8485
typically encoded via the MemRefType of the source, e.g., memref<4096x4096xf16>.
8586
But if "source" is simply a pointer represented as uint64_t type, or a memref
8687
type without shape information e.g., memref<?x?xf16>, the shape information has
8788
to be explicitly passed via the "shape" and "const_shape" arguments.
8889

89-
* strides: the strides of the memory region pointed by the "source". Similar to shape,
90+
- `strides`: the strides of the memory region pointed by the "source". Similar to shape,
9091
it is typically encoded via the MemRefType of the source too. But if "source" is
9192
simply a pointer represented as uint64_t type, or a memref type without shape
9293
information e.g., memref<?x?xf16>, the strides information has to be explicitly
9394
passed via the "strides" and "const_strides" argument.
9495

96+
Results:
97+
- `res`: nd tensor descriptor
98+
9599
Example 1 (suppose the tensor shape inferred by the compiler is 8x16):
96100
```mlir
97101
%0 = memref.alloc() : memref<1024x1024xf32>
@@ -560,12 +564,17 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
560564
(scattered) subviews, allowing each work-item in a subgroup specifying their own offset.
561565
It accepts the following parameters:
562566

563-
* source: a 1D memref or pointer (uint64_t) represents the flattened memory object.
564-
* offsets: a vector containing offsets of each access point. Its size
567+
Arguments:
568+
- `source`: a 1D memref or pointer (i64, i32, ui64, ui32) represents the flattened
569+
memory object.
570+
- `offsets`: a vector containing offsets of each access point. Its size
565571
is fixed to the hardware supportted subgroup size, e.g., 16 on PVC,
566572
implying each element in the vector corresponds to a work-item (SIMT lane)
567573
in the subgroup.
568574

575+
Results:
576+
- `res`: scattered tensor descriptor
577+
569578
The first dimension of the result TensorDesc corresponds to work-items, so it should
570579
match the dimension of offsets. It may also has a second dimension corresponding to
571580
the chunk_size if the chunk size is larger than 1.
@@ -596,8 +605,8 @@ def XeGPU_CreateDescOp: XeGPU_Op<"create_tdesc", [Pure, ViewLikeOpInterface]> {
596605
```
597606
}];
598607

599-
let arguments = (ins XeGPU_BaseAddrType: $source,
600-
XeGPU_OffsetType: $offsets);
608+
let arguments = (ins XeGPU_GatherScatterBaseAddrType:$source,
609+
XeGPU_OffsetType:$offsets);
601610
let results = (outs XeGPU_TensorDesc:$TensorDesc);
602611

603612
let builders = [
@@ -655,6 +664,18 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
655664
As compared to prefetch_nd, which works on non-scattered TensorDesc,
656665
it works on scattered TensorDesc instead.
657666

667+
Arguments:
668+
- `source`: represents the memory region to be loaded from, which can be either a
669+
tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
670+
In case of tensor_desc, offsets come from the producer create_tdesc op.
671+
tensor_desc cannot be used in SIMT mode.
672+
- `offsets`: represents offsets from source. required if `source` in not a TensorDescType.
673+
offsets is a vector of `index` type and vector length is either the subgroup size
674+
or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
675+
- `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache.
676+
- `offset_align_byte`: required if `source` is a pointer. If `source` is not a pointer,
677+
it is not allowed. Represents the alignment in bytes of each offset in offsets.
678+
658679
Example 1:
659680
```mlir
660681
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
@@ -666,7 +687,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
666687
Example 2:
667688
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
668689
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
669-
The source operand could be a raw pointer (uint64_t).
690+
The source operand could be a raw pointer (ui64, ui32, i64, i32).
670691
Please refer to create_tdesc for the restriction of memref.
671692
```mlir
672693
%a = memref.alloc() : memref<1024xf32>
@@ -677,13 +698,33 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
677698
: memref<1024xf32>, vector<4xindex>
678699
```
679700

701+
Example 3 (SIMT mode):
702+
SIMT mode only accepts the offsets variant.
703+
```mlir
704+
xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
705+
l2_hint = #xegpu.cache_hint<cached>,
706+
l3_hint = #xegpu.cache_hint<cached>}
707+
: memref<256xf32>, vector<1xindex>
708+
```
709+
710+
Example 4 (SIMT mode):
711+
SIMT mode only accepts the offsets variant.
712+
```mlir
713+
xegpu.prefetch %0[%1] {l1_hint = #xegpu.cache_hint<cached>,
714+
l2_hint = #xegpu.cache_hint<cached>,
715+
l3_hint = #xegpu.cache_hint<cached>,
716+
offset_align_byte = 2}
717+
: i64, vector<1xindex>
718+
```
719+
680720
}];
681721

682-
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
683-
Optional<XeGPU_OffsetType>: $offsets,
684-
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
685-
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
686-
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
722+
let arguments = (ins XeGPU_GatherScatterSourceType:$source,
723+
Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
724+
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
725+
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
726+
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint,
727+
OptionalAttr<I64Attr>:$offset_align_byte);
687728

688729
let extraClassDeclaration = extraBaseClassDeclaration # [{
689730
Type getSourceType() {
@@ -731,8 +772,26 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
731772
The mask operand masks out memory access so that it is safe to pass out-of-boundary
732773
addresses/offsets as long as they are masked. It applies to slots of SIMD lanes.
733774

734-
In SIMT mode, the result vector represents the data to be loaded by each work-item.
735-
Each work-item recieves a `chunk_size` number of elements.
775+
In SIMT mode, the result is a 1D vector that represents the data to be loaded by
776+
each work-item. If size is not 1, size should be equal to the chunk size,
777+
778+
Arguments:
779+
- `source`: represents the memory region to be loaded from, which can be either a
780+
tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
781+
In case of tensor_desc, offsets come from the producer create_tdesc op.
782+
tensor_desc cannot be used in SIMT mode.
783+
- `offsets`: represents offsets from source. required if `source` in not a TensorDescType.
784+
offsets is a vector of `index` type and vector length is either the subgroup size
785+
or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
786+
- `mask`: is a vector of `i1` type, which is used to mask out the memory access.
787+
mask is a vector of size equal to the subgroup size, or 1 in SIMT mode.
788+
scalar mask is also valid for SIMT mode.
789+
- `chunk_size`: (optional) represents contiguous number of elements to load from per work item.
790+
- `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache.
791+
792+
Results:
793+
- `res`: represents loaded data
794+
736795

737796
Example 1:
738797
```mlir
@@ -752,19 +811,10 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
752811
vector<16xi1> -> vector<16x8xf32>
753812
```
754813

755-
Example 3 (SIMT mode):
756-
```mlir
757-
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
758-
l2_hint = #xegpu.cache_hint<uncached>,
759-
l3_hint = #xegpu.cache_hint<uncached>}>
760-
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
761-
vector<16xi1> -> vector<8xf32>
762-
```
763-
764-
Example 4:
814+
Example 3:
765815
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
766816
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
767-
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
817+
The source operand could be a raw pointer (ui64, ui32, i64, i32). Please refer to create_tdesc
768818
for the restriction of memref.
769819
```mlir
770820
%a = memref.alloc() : memref<1024xf32>
@@ -776,16 +826,25 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
776826
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
777827
```
778828

829+
Example 4 (SIMT mode):
830+
SIMT mode only accepts the offsets variant. chunk_size can be inferred from result
831+
type. In this example, chunk_size is 8.
832+
```mlir
833+
%2 = xegpu.load %1[%2], %0 <{l1_hint = #xegpu.cache_hint<cached>,
834+
l2_hint = #xegpu.cache_hint<uncached>,
835+
l3_hint = #xegpu.cache_hint<uncached>}>
836+
: memref<128xf32>, vector<1xindex>, vector<1xi1> -> vector<8xf32>
837+
```
838+
779839
}];
780840

781-
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
782-
Optional<XeGPU_OffsetType>: $offsets,
783-
XeGPU_MaskType: $mask,
784-
OptionalAttr<I64Attr>: $chunk_size,
785-
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
786-
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
787-
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
788-
let results = (outs XeGPU_ValueType: $value);
841+
let arguments = (ins XeGPU_GatherScatterSourceType:$source,
842+
Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
843+
AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
844+
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
845+
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
846+
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
847+
let results = (outs AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value);
789848

790849
let extraClassDeclaration = extraBaseClassDeclaration # [{
791850

@@ -838,15 +897,31 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
838897

839898
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
840899
let summary = "store data to scattered memory locations.";
841-
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
900+
let description =
901+
[{ It (aka. store) stores data to scattered memory locations. The value is
842902
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
843903
a 2D vector instead. For the later case, dim-1 of the value correspods to the simd lanes
844904
and the dim-0 of the value corresponds to the chunk size stored per lane. So `store_scatter`
845905
has transpose effect, which is similar to `load_gather`. Therefore, a transpose attribute is
846906
introduced on purpose, making sure users are aware of this implicit transformation.
847907

848-
In SIMT mode, the input vector represents the data to be stored by each work-item.
849-
Each work-item stores a `chunk_size` number of elements.
908+
In SIMT mode, the result is a 1D vector that represents the data to be stored by
909+
each work-item. If size is not 1, size should be equal to the chunk size.
910+
911+
Arguments:
912+
- `value`: represents the data to be stored.
913+
- `dest`: represents the memory region to be stored to, which can be either a
914+
tensor_desc or a 1D memref or pointer (ui64, ui32, i64 or i32).
915+
In case of tensor_desc, offsets come from the producer create_tdesc op.
916+
tensor_desc cannot be used in SIMT mode.
917+
- `offsets`: represents offsets from dest. required if `source` in not a TensorDescType.
918+
offsets is a vector of `index` type and vector length is either the subgroup size
919+
or 1 in SIMT mode. scalar offset is also valid for SIMT mode.
920+
- `mask`: is a vector of `i1` type, which is used to mask out the memory access.
921+
mask is a vector of size equal to the subgroup size, or 1 in SIMT mode.
922+
scalar mask is also valid for SIMT mode.
923+
- `chunk_size`: (optional) represents contiguous number of elements to store to per work item.
924+
- `l1_hint`, `l2_hint`, `l3_hint`: are optional cache hints for each level of cache.
850925

851926
Example 1:
852927
```mlir
@@ -864,15 +939,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
864939
: vector<16x8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>>, vector<16xi1>
865940
```
866941

867-
Example 3 (SIMT mode):
868-
```mlir
869-
xegpu.store %0, %1, %2 <{l1_hint = #xegpu.cache_hint<uncached>,
870-
l2_hint = #xegpu.cache_hint<write_back>,
871-
l3_hint = #xegpu.cache_hint<write_through>}>
872-
: vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
873-
```
874-
875-
Example 4:
942+
Example 3:
876943
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
877944
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
878945
The dest operand could be a raw pointer (uint64_t).
@@ -888,19 +955,27 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
888955
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
889956
```
890957

958+
Example 4 (SIMT mode):
959+
SIMT mode only accepts the offsets variant. chunk_size can be inferred from value
960+
type. In this example, chunk_size is 8.
961+
```mlir
962+
xegpu.store %0, %1[%2], %3 <{l1_hint = #xegpu.cache_hint<uncached>,
963+
l2_hint = #xegpu.cache_hint<write_back>,
964+
l3_hint = #xegpu.cache_hint<write_through>}>
965+
: vector<8xf32>, memref<256xf32>, vector<1xindex>, vector<1xi1>
966+
```
967+
891968
}];
892969

893-
let arguments = (ins
894-
XeGPU_ValueType: $value,
895-
XeGPU_GatherScatterSourceType: $dest,
896-
Optional<XeGPU_OffsetType>: $offsets,
897-
XeGPU_MaskType: $mask,
898-
OptionalAttr<I64Attr>: $chunk_size,
899-
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
900-
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
901-
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
970+
let arguments = (ins AnyTypeOf<[XeGPU_ValueType, XeGPU_ScalarType]>:$value,
971+
XeGPU_GatherScatterSourceType:$dest,
972+
Optional<AnyTypeOf<[XeGPU_OffsetType, Index]>>:$offsets,
973+
AnyTypeOf<[XeGPU_MaskType, I1]>:$mask, OptionalAttr<I64Attr>:$chunk_size,
974+
OptionalAttr<XeGPU_CacheHintAttr>:$l1_hint,
975+
OptionalAttr<XeGPU_CacheHintAttr>:$l2_hint,
976+
OptionalAttr<XeGPU_CacheHintAttr>:$l3_hint);
902977

903-
let extraClassDeclaration = extraBaseClassDeclaration # [{
978+
let extraClassDeclaration = extraBaseClassDeclaration#[{
904979
Type getDestType() {
905980
return getDest().getType();
906981
}
@@ -916,6 +991,11 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
916991
return dyn_cast<xegpu::TensorDescType>(getDestType());
917992
}
918993

994+
mlir::Type getElementType() {
995+
auto type = getValue().getType();
996+
return getElementTypeOrSelf(type);
997+
}
998+
919999
VectorType getValueType() {
9201000
return llvm::dyn_cast<VectorType>(getValue().getType());
9211001
}

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,17 @@ include "mlir/IR/BuiltinTypes.td"
1616
def XeGPU_IntType: AnyTypeOf<[I1, I8, I16, I32, I64, SI1, SI8, SI16, SI32, SI64, UI1, UI8, UI16, UI32, UI64]>;
1717
def XeGPU_FloatType: AnyTypeOf<[F16, F32, F64, BF16, TF32]>;
1818
def XeGPU_ScalarType: AnyTypeOf<[XeGPU_IntType, XeGPU_FloatType]>;
19-
def XeGPU_BaseAddrType: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64, UI32, I64, I32]>;
19+
def XeGPU_PointerType : AnyTypeOf<[UI64, UI32, I64, I32]>;
20+
def XeGPU_BaseAddrType
21+
: AnyTypeOf<[Non0RankedMemRefOf<[XeGPU_ScalarType]>, XeGPU_PointerType]>;
2022
def XeGPU_DpasOprType: FixedVectorOfRankAndType<[1, 2, 3], [XeGPU_ScalarType]>;
2123
def XeGPU_DpasResType: FixedVectorOfRankAndType<[1, 2], [XeGPU_ScalarType]>;
2224
def XeGPU_OffsetType: FixedVectorOfNonZeroRankOf<[Index]>;
2325
def XeGPU_MaskType: FixedVectorOfNonZeroRankOf<[I1]>;
2426
def XeGPU_ValueType: FixedVectorOfNonZeroRankOf<[XeGPU_ScalarType]>;
2527
def XeGPU_VectorType: VectorOfRankAndType<[1,2,3,4,5,6], [XeGPU_ScalarType]>;
28+
def XeGPU_GatherScatterBaseAddrType
29+
: AnyTypeOf<[MemRefRankOf<[XeGPU_ScalarType], [1]>, XeGPU_PointerType]>;
2630

2731
// common base class for types in XeGPU dialect
2832
class XeGPUTypeDef<string name, string typeMnemonic, list<Trait> traits = [],
@@ -189,7 +193,8 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
189193
let genVerifyDecl = 1;
190194
}
191195

192-
def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
196+
def XeGPU_GatherScatterSourceType
197+
: AnyTypeOf<[XeGPU_TensorDesc, XeGPU_GatherScatterBaseAddrType]>;
193198

194199
def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
195200
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";

0 commit comments

Comments
 (0)