Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b9a6d98
init code
Jianhui-Li Jul 1, 2025
2465050
add tests
Jianhui-Li Jul 2, 2025
1077871
git-clang-format
Jianhui-Li Jul 2, 2025
42baa22
add more tests
Jianhui-Li Jul 2, 2025
204d347
git-clang-format
Jianhui-Li Jul 2, 2025
2793c81
add ui64 case support
Jianhui-Li Jul 12, 2025
f23ea03
modify ui64 test
Jianhui-Li Jul 12, 2025
0bb958b
Merge branch 'main' into dialect-assembly-format
Jianhui-Li Jul 12, 2025
6793689
remove unnecessary comments
Jianhui-Li Jul 12, 2025
4a96c71
fix VectorToXeGPU tests
Jianhui-Li Jul 14, 2025
689a8a5
tweak default offset value
Jianhui-Li Jul 14, 2025
02d3795
git-clang-format
Jianhui-Li Jul 14, 2025
01718f4
add builders
Jianhui-Li Jul 15, 2025
5ef6ca9
git-clang-format
Jianhui-Li Jul 15, 2025
26a222d
Merge branch 'main' into dialect-assembly-format
Jianhui-Li Jul 15, 2025
882313f
simplify custom parser
Jianhui-Li Jul 15, 2025
456534a
add comma before shape and strides
Jianhui-Li Jul 15, 2025
b6f016e
tie the offsets rank to input tensor shape instead of tdesc
Jianhui-Li Jul 15, 2025
cd518d2
git-clang-format
Jianhui-Li Jul 15, 2025
546a3f7
addverifier for invalid cases
Jianhui-Li Jul 15, 2025
7846955
git-clang-format
Jianhui-Li Jul 16, 2025
ded9552
add comments
Jianhui-Li Jul 16, 2025
97b6e39
simplify custom print
Jianhui-Li Jul 17, 2025
ed1d48e
git-clang-format
Jianhui-Li Jul 17, 2025
b3edff6
Merge branch 'main' into dialect-assembly-format
Jianhui-Li Jul 17, 2025
d3e935b
use simpler interface for DenseI64ArrayAttr
Jianhui-Li Jul 17, 2025
205fea7
address feedback
Jianhui-Li Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,34 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
Variadic<Index>: $offsets,
Variadic<Index>: $shape,
Variadic<Index>: $strides,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_shape,
OptionalAttr<DenseI64ArrayAttr>: $const_strides
);
let results = (outs XeGPU_TensorDesc: $TensorDesc);

let assemblyFormat = [{
$source ``
custom<DynamicIndexList>($offsets, $const_offsets)
(`,` custom<DynamicIndexList>($shape, $const_shape)^
`,` custom<DynamicIndexList>($strides, $const_strides))?
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
(`,` `shape` `:` custom<DynamicIndexList>($shape, $const_shape)^
`,` `strides``:` custom<DynamicIndexList>($strides, $const_strides))?
attr-dict `:` type($source) `->` qualified(type($TensorDesc))
}];

let results = (outs XeGPU_TensorDesc: $TensorDesc);

let hasVerifier = 1;

let builders = [
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,

OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>,

OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>,

OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
"llvm::ArrayRef<OpFoldResult>": $offsets)>,

Expand Down Expand Up @@ -163,9 +174,29 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
}

ArrayRef<int64_t> getStaticOffsets(){
return getConstOffsets();
auto attr = getConstOffsetsAttr();

if (attr)
return attr;

auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
int rank = 0;
if (memrefType)
//use source memref's rank, as source memref rank may be higher
rank = memrefType.getRank();
else
//nd_tdesc created from ui64, use nd_tdesc's rank
rank = getMixedSizes().size();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it is not nd_tdesc's rank, it is the rank of Shape. It is not necessary to differentiate between memrefType and else, since it is handled in getMixedSizes(). It is an abstraction interface returns the shape of the memory, regardless of it is specified by a MemrefType or via the shape/stride parameter.


// The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present.
// It is set to be MAX to indicate user not passed any value, instead of kDynamic which means offsets passed as value.
setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, std::numeric_limits<int64_t>::max()));

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably we need to reuse this constant in future. Better to define it somewhere.

static constexpr int64_t optionalValue = std::numeric_limits<int64_t>::max();

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parser and print code is supposed to removed once we finish the transition that move the offsets from create_nd_tdesc definition to load_nd. So no plan to reuse.

attr = getConstOffsetsAttr();
return attr;
}


/// wrapper for matching with OffsetSizeAndStrideOpInterface
/// If source is IntegerType or `const_shape` is filled,
/// it will return `const_shape`, such that mixes of `shape`
Expand Down
122 changes: 120 additions & 2 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,64 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
//===----------------------------------------------------------------------===//
// XeGPU_CreateNdDescOp
//===----------------------------------------------------------------------===//

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<MemRefType> source) {
[[maybe_unused]] auto ty = source.getType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think [[maybe_unused]] is not needed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ty is only used in the assert statement which is unused in release binary.

assert(ty.hasStaticShape());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some error text in assert. like "expecting a memref with static shape"


build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */,
ValueRange({}) /* empty dynamic shape */,
ValueRange({}) /* empty dynamic strides */,
builder.getDenseI64ArrayAttr({}) /* const offsets */,
builder.getDenseI64ArrayAttr({}) /* empty const shape*/,
builder.getDenseI64ArrayAttr({}) /* empty const strides*/);
}

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<MemRefType> source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
assert(shape.size() && strides.size() && shape.size() == strides.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some error text. why this invariant must be satisfied.


llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;

dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);

auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);

build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
staticStridesAttr);
}

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<IntegerType> source,
llvm::ArrayRef<OpFoldResult> shape,
llvm::ArrayRef<OpFoldResult> strides) {
assert(shape.size() && strides.size() && shape.size() == strides.size());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here.


llvm::SmallVector<int64_t> staticShape;
llvm::SmallVector<int64_t> staticStrides;
llvm::SmallVector<Value> dynamicShape;
llvm::SmallVector<Value> dynamicStrides;

dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);

auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);

build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
staticStridesAttr);
}

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Type tdesc, TypedValue<MemRefType> source,
llvm::ArrayRef<OpFoldResult> offsets) {
Expand All @@ -125,8 +183,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
ValueRange({}) /* empty dynamic shape */,
ValueRange({}) /* empty dynamic strides */,
staticOffsets /* const offsets */, {} /* empty const shape*/,
{} /* empty const strides*/);
builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
{} /* empty const shape*/, {} /* empty const strides*/);
}

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the new build methods implemented?

    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,

    OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
                   "llvm::ArrayRef<OpFoldResult>": $shape,
                   "llvm::ArrayRef<OpFoldResult>": $strides)>,

    OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
                   "llvm::ArrayRef<OpFoldResult>": $shape,
                   "llvm::ArrayRef<OpFoldResult>": $strides)>,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

Expand Down Expand Up @@ -221,6 +279,66 @@ LogicalResult CreateNdDescOp::verify() {
return success();
}

ParseResult parseOptionalDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {

SmallVector<int64_t, 4> integerVals;
auto parseIntegerOrValue = [&]() {
OpAsmParser::UnresolvedOperand operand;
auto res = parser.parseOptionalOperand(operand);

if (res.has_value() && succeeded(res.value())) {
values.push_back(operand);
integerVals.push_back(ShapedType::kDynamic);
if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
return failure();
} else {
int64_t integer;
if (failed(parser.parseInteger(integer)))
return failure();
integerVals.push_back(integer);
}

return success();
};
if (parser.parseOptionalLSquare().succeeded()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume that for no-offset case this check will fail?
Example:

create_nd %src shape: [] strides: []

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment here like "If the optional values are given there must be left bracket"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Added

if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
parser.parseRSquare())
return parser.emitError(parser.getNameLoc())
<< "expected SSA value or integer";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
<< "expected SSA value or integer";
<< "expected a list of SSA values or integers";

integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
return success();
}
return success();
}

void printOptionalDynamicIndexList(OpAsmPrinter &printer, Operation *op,
OperandRange values,
ArrayRef<int64_t> integers,
TypeRange valueTypes = TypeRange()) {

if (values.empty() && llvm::all_of(integers, [](int64_t i) {
return i == std::numeric_limits<int64_t>::max();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add a comment here to explain that here we use some place holder values to carry optional offsets.

nit: I still prefer if this constant is defined somewhere and properly documented. Will make life easier for future changes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is temporary. I added a comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can remove the valueTypes and related logic, since it is not used.

}))
return;
printer << '[';
unsigned dynamicValIdx = 0;
llvm::interleaveComma(integers, printer, [&](int64_t integer) {
if (ShapedType::isDynamic(integer)) {
printer << values[dynamicValIdx];
if (!valueTypes.empty())
printer << " : " << valueTypes[dynamicValIdx];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does values and valueTypes always have same size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems correct from the test cases. This interface is automatically generated by parser which should guarantee it.

++dynamicValIdx;
} else {
printer << integer;
}
});
printer << ']';
}

//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
//===----------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
// CHECK: return %[[VEC]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>

Expand Down
45 changes: 43 additions & 2 deletions mlir/test/Dialect/XeGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ gpu.func @create_nd_tdesc_1(%src: memref<24x32xf32>) {
gpu.func @create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], [%[[arg2]], %[[arg1]]], [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], shape : [%[[arg2]], %[[arg1]]], strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %src[%x, %y], shape:[%h, %w], strides: [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
gpu.return
}

Expand Down Expand Up @@ -62,6 +62,47 @@ gpu.func @create_nd_tdesc_7(%src: memref<8x24x32x48x64xf32>) {
}


// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index

// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
%3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

gpu.return
}

// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {

%c1 = arith.constant 1 : index
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.create_nd_tdesc %src, shape : [%h, %w], strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>

gpu.return
}

// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})

gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {

%c1 = arith.constant 1 : index
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[%arg3, %arg4], shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %src[%x, %y], shape:[%h, %w], strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>

gpu.return
}

// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0, shape : [%arg2, %arg1], strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
%2 = xegpu.create_nd_tdesc %src, shape:[%h, %w], strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>

gpu.return
}

// CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @prefetch_nd(%src: memref<24x32xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/XeGPU/subgroup-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -150,16 +150,16 @@ gpu.module @test {
// CHECK: (%[[ARG0:[0-9a-zA-Z]+]]: ui64, %[[ARG1:[0-9a-zA-Z]+]]: ui64, %[[ARG2:[0-9a-zA-Z]+]]: index,
// CHECK-SAME: %[[ARG3:[0-9a-zA-Z]+]]: index, %[[ARG4:[0-9a-zA-Z]+]]: index,
// CHECK-SAME: %[[ARG5:[0-9a-zA-Z]+]]: index, %[[ARG6:[0-9a-zA-Z]+]]: index, %[[ARG7:[0-9a-zA-Z]+]]: index) {
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK: %[[T0:.*]] = xegpu.create_nd_tdesc %[[ARG0]][{{.*}}], shape : [%[[ARG2]], %[[ARG3]]], strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK: %[[T1:.*]] = xegpu.load_nd %[[T0]] : !xegpu.tensor_desc<16x16xf16> -> vector<16xf16>
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], [%[[ARG2]], %[[ARG3]]], [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK: %[[T2:.*]] = xegpu.create_nd_tdesc %[[ARG1]][{{.*}}], shape : [%[[ARG2]], %[[ARG3]]], strides : [%[[ARG4]], %[[ARG5]]] : ui64 -> !xegpu.tensor_desc<16x16xf16>
// CHECK: xegpu.store_nd %[[T1]], %[[T2]] : vector<16xf16>, !xegpu.tensor_desc<16x16xf16>
gpu.module @test {
gpu.func @create_nd_tdesc_non_memref(%arg0: ui64, %arg1: ui64, %arg2: index, %arg3: index, %arg4: index, %arg5: index, %arg6: index, %arg7: index) {
%c0 = arith.constant 0 : index
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%0 = xegpu.create_nd_tdesc %arg0[%c0, %c0], shape:[%arg2, %arg3], strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%1 = xegpu.load_nd %0 {layout_result_0 = #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>} : !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>> -> vector<16x16xf16>
%2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], [%arg2, %arg3], [%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
%2 = xegpu.create_nd_tdesc %arg1[%c0, %c0], shape:[%arg2, %arg3], strides:[%arg4, %arg5] : ui64 -> !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
xegpu.store_nd %1, %2 : vector<16x16xf16>, !xegpu.tensor_desc<16x16xf16, #xegpu.layout<lane_layout = [1, 16], lane_data = [1, 1]>>
gpu.return
}
Expand Down
Loading