Skip to content
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 118 additions & 18 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -628,35 +628,65 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
As compared to prefetch_nd, which works on non-scattered TensorDesc,
it works on scattered TensorDesc instead.

Example:
Example 1:
```mlir
xegpu.prefetch %tdesc {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
```

Example 2:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
xegpu.prefetch %a[%0] {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<4xindex>
```

}];

let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
Optional<XeGPU_OffsetType>: $offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);

let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getSourceType() {
return getSource().getType();
}

Value getTensorDesc() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this method is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to minimize the change during the transition.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is bit risky. because the caller might expect a tensor_desc but this may return a memref or I64. I would add a TODO note.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it returns one specific variant of the source, I'd also suggest it return a TypedValue<xegpu::TensorDescType>

return getSource();
}

xegpu::TensorDescType getTensorDescType() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here dyn_cast is involved. Maybe better to rename the function to TryGetTensorDesc to avoid confusion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name needs to keep the same to minimize the change.

return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
}];

let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)?
prop-dict
attr-dict `:` type(operands)
}];

let builders = [
OpBuilder<(ins "Value": $source,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

let hasVerifier = 1;
}

def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
]> {
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let summary = "load a set of scattered data points from memory.";

let description = [{ It (aka. load) load data per each work-item. The output
Expand Down Expand Up @@ -687,6 +717,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
```

Example 3 (SIMT mode):
```mlir
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
Expand All @@ -695,19 +726,42 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
vector<16xi1> -> vector<8xf32>
```

Example 4:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc". The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%offsets = vector.step : vector<16xindex>
%mask = vector.constant_mask [16]: vector<16xi1>
%val = xegpu.load %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
```

}];

let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
let arguments = (ins XeGPU_GatherScatterSourceType: $source,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let results = (outs XeGPU_ValueType: $value);

let extraClassDeclaration = extraBaseClassDeclaration # [{

Type getSourceType() {
return getSource().getType();
}

Value getTensorDesc() {
return getSource();
}

xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getSourceType());
}

mlir::Type getElementType() {
Expand All @@ -725,15 +779,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [

}];

let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
`:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)? `,`
$mask prop-dict
attr-dict `:` type(operands) `->` type($value)
}];

let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

let hasVerifier = 1;
}

def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
]> {
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let summary = "store data to scattered memory locations.";
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
Expand Down Expand Up @@ -768,19 +831,43 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
l3_hint = #xegpu.cache_hint<write_through>}>
: vector<8xf32>, !xegpu.tensor_desc<16x8xf32, #xegpu.scattered_tdesc_attr<chunk_size=8>> vector<16xi1>
```

Example 4:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc. It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc". The dest operand could be a raw pointer (uint64_t). Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%val = arith.constant dense<0.0> : vector<16xf32>
%offsets = vector.step : vector<16xindex>
%mask = vector.constant_mask [16]: vector<16xi1>
xegpu.store %val, %a[%offsets], %mask {l1_hint = #xegpu.cache_hint<cached>,
l2_hint = #xegpu.cache_hint<cached>,
l3_hint = #xegpu.cache_hint<cached>}
: memref<1024xf32>, vector<16xi1>, vector<16xindex> -> vector<16xf32>
```

}];

let arguments = (ins
XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
XeGPU_GatherScatterSourceType: $dest,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);

let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getDestType() {
return getDest().getType();
}

Value getTensorDesc() {
return getDest();
}

xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getDestType());
}

VectorType getValueType() {
Expand All @@ -792,8 +879,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
}
}];

let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
`:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
let assemblyFormat = [{
$value `,`
$dest
(`[` $offsets^ `]`)? `,`
$mask
prop-dict
attr-dict `:` type(operands)
}];

let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

let hasVerifier = 1;
}
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}

def XeGPU_GatherScatterSourceType : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;

def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
Expand Down
111 changes: 105 additions & 6 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,34 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}

static LogicalResult
isValidGatherScatterBufferParams(Type maskTy, VectorType valueTy,
int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {

if (!valueTy)
return emitError() << "Expecting a vector type result.";

auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);

// a valid shape for SIMT case
if (valueTy.getRank() == 1) {
if (valueTy.getNumElements() != chunkSize)
return emitError() << "value elements must match chunk size " << chunkSize
<< " for SIMT code.";
return success();
}

llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
if (chunkSize > 1)
expectedMaskShape.pop_back();
if (expectedMaskShape != maskShape)
return emitError() << "Mask should match value except the chunk size dim.";

return success();
}

//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -644,8 +672,15 @@ LogicalResult CreateDescOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");

if (tdescTy) {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (tdescTy) {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
}
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suggested code doesn't work since the Tdesc can be 2d.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (tdescTy) {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
}
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
if (!tdescTy && getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");


if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();
Expand All @@ -659,6 +694,13 @@ LogicalResult PrefetchOp::verify() {
return success();
}

void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
}

//===----------------------------------------------------------------------===//
// XeGPU_LoadGatherOp
//===----------------------------------------------------------------------===//
Expand All @@ -667,6 +709,15 @@ LogicalResult LoadGatherOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();

if (tdescTy) {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getSource()) > 1)
return emitOpError(
"Expecting the source is a 1D memref or pointer (uint64_t).");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check the suggested change above. in summary, we should try to combine conditions and return early rather than nesting if conditions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed


if (!isReadHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();

Expand All @@ -676,8 +727,27 @@ LogicalResult LoadGatherOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();

return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });
if (tdescTy)
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });
auto srcTy = getSourceType();
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(srcTy);

if (memTy && (valueTy.getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";

return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}

void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
Type valueType, Value source, Value mask,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
l1_hint, l2_hint, l3_hint);
}

//===----------------------------------------------------------------------===//
Expand All @@ -688,6 +758,15 @@ LogicalResult StoreScatterOp::verify() {
auto maskTy = getMaskType();
auto valueTy = getValueType();

if (tdescTy) {
if (!tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");
} else {
if (getRankOf(getDest()) > 1)
return emitOpError(
"Expecting the dest is a 1D memref or pointer (uint64_t).");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check suggested code change.


if (!isWriteHintOrNone(getL1HintAttr()))
return emitOpError("invalid l1_hint: ") << getL1HintAttr();

Expand All @@ -697,8 +776,28 @@ LogicalResult StoreScatterOp::verify() {
if (!isWriteHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();

return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });
if (tdescTy)
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });

auto destTy = getDestType();
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(destTy);

if (memTy && (valueTy.getElementType() != memTy.getElementType()))
return emitError() << "Value should have the same element type as MemRef.";

return isValidGatherScatterBufferParams(maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}

void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
Value value, Value dest, Value mask,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
l2_hint, l3_hint);
}

//===----------------------------------------------------------------------===//
Expand Down
Loading