Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b9a6d98
init code
Jianhui-Li Jul 1, 2025
2465050
add tests
Jianhui-Li Jul 2, 2025
1077871
git-clang-format
Jianhui-Li Jul 2, 2025
42baa22
add more tests
Jianhui-Li Jul 2, 2025
204d347
git-clang-format
Jianhui-Li Jul 2, 2025
2793c81
add ui64 case support
Jianhui-Li Jul 12, 2025
f23ea03
modify ui64 test
Jianhui-Li Jul 12, 2025
0bb958b
Merge branch 'main' into dialect-assembly-format
Jianhui-Li Jul 12, 2025
6793689
remove unnecessary comments
Jianhui-Li Jul 12, 2025
4a96c71
fix VectorToXeGPU tests
Jianhui-Li Jul 14, 2025
689a8a5
tweak default offset value
Jianhui-Li Jul 14, 2025
02d3795
git-clang-format
Jianhui-Li Jul 14, 2025
01718f4
add builders
Jianhui-Li Jul 15, 2025
5ef6ca9
git-clang-format
Jianhui-Li Jul 15, 2025
26a222d
Merge branch 'main' into dialect-assembly-format
Jianhui-Li Jul 15, 2025
882313f
simplify custom parser
Jianhui-Li Jul 15, 2025
456534a
add comma before shape and strides
Jianhui-Li Jul 15, 2025
b6f016e
tie the offsets rank to input tensor shape instead of tdesc
Jianhui-Li Jul 15, 2025
cd518d2
git-clang-format
Jianhui-Li Jul 15, 2025
546a3f7
addverifier for invalid cases
Jianhui-Li Jul 15, 2025
7846955
git-clang-format
Jianhui-Li Jul 16, 2025
ded9552
add comments
Jianhui-Li Jul 16, 2025
97b6e39
simplify custom print
Jianhui-Li Jul 17, 2025
ed1d48e
git-clang-format
Jianhui-Li Jul 17, 2025
b3edff6
Merge branch 'main' into dialect-assembly-format
Jianhui-Li Jul 17, 2025
d3e935b
use simpler interface for DenseI64ArrayAttr
Jianhui-Li Jul 17, 2025
205fea7
address feedback
Jianhui-Li Jul 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 32 additions & 9 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,36 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
Variadic<Index>: $offsets,
Variadic<Index>: $shape,
Variadic<Index>: $strides,
DenseI64ArrayAttr: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
OptionalAttr<DenseI64ArrayAttr>: $const_shape,
OptionalAttr<DenseI64ArrayAttr>: $const_strides
);
let results = (outs XeGPU_TensorDesc: $TensorDesc);

let assemblyFormat = [{
$source ``
custom<DynamicIndexList>($offsets, $const_offsets)
(`,` custom<DynamicIndexList>($shape, $const_shape)^
`,` custom<DynamicIndexList>($strides, $const_strides))?
attr-dict `:` type($source) `->` qualified(type($TensorDesc))
}];

// let assemblyFormat = [{
// $source
// (custom<DynamicIndexList>($offsets, $const_offsets)^)?
// (`base_shape` `:` custom<DynamicIndexList>($shape, $const_shape)^
// `base_strides` `:` custom<DynamicIndexList>($strides, $const_strides))?
// attr-dict `:` type($source) `->` qualified(type($TensorDesc))
// }];

let hasVerifier = 1;

let hasCustomAssemblyFormat = 1;

let builders = [
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,

OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>,

OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
"llvm::ArrayRef<OpFoldResult>": $shape,
"llvm::ArrayRef<OpFoldResult>": $strides)>,

OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
"llvm::ArrayRef<OpFoldResult>": $offsets)>,

Expand Down Expand Up @@ -163,9 +176,19 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
}

ArrayRef<int64_t> getStaticOffsets(){
return getConstOffsets();
auto attr = getConstOffsetsAttr();
if (llvm::isa<IntegerType>(getSourceType()) || attr)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the Source Type matter here?

Copy link

@akroviakov akroviakov Jul 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For createNdDesc, we have the same check in getStaticStrides and getStaticSizes .
The op description says the following for the optional strides and shape:

For the case of dynamic memrefs or pointer, the shape and layout information of the
memory region should be explicitly passed via shape and strides parameters

Both getStaticStrides and getStaticSizes rely on a source memref to extract strides and shape, and on a user to provide them alongside ui64.

For optional offsets, you propose to use memref shape in one of the comments below, in that case, the check starts to matter. Without a memref, the above quoted op description would also extend to, now optional, offsets.

However, since we still require shape for ui64 and we only want to match the rank, we could indirectly depend on the user input via shape :

Suggested change
if (llvm::isa<IntegerType>(getSourceType()) || attr)
if (attr)
return attr;
setConstOffsets(llvm::SmallVector<int64_t, 4>(getStaticSizes().size(), std::numeric_limits<int64_t>::max()));

This should work for both memref and ui64 sources

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try it out on your side? I still run into the same error for ui64 case (if we remove offset[0,0]).

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh. Now with some more trying, it works. Thanks!

return attr;

// The offsets are allowed to be empty. The Traits verification of OffsetSizeAndStrideOpInterface interface assumes offsets being present. So it is set to be MAX to indicate user not passed any value (kDynamic means offsets passed as variable).
setConstOffsets(llvm::SmallVector<int64_t, 4>(getTensorDescShape().size(), std::numeric_limits<int64_t>::max()));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here using TensorDescShape as anchor is not appropriate, since TensorDesc could have smaller rank than MemRef. Use Shape for MemRef is better.

//setConstOffsets(llvm::SmallVector<int64_t, 4>(getTensorDescShape().size(), mlir::ShapedType::kDynamic));
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clean up


attr = getConstOffsetsAttr();
return attr;
}


/// wrapper for matching with OffsetSizeAndStrideOpInterface
/// If source is IntegerType or `const_shape` is filled,
/// it will return `const_shape`, such that mixes of `shape`
Expand Down
254 changes: 250 additions & 4 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
ValueRange({}) /* empty dynamic shape */,
ValueRange({}) /* empty dynamic strides */,
staticOffsets /* const offsets */, {} /* empty const shape*/,
{} /* empty const strides*/);
builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
{} /* empty const shape*/, {} /* empty const strides*/);
}

void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
Expand All @@ -135,8 +135,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
shape.size() == strides.size() && shape.size() == offsets.size());

Type srcTy = source.getType();
assert(isa<IntegerType>(srcTy) ||
isa<MemRefType>(srcTy) && "Source has to be either int or memref.");
assert((isa<IntegerType>(srcTy) || isa<MemRefType>(srcTy)) &&
"Source has to be either int or memref.");

llvm::SmallVector<Value> dynamicOffsets;
llvm::SmallVector<Value> dynamicShape;
Expand Down Expand Up @@ -220,6 +220,252 @@ LogicalResult CreateNdDescOp::verify() {
return success();
}

ParseResult parseOptionalDynamicIndexList(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
DenseI64ArrayAttr &integers, DenseBoolArrayAttr &scalableFlags,
SmallVectorImpl<Type> *valueTypes = nullptr,
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {

SmallVector<int64_t, 4> integerVals;
SmallVector<bool, 4> scalableVals;
auto parseIntegerOrValue = [&]() {
OpAsmParser::UnresolvedOperand operand;
auto res = parser.parseOptionalOperand(operand);

// When encountering `[`, assume that this is a scalable index.
scalableVals.push_back(parser.parseOptionalLSquare().succeeded());

if (res.has_value() && succeeded(res.value())) {
values.push_back(operand);
integerVals.push_back(ShapedType::kDynamic);
if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
return failure();
} else {
int64_t integer;
if (failed(parser.parseInteger(integer)))
return failure();
integerVals.push_back(integer);
}

// If this is assumed to be a scalable index, verify that there's a closing
// `]`.
if (scalableVals.back() && parser.parseOptionalRSquare().failed())
return failure();
return success();
};
if (parser.parseOptionalLSquare().succeeded()) {
if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
parser.parseRSquare())
return parser.emitError(parser.getNameLoc())
<< "expected SSA value or integer";
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
scalableFlags = parser.getBuilder().getDenseBoolArrayAttr(scalableVals);
return success();
}
return success();
}

::mlir::ParseResult CreateNdDescOp::parse(::mlir::OpAsmParser &parser,
::mlir::OperationState &result) {
::mlir::OpAsmParser::UnresolvedOperand sourceRawOperand{};
::llvm::ArrayRef<::mlir::OpAsmParser::UnresolvedOperand> sourceOperands(
&sourceRawOperand, 1);
::llvm::SMLoc sourceOperandsLoc;

::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
offsetsOperands;
::llvm::SMLoc offsetsOperandsLoc;
::mlir::DenseI64ArrayAttr const_offsetsAttr;
::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4> shapeOperands;
::llvm::SMLoc shapeOperandsLoc;
::mlir::DenseI64ArrayAttr const_shapeAttr;
::llvm::SmallVector<::mlir::OpAsmParser::UnresolvedOperand, 4>
stridesOperands;
::llvm::SMLoc stridesOperandsLoc;
::mlir::DenseI64ArrayAttr const_stridesAttr;
::mlir::Type sourceRawType{};
::llvm::ArrayRef<::mlir::Type> sourceTypes(&sourceRawType, 1);
::mlir::Type TensorDescRawType{};
::llvm::ArrayRef<::mlir::Type> TensorDescTypes(&TensorDescRawType, 1);

sourceOperandsLoc = parser.getCurrentLocation();
if (parser.parseOperand(sourceRawOperand))
return ::mlir::failure();

// skip the "offsets :" at the begining if it exists
// if (::mlir::succeeded(parser.parseOptionalKeyword("offsets"))) {
// if (parser.parseColon())
// return ::mlir::failure();
//}
offsetsOperandsLoc = parser.getCurrentLocation();

DenseBoolArrayAttr scalableFlags;
auto odsResult = parseOptionalDynamicIndexList(
parser, offsetsOperands, const_offsetsAttr, scalableFlags);

if (const_offsetsAttr) {
if (odsResult)
return ::mlir::failure();
result.getOrAddProperties<CreateNdDescOp::Properties>().const_offsets =
const_offsetsAttr;
}

if (::mlir::succeeded(parser.parseOptionalKeyword("shape"))) {
if (parser.parseColon())
return ::mlir::failure();
{
shapeOperandsLoc = parser.getCurrentLocation();
auto odsResult =
parseDynamicIndexList(parser, shapeOperands, const_shapeAttr);
if (const_shapeAttr) {
if (odsResult)
return ::mlir::failure();
result.getOrAddProperties<CreateNdDescOp::Properties>().const_shape =
const_shapeAttr;
}
}

if (parser.parseKeyword("strides"))
return ::mlir::failure();
if (parser.parseColon())
return ::mlir::failure();
{
stridesOperandsLoc = parser.getCurrentLocation();
auto odsResult =
parseDynamicIndexList(parser, stridesOperands, const_stridesAttr);
if (const_stridesAttr) {
if (odsResult)
return ::mlir::failure();
result.getOrAddProperties<CreateNdDescOp::Properties>().const_strides =
const_stridesAttr;
}
}
}
{
auto loc = parser.getCurrentLocation();
if (parser.parseOptionalAttrDict(result.attributes))
return ::mlir::failure();
if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
return parser.emitError(loc)
<< "'" << result.name.getStringRef() << "' op ";
})))
return ::mlir::failure();
}
if (parser.parseColon())
return ::mlir::failure();

{
::mlir::Type type;
if (parser.parseCustomTypeWithFallback(type))
return ::mlir::failure();
sourceRawType = type;
}
if (parser.parseArrow())
return ::mlir::failure();

if (parser.parseType(TensorDescRawType))
return ::mlir::failure();

::llvm::copy(::llvm::ArrayRef<int32_t>(
{1, static_cast<int32_t>(offsetsOperands.size()),
static_cast<int32_t>(shapeOperands.size()),
static_cast<int32_t>(stridesOperands.size())}),
result.getOrAddProperties<CreateNdDescOp::Properties>()
.operandSegmentSizes.begin());

::mlir::Type odsBuildableType0 = parser.getBuilder().getIndexType();
result.addTypes(TensorDescTypes);

if (parser.resolveOperands(sourceOperands, sourceTypes, sourceOperandsLoc,
result.operands))
return ::mlir::failure();

if (parser.resolveOperands(offsetsOperands, odsBuildableType0,
offsetsOperandsLoc, result.operands))
return ::mlir::failure();

if (parser.resolveOperands(shapeOperands, odsBuildableType0, shapeOperandsLoc,
result.operands))
return ::mlir::failure();

if (parser.resolveOperands(stridesOperands, odsBuildableType0,
stridesOperandsLoc, result.operands))
return ::mlir::failure();
return ::mlir::success();
}

void CreateNdDescOp::print(::mlir::OpAsmPrinter &_odsPrinter) {
_odsPrinter << ' ';
_odsPrinter << getSource();
// Print offsets if getConstOffsetsAttr() exists, is not empty, and its first
// value is not int64_t::max.
auto constOffsetsAttr = getConstOffsetsAttr();
bool printOffsets = false;
if (constOffsetsAttr && constOffsetsAttr.size() > 0) {
auto firstVal = constOffsetsAttr.asArrayRef()[0];
if (firstVal != std::numeric_limits<int64_t>::max()) {
printOffsets = true;
}
}
if (printOffsets) {

printDynamicIndexList(_odsPrinter, *this, getOffsets(),
getConstOffsetsAttr());
}
if (((!getShape().empty()) || (getConstShapeAttr()))) {
_odsPrinter << ' ' << "shape";
_odsPrinter << ' ' << ":";
_odsPrinter << ' ';
printDynamicIndexList(_odsPrinter, *this, getShape(), getConstShapeAttr());
_odsPrinter << ' ' << "strides";
_odsPrinter << ' ' << ":";
_odsPrinter << ' ';
printDynamicIndexList(_odsPrinter, *this, getStrides(),
getConstStridesAttr());
}
::llvm::SmallVector<::llvm::StringRef, 2> elidedAttrs;
elidedAttrs.push_back("operandSegmentSizes");
elidedAttrs.push_back("const_offsets");
elidedAttrs.push_back("const_shape");
elidedAttrs.push_back("const_strides");
_odsPrinter.printOptionalAttrDict((*this)->getAttrs(), elidedAttrs);
_odsPrinter << ' ' << ":";
_odsPrinter << ' ';
{
auto type = getSource().getType();
if (auto validType = ::llvm::dyn_cast<::mlir::Type>(type))
_odsPrinter.printStrippedAttrOrType(validType);
else
_odsPrinter << type;
}
_odsPrinter << ' ' << "->";
_odsPrinter << ' ';
// _odsPrinter << getTensorDesc().getType();

_odsPrinter << "!xegpu.tensor_desc<";

auto tDesc = getTensorDesc().getType();
auto shape = tDesc.getShape();
for (int64_t dim : shape) {
if (mlir::ShapedType::isDynamic(dim))
_odsPrinter << '?';
else
_odsPrinter << dim;
_odsPrinter << 'x';
}

_odsPrinter << tDesc.getElementType();

if (auto encoding = tDesc.getEncoding())
_odsPrinter << ", " << encoding;

if (auto layout = tDesc.getLayout())
_odsPrinter << ", " << layout;

_odsPrinter << ">";
}

//===----------------------------------------------------------------------===//
// XeGPU_PrefetchNdOp
//===----------------------------------------------------------------------===//
Expand Down
45 changes: 43 additions & 2 deletions mlir/test/Dialect/XeGPU/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ gpu.func @create_nd_tdesc_1(%src: memref<24x32xf32>) {
gpu.func @create_nd_tdesc_2(%src: ui64, %w : index, %h : index, %x : index, %y : index) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]], [%[[arg2]], %[[arg1]]], [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %src[%x, %y], [%h, %w], [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg0]][%[[arg3]], %[[arg4]]] shape : [%[[arg2]], %[[arg1]]] strides : [%[[arg1]], %[[C]]] : ui64 -> !xegpu.tensor_desc<8x16xf32>
%1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides: [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
gpu.return
}

Expand Down Expand Up @@ -55,6 +55,47 @@ gpu.func @create_nd_tdesc_6(%src: memref<24x32xf32>) {
}


// CHECK: gpu.func @test_create_nd_tdesc_7(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index, %[[arg5:.*]]: memref<24x32xf32>)
gpu.func @test_create_nd_tdesc_7(%src: ui64, %w : index, %h : index, %x : index, %y : index, %src2: memref<24x32xf32>) {
//CHECK: %[[C:.*]] = arith.constant 1 : index
%c1 = arith.constant 1 : index

// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %[[arg5]] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
%3 = xegpu.create_nd_tdesc %src2 : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>

gpu.return
}

// CHECK: gpu.func @test_create_nd_tdesc_8(%[[arg0:.*]]: ui64, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index, %[[arg4:.*]]: index)
gpu.func @test_create_nd_tdesc_8(%src: ui64, %w : index, %h : index, %x : index, %y : index) {

%c1 = arith.constant 1 : index
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[0, 0] shape : [%arg2, %arg1] strides : [%arg1, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>
%2 = xegpu.create_nd_tdesc %src[0, 0] shape : [%h, %w] strides : [%w, %c1] : ui64 -> !xegpu.tensor_desc<8x16xf32>

gpu.return
}

// CHECK-LABEL: func @test_create_nd_tdesc_9({{.*}})

gpu.func @test_create_nd_tdesc_9(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {

%c1 = arith.constant 1 : index
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0[%arg3, %arg4] shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
%1 = xegpu.create_nd_tdesc %src[%x, %y] shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>

gpu.return
}

// CHECK-LABEL: func @test_create_nd_tdesc_10({{.*}})
gpu.func @test_create_nd_tdesc_10(%src: memref<?x?xf16>, %w : index, %h : index, %x : index, %y : index) {
%c1 = arith.constant 1 : index
// CHECK: %[[REG:.*]] = xegpu.create_nd_tdesc %arg0 shape : [%arg2, %arg1] strides : [%arg1, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>
%2 = xegpu.create_nd_tdesc %src shape:[%h, %w] strides:[%w, %c1] : memref<?x?xf16> -> !xegpu.tensor_desc<8x16xf16>

gpu.return
}

// CHECK: gpu.func @prefetch_nd(%[[arg0:.*]]: memref<24x32xf16>) {
gpu.func @prefetch_nd(%src: memref<24x32xf16>) {
// CHECK: %[[R0:.*]] = xegpu.create_nd_tdesc %[[arg0]][0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>
Expand Down
Loading