Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 136 additions & 18 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -628,35 +628,71 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
As compared to prefetch_nd, which works on non-scattered TensorDesc,
it works on scattered TensorDesc instead.

Example:
Example 1:
```mlir
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
```

Example 2:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t).
Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<4xindex>
```

}];

let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
Optional<XeGPU_OffsetType>: $offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);

let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getSourceType() {
return getSource().getType();
}

TypedValue<xegpu::TensorDescType> getTensorDesc() {
if (auto tdescType = getTensorDescType()) {
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
}
return TypedValue<xegpu::TensorDescType>();
}

xegpu::TensorDescType getTensorDescType() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here dyn_cast is involved. Maybe better to rename the function to TryGetTensorDesc to avoid confusion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name needs to keep the same to minimize the change.

return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
}];

let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)?
prop-dict
attr-dict `:` type(operands)
}];

let builders = [
OpBuilder<(ins "Value": $source,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

let hasVerifier = 1;
}

def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
]> {
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let summary = "load a set of scattered data points from memory.";

let description = [{ It (aka. load) load data per each work-item. The output
Expand Down Expand Up @@ -687,6 +723,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
```

Example 3 (SIMT mode):
```mlir
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
Expand All @@ -695,19 +732,48 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
vector<16xi1> -> vector<8xf32>
```

Example 4:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%offsets = vector.step : vector<16xindex>
%mask = vector.constant_mask [16]: vector<16xi1>
%val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
```

}];

let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let results = (outs XeGPU_ValueType: $value);

let extraClassDeclaration = extraBaseClassDeclaration # [{

Type getSourceType() {
return getSource().getType();
}

TypedValue<xegpu::TensorDescType> getTensorDesc() {
if (auto tdescType = getTensorDescType()) {
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getSource());
}
return TypedValue<xegpu::TensorDescType>();
}

xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getSourceType());
}

mlir::Type getElementType() {
Expand All @@ -725,15 +791,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [

}];

let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
`:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)? `,`
$mask prop-dict
attr-dict `:` type(operands) `->` type($value)
}];

let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

let hasVerifier = 1;
}

def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
]> {
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let summary = "store data to scattered memory locations.";
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
Expand Down Expand Up @@ -768,19 +843,49 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
l3_hint = #xegpu.cache_hint<write_through>}>
: vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
```

Example 4:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
The dest operand could be a raw pointer (uint64_t).
Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%val = arith.constant dense<0.0> : vector<16xf32>
%offsets = vector.step : vector<16xindex>
%mask = vector.constant_mask [16]: vector<16xi1>
xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
```

}];

let arguments = (ins
XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
XeGPU_GatherScatterSourceType: $dest,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);

let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getDestType() {
return getDest().getType();
}

TypedValue<xegpu::TensorDescType> getTensorDesc() {
if (auto tdescType = getTensorDescType()) {
return llvm::cast<TypedValue<xegpu::TensorDescType>>(getDest());
}
return TypedValue<xegpu::TensorDescType>();
}

xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getDestType());
}

VectorType getValueType() {
Expand All @@ -792,8 +897,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
}
}];

let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
`:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
let assemblyFormat = [{
$value `,`
$dest
(`[` $offsets^ `]`)? `,`
$mask
prop-dict
attr-dict `:` type(operands)
}];

let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

let hasVerifier = 1;
}
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}

def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;

def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
Expand Down
Loading