Skip to content

[mlir][xegpu] Remove OffsetSizeAndStrideOpInterface from CreateNdDescOp #152773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 63 additions & 108 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
void printProperties(::mlir::MLIRContext *ctx,
::mlir::OpAsmPrinter &p, const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {

DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));

// filter out the elidedProps from propAttr, and get the resultAttr
Expand All @@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}

if (!filteredAttrs.empty()) {
p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
}
}

Expand All @@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}


def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is op description section aligned with the changes now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change only impacts the implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the description, marking offsets deprecated.


let summary = "Create nd-tensor descriptor operation";
let description = [{
Expand Down Expand Up @@ -181,82 +180,38 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
return getType().getShape();
}

/// wrapper for matching with OffsetSizeAndStrideOpInterface
OperandRange getSizes() {
return getShape();
SmallVector<OpFoldResult> getMixedOffsets() {
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
auto dynamics = getOffsets();
if (statics.size() == 0 && dynamics.size() == 0)
return {};
return getMixedValues(statics, dynamics, getContext());
}

ArrayRef<int64_t> getStaticOffsets(){
auto attr = getConstOffsetsAttr();

if (attr)
return attr;
SmallVector<OpFoldResult> getMixedSizes() {
SmallVector<int64_t> statics;

int64_t rank = getMixedSizes().size();

setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
/// Get the static sizes/shape, the value passed to const_shape
/// will overide the value in memref shape.
if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
statics = llvm::to_vector(memrefTy.getShape());
if (auto attr = getConstShapeAttr())
statics = llvm::to_vector(attr.asArrayRef());

attr = getConstOffsetsAttr();
return attr;
return getMixedValues(statics, getShape(), getContext());
}

/// wrapper for matching with OffsetSizeAndStrideOpInterface
/// If source is IntegerType or `const_shape` is filled,
/// it will return `const_shape`, such that mixes of `shape`
/// and `const_shape` will be used to represent the shape of
/// source operand. They overide static shape from source memref type.
ArrayRef<int64_t> getStaticSizes() {
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
static llvm::SmallVector<int64_t, 4> emptyShape;

auto attr = getConstShapeAttr();
if (attr)
return attr;

if (llvm::isa<IntegerType>(getSourceType()))
return emptyShape;

auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
assert(memrefType && "Incorrect use of getStaticSizes");
return memrefType.getShape();
}
SmallVector<OpFoldResult> getMixedStrides() {
SmallVector<int64_t> statics;

/// wrapper for matching with OffsetSizeAndStrideOpInterface
/// If source is IntegerType or `const_strides` is filled, it
/// will return `const_strides`, such that mixes of `strides`
/// and `const_strides` will be used to represent the strides of
/// source operand. They overide static strides from source memref type.
ArrayRef<int64_t> getStaticStrides() {
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
static llvm::SmallVector<int64_t, 4> emptyStrides;

auto attr = getConstStridesAttr();
if (attr)
return attr;

if (llvm::isa<IntegerType>(getSourceType()))
return emptyStrides;

auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
assert(memrefType && "Incorrect use of getStaticStrides");
auto [strides, _] = memrefType.getStridesAndOffset();
// reuse the storage of ConstStridesAttr since strides from
// memref is not persistant
setConstStrides(strides);
attr = getConstStridesAttr();
return attr;
}
/// Get the static strides, the value passed to const_strides
/// will overide the value in memref.
if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
statics = memrefTy.getStridesAndOffset().first;
if (auto attr = getConstStridesAttr())
statics = llvm::to_vector(attr.asArrayRef());

/// Return the expected rank of each of the`static_offsets`,
/// `static_shape` and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
unsigned rank;
if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
rank = ty.getRank();
} else {
rank = (unsigned)getMixedOffsets().size();
}
return {rank, rank, rank};
return getMixedValues(statics, getStrides(), getContext());
}

/// Return the number of leading operands before the `offsets`,
Expand Down Expand Up @@ -314,15 +269,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
}];

let assemblyFormat = [{
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc))
}];

let builders = [
OpBuilder<(ins "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
OpBuilder<(ins "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down Expand Up @@ -370,7 +325,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [

let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
Variadic<Index>: $offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<UnitAttr>: $packed,
OptionalAttr<DenseI64ArrayAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
Expand All @@ -390,16 +345,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
}];

let assemblyFormat = [{
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
}];

let builders = [
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down Expand Up @@ -442,7 +397,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
let arguments = (ins XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
Variadic<Index>: $offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
Expand All @@ -458,16 +413,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
}];

let assemblyFormat = [{
$value `,`
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
$value `,`
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
}];

let builders = [
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down Expand Up @@ -635,12 +590,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
```

Example 2:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t).
Please refer to create_tdesc for the restriction of memref.
Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
Expand Down Expand Up @@ -676,16 +631,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
}];

let assemblyFormat = [{
$source
$source
(`[` $offsets^ `]`)?
prop-dict
attr-dict `:` type(operands)
attr-dict `:` type(operands)
}];

let builders = [
OpBuilder<(ins "Value": $source,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down Expand Up @@ -723,7 +678,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
```

Example 3 (SIMT mode):
```mlir
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
Expand All @@ -732,12 +687,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
vector<16xi1> -> vector<8xf32>
```

Example 4:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
for the restriction of memref.
for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%offsets = vector.step : vector<16xindex>
Expand Down Expand Up @@ -794,14 +749,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)? `,`
$mask prop-dict
$mask prop-dict
attr-dict `:` type(operands) `->` type($value)
}];

let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down Expand Up @@ -848,7 +803,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
The dest operand could be a raw pointer (uint64_t).
Please refer to create_tdesc for the restriction of memref.
Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%val = arith.constant dense<0.0> : vector<16xf32>
Expand Down Expand Up @@ -901,15 +856,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
$value `,`
$dest
(`[` $offsets^ `]`)? `,`
$mask
prop-dict
$mask
prop-dict
attr-dict `:` type(operands)
}];

let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down
19 changes: 8 additions & 11 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}

LogicalResult CreateNdDescOp::verify() {
auto rank = (int64_t)getMixedOffsets().size();
bool invalidRank = false;
size_t rank = getMixedSizes().size();
bool invalidRank = rank != getMixedStrides().size();
bool invalidElemTy = false;

// Memory space of created TensorDesc should match with the source.
Expand All @@ -280,31 +280,28 @@ LogicalResult CreateNdDescOp::verify() {
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;

if (size_t offsetRank = getMixedOffsets().size())
invalidRank |= (offsetRank != rank);

// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
auto memrefTy = dyn_cast<MemRefType>(getSourceType());
if (memrefTy) {
invalidRank |= (memrefTy.getRank() != rank);
if (auto memrefTy = dyn_cast<MemRefType>(getSourceType()))
invalidElemTy |= memrefTy.getElementType() != getElementType();
}

if (llvm::isa<IntegerType>(getSourceType())) {
// strides and shape must present for integer source.
if (getMixedStrides().empty() || getMixedSizes().empty())
return emitOpError("Expecting strides and shape to be present for "
return emitOpError("expecting strides and shape to be present for "
"integer source.");
}

// mismatches among shape, strides, and offsets are
// already handeled by OffsetSizeAndStrideOpInterface.
// So they are not check here.
if (invalidRank)
return emitOpError(
"Expecting the rank of shape, strides, offsets, and source (if source "
"is a memref) should match with each other.");

// check result TensorDesc rank
if (getType().getRank() > rank)
if (getType().getRank() > (int64_t)rank)
return emitOpError(
"Expecting the TensorDesc rank is not greater than the "
"ranks of shape, strides, offsets or the memref source.");
Expand Down
Loading