Skip to content

[mlir][xegpu] Remove OffsetSizeAndStrideOpInterface from CreateNdDescOp #152773

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 63 additions & 108 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
void printProperties(::mlir::MLIRContext *ctx,
::mlir::OpAsmPrinter &p, const Properties &prop,
::mlir::ArrayRef<::llvm::StringRef> elidedProps) {

DictionaryAttr propAttr = dyn_cast_if_present<mlir::DictionaryAttr>(getPropertiesAsAttr(ctx, prop));

// filter out the elidedProps from propAttr, and get the resultAttr
Expand All @@ -43,7 +43,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}

if (!filteredAttrs.empty()) {
p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
p << "<" << DictionaryAttr::get(ctx, filteredAttrs) << ">";
}
}

Expand All @@ -60,8 +60,7 @@ class XeGPU_Op<string mnemonic, list<Trait> traits = []>:
}


def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface,
AttrSizedOperandSegments, OffsetSizeAndStrideOpInterface]> {
def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface, AttrSizedOperandSegments]> {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is op description section aligned with the changes now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change only impacts the implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the description, marking offsets deprecated.


let summary = "Create nd-tensor descriptor operation";
let description = [{
Expand Down Expand Up @@ -181,82 +180,38 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
return getType().getShape();
}

/// wrapper for matching with OffsetSizeAndStrideOpInterface
OperandRange getSizes() {
return getShape();
SmallVector<OpFoldResult> getMixedOffsets() {
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
auto dynamics = getOffsets();
if (statics.size() == 0 && dynamics.size() == 0)
return {};
return getMixedValues(statics, dynamics, getContext());
}

ArrayRef<int64_t> getStaticOffsets(){
auto attr = getConstOffsetsAttr();

if (attr)
return attr;
SmallVector<OpFoldResult> getMixedSizes() {
SmallVector<int64_t> statics;

int64_t rank = getMixedSizes().size();

setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
/// Get the static sizes/shape, the value passed to const_shape
/// will overide the value in memref shape.
if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
statics = llvm::to_vector(memrefTy.getShape());
if (auto attr = getConstShapeAttr())
statics = llvm::to_vector(attr.asArrayRef());

attr = getConstOffsetsAttr();
return attr;
return getMixedValues(statics, getShape(), getContext());
}

/// wrapper for matching with OffsetSizeAndStrideOpInterface
/// If source is IntegerType or `const_shape` is filled,
/// it will return `const_shape`, such that mixes of `shape`
/// and `const_shape` will be used to represent the shape of
/// source operand. They overide static shape from source memref type.
ArrayRef<int64_t> getStaticSizes() {
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
static llvm::SmallVector<int64_t, 4> emptyShape;

auto attr = getConstShapeAttr();
if (attr)
return attr;

if (llvm::isa<IntegerType>(getSourceType()))
return emptyShape;

auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
assert(memrefType && "Incorrect use of getStaticSizes");
return memrefType.getShape();
}
SmallVector<OpFoldResult> getMixedStrides() {
SmallVector<int64_t> statics;

/// wrapper for matching with OffsetSizeAndStrideOpInterface
/// If source is IntegerType or `const_strides` is filled, it
/// will return `const_strides`, such that mixes of `strides`
/// and `const_strides` will be used to represent the strides of
/// source operand. They overide static strides from source memref type.
ArrayRef<int64_t> getStaticStrides() {
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
static llvm::SmallVector<int64_t, 4> emptyStrides;

auto attr = getConstStridesAttr();
if (attr)
return attr;

if (llvm::isa<IntegerType>(getSourceType()))
return emptyStrides;

auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
assert(memrefType && "Incorrect use of getStaticStrides");
auto [strides, _] = memrefType.getStridesAndOffset();
// reuse the storage of ConstStridesAttr since strides from
// memref is not persistant
setConstStrides(strides);
attr = getConstStridesAttr();
return attr;
}
/// Get the static strides, the value passed to const_strides
/// will overide the value in memref.
if (auto memrefTy = llvm::dyn_cast<MemRefType>(getSourceType()))
statics = memrefTy.getStridesAndOffset().first;
if (auto attr = getConstStridesAttr())
statics = llvm::to_vector(attr.asArrayRef());

/// Return the expected rank of each of the`static_offsets`,
/// `static_shape` and `static_strides` attributes.
std::array<unsigned, 3> getArrayAttrMaxRanks() {
unsigned rank;
if (auto ty = llvm::dyn_cast<MemRefType>(getSourceType())) {
rank = ty.getRank();
} else {
rank = (unsigned)getMixedOffsets().size();
}
return {rank, rank, rank};
return getMixedValues(statics, getStrides(), getContext());
}

/// Return the number of leading operands before the `offsets`,
Expand Down Expand Up @@ -314,15 +269,15 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
}];

let assemblyFormat = [{
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc))
}];

let builders = [
OpBuilder<(ins "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
OpBuilder<(ins "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down Expand Up @@ -370,7 +325,7 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [

let arguments = (ins XeGPU_TensorDesc: $TensorDesc,
Variadic<Index>: $offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<UnitAttr>: $packed,
OptionalAttr<DenseI64ArrayAttr>: $transpose,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
Expand All @@ -390,16 +345,16 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
}];

let assemblyFormat = [{
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` qualified(type($TensorDesc)) `->` type($value)
}];

let builders = [
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
OpBuilder<(ins "Type": $value, "Value": $TensorDesc,
"UnitAttr": $packed, "DenseI64ArrayAttr": $transpose,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down Expand Up @@ -442,7 +397,7 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
let arguments = (ins XeGPU_ValueType: $value,
XeGPU_TensorDesc: $TensorDesc,
Variadic<Index>: $offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<XeGPU_CacheHintAttr>: $l1_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l2_hint,
OptionalAttr<XeGPU_CacheHintAttr>: $l3_hint);
Expand All @@ -458,16 +413,16 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
}];

let assemblyFormat = [{
$value `,`
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
$value `,`
$TensorDesc ``
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
prop-dict attr-dict `:` type($value) `,` qualified(type($TensorDesc))
}];

let builders = [
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
OpBuilder<(ins "Value": $value, "Value": $TensorDesc,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down Expand Up @@ -635,12 +590,12 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
l3_hint = #xegpu.cache_hint<cached>}
: !xegpu.tensor_desc<16xf16>
```

Example 2:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "prefetch with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t).
Please refer to create_tdesc for the restriction of memref.
Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%0 = arith.constant dense<[0, 16, 32, 64]> : vector<4xindex>
Expand Down Expand Up @@ -676,16 +631,16 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
}];

let assemblyFormat = [{
$source
$source
(`[` $offsets^ `]`)?
prop-dict
attr-dict `:` type(operands)
attr-dict `:` type(operands)
}];

let builders = [
OpBuilder<(ins "Value": $source,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down Expand Up @@ -723,7 +678,7 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>,
vector<16xi1> -> vector<16x8xf32>
```

Example 3 (SIMT mode):
```mlir
%2 = xegpu.load %1, %0 <{l1_hint = #xegpu.cache_hint<cached>,
Expand All @@ -732,12 +687,12 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
: !xegpu.tensor_desc<16x8xf32, #xegpu.scatter_tdesc_attr<memory_space=global, chunk_size=8>>
vector<16xi1> -> vector<8xf32>
```

Example 4:
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "load with scattered TensorTdesc".
The source operand could be a raw pointer (uint64_t). Please refer to create_tdesc
for the restriction of memref.
for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%offsets = vector.step : vector<16xindex>
Expand Down Expand Up @@ -794,14 +749,14 @@ def XeGPU_LoadGatherOp : XeGPU_Op<"load", [MemoryEffects<[MemRead]>]> {
let assemblyFormat = [{
$source
(`[` $offsets^ `]`)? `,`
$mask prop-dict
$mask prop-dict
attr-dict `:` type(operands) `->` type($value)
}];

let builders = [
OpBuilder<(ins "Type": $value, "Value": $source, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down Expand Up @@ -848,7 +803,7 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
A variant accepts memref as base pointer and an offset instead of scattered TensorTdesc.
It combines "create scattered TensorTdesc" and "store with scattered TensorTdesc".
The dest operand could be a raw pointer (uint64_t).
Please refer to create_tdesc for the restriction of memref.
Please refer to create_tdesc for the restriction of memref.
```mlir
%a = memref.alloc() : memref<1024xf32>
%val = arith.constant dense<0.0> : vector<16xf32>
Expand Down Expand Up @@ -901,15 +856,15 @@ def XeGPU_StoreScatterOp : XeGPU_Op<"store", [MemoryEffects<[MemWrite]>]> {
$value `,`
$dest
(`[` $offsets^ `]`)? `,`
$mask
prop-dict
$mask
prop-dict
attr-dict `:` type(operands)
}];

let builders = [
OpBuilder<(ins "Value": $value, "Value": $dest, "Value": $mask,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l1_hint,
"xegpu::CachePolicyAttr": $l2_hint,
"xegpu::CachePolicyAttr": $l3_hint)>
];

Expand Down
7 changes: 5 additions & 2 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
}

LogicalResult CreateNdDescOp::verify() {
auto rank = (int64_t)getMixedOffsets().size();
int64_t rank = getMixedSizes().size();
bool invalidRank = false;
bool invalidElemTy = false;

Expand All @@ -280,6 +280,9 @@ LogicalResult CreateNdDescOp::verify() {
<< " Source: " << srcMemorySpace
<< ", TensorDesc: " << tdescMemorySpace;

if (int64_t offsetRank = getMixedOffsets().size())
invalidRank |= (offsetRank != rank);

// check source type matches the rank if it is a memref.
// It also should have the same ElementType as TensorDesc.
auto memrefTy = dyn_cast<MemRefType>(getSourceType());
Expand All @@ -291,7 +294,7 @@ LogicalResult CreateNdDescOp::verify() {
if (llvm::isa<IntegerType>(getSourceType())) {
// strides and shape must present for integer source.
if (getMixedStrides().empty() || getMixedSizes().empty())
return emitOpError("Expecting strides and shape to be present for "
return emitOpError("expecting strides and shape to be present for "
"integer source.");
}

Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/XeGPU/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ func.func @create_nd_tdesc_7(%src: memref<128x128xf32>) {

// -----
func.func @create_nd_tdesc_8(%src: ui64) {
// expected-error@+1 {{'xegpu.create_nd_tdesc' op Expecting strides and shape to be present for integer source}}
// expected-error@+1 {{'xegpu.create_nd_tdesc' op expecting strides and shape to be present for integer source}}
%1 = xegpu.create_nd_tdesc %src : ui64-> !xegpu.tensor_desc<128x128xf32>
return
}

// -----
func.func @create_nd_tdesc_9(%src: ui64) {
// expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank}}
// expected-error@+1 {{expecting strides and shape to be present for integer source}}
%1 = xegpu.create_nd_tdesc %src[0, 0] : ui64-> !xegpu.tensor_desc<128x128xf32>
return
}
Expand Down Expand Up @@ -149,7 +149,7 @@ func.func @subgroup_load_nd_offset_2(%src: memref<4x8x16xf16>, %x : index) {
}

// -----
func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
func.func @subgroup_load_nd_offset_3(%src: memref<4x8x16xf16>, %x : index) {
%3 = xegpu.create_nd_tdesc %src: memref<4x8x16xf16> -> !xegpu.tensor_desc<8x16xf16>
%5 = xegpu.load_nd %3[0, 0] : !xegpu.tensor_desc<8x16xf16> -> vector<8x16xf16>
// expected-error@+1 {{Mismatched ranks between offsets and tensor descriptor}}
Expand Down Expand Up @@ -418,7 +418,7 @@ func.func @store_scatter_offset_wi_1(%src: memref<?xf16>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error@+1 {{value elements must match chunk size}}
xegpu.store %val, %src[%offsets], %mask
xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<?xf16>, vector<1xindex>, vector<1xi1>
return
}
Expand All @@ -429,7 +429,7 @@ func.func @store_scatter_offset_wi_2(%src: memref<4x4xf16>) {
%offsets = arith.constant dense<[0]> : vector<1xindex>
%mask = arith.constant dense<1>: vector<1xi1>
// expected-error@+1 {{Expecting the dest is a 1D memref or pointer}}
xegpu.store %val, %src[%offsets], %mask
xegpu.store %val, %src[%offsets], %mask
: vector<4xf16>, memref<4x4xf16>, vector<1xindex>, vector<1xi1>
return
}
Expand Down
Loading