Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 80 additions & 17 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ include "mlir/Dialect/XeGPU/IR/XeGPUTypes.td"
include "mlir/Interfaces/ShapedOpInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/Dialect/GPU/IR/CompilationAttrInterfaces.td"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure whether this header include is necessary, So far, I didn't see the changes requiring this. Maybe It can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


// Base class for dialect operations. This operation inherits from the base
// `Op` class in OpBase.td, and provides:
Expand Down Expand Up @@ -638,25 +639,44 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {

}];

let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
Optional<XeGPU_OffsetType>: $offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);

let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getSourceType() {
return getSource().getType();
}

Value getTensorDesc() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why this method is needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to minimize the change during the transition.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this is bit risky. because the caller might expect a tensor_desc but this may return a memref or I64. I would add a TODO note.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it returns one specific variant of the source, I'd also suggest it return a TypedValue<xegpu::TensorDescType>

return getSource();
}

xegpu::TensorDescType getTensorDescType() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here dyn_cast is involved. Maybe better to rename the function to TryGetTensorDesc to avoid confusion?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name needs to keep the same to minimize the change.

return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getSourceType());
}
}];

let assemblyFormat = "$TensorDesc prop-dict attr-dict `:` qualified(type($TensorDesc))";
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)?
prop-dict
attr-dict `:` type(operands)
}];

let builders = [
OpBuilder<(ins "Value": $source,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

let hasVerifier = 1;
}

def XeGPU_LoadGatherOp : XeGPU_Op<"load", [
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemRead]>
]> {
def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let summary = "load a set of scattered data points from memory.";

let description = [{ It (aka. load) load data per each work-item. The output
Expand Down Expand Up @@ -698,16 +718,27 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [

}];

let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
let arguments = (ins XeGPU_TensorDesc_or_MemRef: $source,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
let results = (outs XeGPU_ValueType: $value);

let extraClassDeclaration = extraBaseClassDeclaration # [{

Type getSourceType() {
return getSource().getType();
}

Value getTensorDesc() {
return getSource();
}

xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getSourceType());
}

mlir::Type getElementType() {
Expand All @@ -725,15 +756,24 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [

}];

let assemblyFormat = [{$TensorDesc `,` $mask prop-dict attr-dict
`:` qualified(type($TensorDesc)) `,` type($mask) `->` type($value)}];
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)? `,`
$mask prop-dict
attr-dict `:` type(operands) `->` type($value)
}];

let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

let hasVerifier = 1;
}

def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
AllElementTypesMatch<["value", "TensorDesc"]>, MemoryEffects<[MemWrite]>
]> {
def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
let summary = "store data to scattered memory locations.";
let description = [{ It (aka. store) stores data to scattered memory locations. The value is
typically a 1D vector. But when the chunk size of the TensorDesc is larger than 1, it will be
Expand Down Expand Up @@ -772,15 +812,25 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [

let arguments = (ins
XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
XeGPU_TensorDesc_or_MemRef: $dest,
Optional<XeGPU_OffsetType>: $offsets,
XeGPU_MaskType: $mask,
OptionalAttr<I64Attr>: $chunk_size,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);

let extraClassDeclaration = extraBaseClassDeclaration # [{
Type getDestType() {
return getDest().getType();
}

Value getTensorDesc() {
return getDest();
}

xegpu::TensorDescType getTensorDescType() {
return getTensorDesc().getType();
return dyn_cast<xegpu::TensorDescType>(getDestType());
}

VectorType getValueType() {
Expand All @@ -792,8 +842,21 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [
}
}];

let assemblyFormat = [{$value `,` $TensorDesc `,` $mask prop-dict attr-dict
`:` type($value) `,` qualified(type($TensorDesc)) `,` type($mask)}];
let assemblyFormat = [{
$value `,`
$dest
(`[` $offsets^ `]`)? `,`
$mask
prop-dict
attr-dict `:` type(operands)
}];

let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

let hasVerifier = 1;
}
Expand Down
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def XeGPU_TensorDesc: XeGPUTypeDef<"TensorDesc", "tensor_desc",
let genVerifyDecl = 1;
}

def XeGPU_TensorDesc_or_MemRef : AnyTypeOf<[XeGPU_TensorDesc,Non0RankedMemRefOf<[XeGPU_ScalarType]>, UI64]>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: keep a consistent naming convention: XeGPU_TensorDescOrMemRef

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed


def XeGPU_Nbarrier: XeGPUTypeDef<"Nbarrier", "nbarrier", [], "mlir::Type"> {
let summary = "!xegpu.nbarrier a custom XeGPU type representing a barrier.";
Expand Down
116 changes: 111 additions & 5 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,66 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
return success();
}

static LogicalResult
isValidGatherScatterMemRefParams(Type maskTy, VectorType valueTy,
MemRefType memTy, int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {

if (!valueTy)
return emitError() << "Expecting a vector type result.";

auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);
auto memShape = getShapeOf(memTy);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems memShape is not used. Maybe it can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed


if (valueTy.getElementType() != memTy.getElementType())
return emitError() << "Value should have the same element type as MemRef.";

// a valid shape for SIMT case
if (valueTy.getRank() == 1) {
if (valueTy.getNumElements() != chunkSize)
return emitError() << "value elements must match chunk size " << chunkSize
<< " for SIMT code.";
return success();
}

llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
if (chunkSize > 1)
expectedMaskShape.pop_back();
if (expectedMaskShape != maskShape)
return emitError() << "Mask should match value except the chunk size dim.";

return success();
}

static LogicalResult
isValidGatherScatterRawptrParams(Type maskTy, VectorType valueTy,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these two functions are very similar. I think we can reuse/refactor isValidGatherScatterParams to achive this. I don't see a need to define 2 new functions.

If that is hard to do, at least consider moving common logic to a helper and reuse the helper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

refactored to one function.

int64_t chunkSize,
function_ref<InFlightDiagnostic()> emitError) {

if (!valueTy)
return emitError() << "Expecting a vector type result.";

auto maskShape = getShapeOf(maskTy);
auto valueShape = getShapeOf(valueTy);

// a valid shape for SIMT case
if (valueTy.getRank() == 1) {
if (valueTy.getNumElements() != chunkSize)
return emitError() << "value elements must match chunk size " << chunkSize
<< " for SIMT code.";
return success();
}

llvm::SmallVector<int64_t> expectedMaskShape(valueShape);
if (chunkSize > 1)
expectedMaskShape.pop_back();
if (expectedMaskShape != maskShape)
return emitError() << "Mask should match value except the chunk size dim.";

return success();
}

//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -644,7 +704,7 @@ LogicalResult CreateDescOp::verify() {
//===----------------------------------------------------------------------===//
LogicalResult PrefetchOp::verify() {
auto tdescTy = getTensorDescType();
if (!tdescTy.isScattered())
if (tdescTy && !tdescTy.isScattered())
return emitOpError("Expects a scattered TensorDesc.\n");

if (!isReadHintOrNone(getL1HintAttr()))
Expand All @@ -659,6 +719,13 @@ LogicalResult PrefetchOp::verify() {
return success();
}

void PrefetchOp::build(OpBuilder &builder, OperationState &state, Value source,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, source, Value(), l1_hint, l2_hint, l3_hint);
}

//===----------------------------------------------------------------------===//
// XeGPU_LoadGatherOp
//===----------------------------------------------------------------------===//
Expand All @@ -676,8 +743,27 @@ LogicalResult LoadGatherOp::verify() {
if (!isReadHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();

return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });
if (tdescTy)
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });
auto srcTy = getSourceType();
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(srcTy);

if (memTy)
return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
[&]() { return emitOpError(); });
return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}

void LoadGatherOp::build(OpBuilder &builder, OperationState &state,
Type valueType, Value source, Value mask,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, valueType, source, Value(), mask, IntegerAttr(),
l1_hint, l2_hint, l3_hint);
}

//===----------------------------------------------------------------------===//
Expand All @@ -697,8 +783,28 @@ LogicalResult StoreScatterOp::verify() {
if (!isWriteHintOrNone(getL3HintAttr()))
return emitOpError("invalid l3_hint: ") << getL3HintAttr();

return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });
if (tdescTy)
return isValidGatherScatterParams(maskTy, valueTy, tdescTy,
[&]() { return emitOpError(); });

auto destTy = getDestType();
uint64_t chunkSize = static_cast<int64_t>(getChunkSize().value_or(1));
auto memTy = dyn_cast<MemRefType>(destTy);

if (memTy)
return isValidGatherScatterMemRefParams(maskTy, valueTy, memTy, chunkSize,
[&]() { return emitOpError(); });
return isValidGatherScatterRawptrParams(maskTy, valueTy, chunkSize,
[&]() { return emitOpError(); });
}

void StoreScatterOp::build(OpBuilder &builder, OperationState &state,
Value value, Value dest, Value mask,
xegpu::CachePolicyAttr l1_hint,
xegpu::CachePolicyAttr l2_hint,
xegpu::CachePolicyAttr l3_hint) {
build(builder, state, value, dest, Value(), mask, IntegerAttr(), l1_hint,
l2_hint, l3_hint);
}

//===----------------------------------------------------------------------===//
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@ struct UnrollLoadGatherOp : public UnrollPattern<xegpu::LoadGatherOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();

if (!tdescTy.isScattered())
if (!tdescTy || !tdescTy.isScattered())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Looks like the new version is supported by this pattern? Maybe add a TODO note will help here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
Expand Down Expand Up @@ -546,7 +546,7 @@ struct UnrollPrefetchOp : public UnrollPattern<xegpu::PrefetchOp> {
Location loc = op.getLoc();
xegpu::TensorDescType tdescTy = op.getTensorDescType();

if (!tdescTy.isScattered())
if (!tdescTy || !tdescTy.isScattered())
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
Expand Down Expand Up @@ -575,7 +575,7 @@ struct UnrollStoreScatterOp : public UnrollPattern<xegpu::StoreScatterOp> {
VectorType valueTy = llvm::dyn_cast<VectorType>(op.getValue().getType());
xegpu::TensorDescType tdescTy = op.getTensorDescType();

if (!tdescTy.isScattered())
if (!tdescTy || !tdescTy.isScattered())
return failure();

std::optional<SmallVector<int64_t>> targetShape = getTargetShape(op);
Expand Down
Loading