Skip to content

[mlir][xegpu] Add definition of SliceAttr #150146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2bc70b6
add definition draft of SliceAttr
chencha3 Jul 22, 2025
3959f9e
add layout traits
chencha3 Jul 22, 2025
2027cfc
add verifier and interface
chencha3 Jul 22, 2025
638c085
add invalid unit test
chencha3 Jul 23, 2025
91048f0
add wrappers
chencha3 Jul 23, 2025
7eaf0a6
Merge branch 'main' into xegpu_slice_attr
chencha3 Jul 23, 2025
ddc42c2
update description
chencha3 Jul 23, 2025
36e2c3a
refactor
chencha3 Jul 23, 2025
6872e6d
add delinearizeSubgroupId interface
chencha3 Jul 23, 2025
ded53b4
Merge branch 'main' into xegpu_slice_attr
chencha3 Jul 23, 2025
223fab9
fix format
chencha3 Jul 23, 2025
60e20a0
add impl of getOffsets for LayoutAttr
chencha3 Jul 24, 2025
3630966
apply getOffsets in CreateNdDescOp
chencha3 Jul 25, 2025
398d69b
cleanup
chencha3 Jul 25, 2025
08e4aa9
fix a bug
chencha3 Jul 25, 2025
62aa1dd
cleanup
chencha3 Jul 25, 2025
de0a1bb
add unit test
chencha3 Jul 25, 2025
a483699
Merge branch 'main' into xegpu_slice_attr
chencha3 Jul 25, 2025
e7f2977
fix a typo
chencha3 Jul 25, 2025
e3e4a61
add unit test
chencha3 Jul 25, 2025
4d72663
Merge branch 'main' into xegpu_slice_attr
chencha3 Aug 4, 2025
3f59105
fix conflicts
chencha3 Aug 4, 2025
129312a
address comments
chencha3 Aug 4, 2025
0865612
add support for nested SliceAttr
chencha3 Aug 5, 2025
b67f2b1
add unit test for nested slice attr
chencha3 Aug 5, 2025
01e4efe
cleanup
chencha3 Aug 5, 2025
3077c6c
Update mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
chencha3 Aug 5, 2025
d1f7bac
update docs
chencha3 Aug 6, 2025
27da02a
add check for order attribute
chencha3 Aug 6, 2025
e49e1cf
clean up
chencha3 Aug 6, 2025
59de450
clean up
chencha3 Aug 6, 2025
1b16552
address comments
chencha3 Aug 8, 2025
0511e1b
cleanup
chencha3 Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRXeGPUEnumsIncGen)
add_dependencies(mlir-headers MLIRXeGPUEnumsIncGen)

set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
mlir_tablegen(XeGPUAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(XeGPUAttrInterface.cpp.inc -gen-attr-interface-defs)
add_public_tablegen_target(MLIRXeGPUAttrInterfaceIncGen)
add_dependencies(mlir-headers MLIRXeGPUAttrInterfaceIncGen)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

8 changes: 5 additions & 3 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,26 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

namespace mlir {
namespace xegpu {
class TensorDescType;
class LayoutAttr;
} // namespace xegpu
} // namespace mlir

#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>

#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
#define GET_TYPEDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>

#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>

#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>

Expand Down
120 changes: 119 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,32 @@ def XeGPU_FenceScopeAttr:
let assemblyFormat = "$value";
}

def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
def LayoutTrait: AttrInterface<"LayoutTrait"> {
let cppNamespace = "::mlir::xegpu";
let description = [{
Common trait for all XeGPU layouts.
}];

let methods = [
InterfaceMethod<"Get the effective sg layout",
"std::optional<SmallVector<int64_t>>",
"getEffectiveSgLayout">,
InterfaceMethod<"Get the effective sg data",
"std::optional<SmallVector<int64_t>>",
"getEffectiveSgData">,
InterfaceMethod<"Delinearize the Subgroup Id",
"FailureOr<SmallVector<Value>>",
"delinearizeSubgroupId",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,

InterfaceMethod<"Get the local offset to be accessed by the given subgroup Id",
"FailureOr<SmallVector<SmallVector<Value>>>",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the need for having a vector<vector<>> here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since each subgroup may be assigned with multiple blocks.

"getOffsets",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
];
}

def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
let summary = [{
Describes the data distribution to subgroups and work-items for a tensor
specified by the tensor descriptor.
Expand Down Expand Up @@ -330,12 +355,105 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
getLaneLayout(), getLaneData(), getOrder());
}

std::optional<SmallVector<int64_t>> getEffectiveSgLayout() const {
if (DenseI32ArrayAttr layout = getSgLayout())
return llvm::to_vector_of<int64_t>(layout.asArrayRef());
return std::nullopt;
}

std::optional<SmallVector<int64_t>> getEffectiveSgData() const {
if (DenseI32ArrayAttr data = getSgData())
return llvm::to_vector_of<int64_t>(data.asArrayRef());
return std::nullopt;
}

FailureOr<SmallVector<Value>>
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);

FailureOr<SmallVector<SmallVector<Value>>>
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);

}];

let assemblyFormat = "`<` struct(params) `>`";
let genVerifyDecl = 1;
}


def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is desirable to allow nested slice attribute to match the staged reduction use case, where reduction may follow another reduction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the support for nested SliceAttr has been enabled. @Jianhui-Li

let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];

let description = [{
Like LayoutAttr, SliceAttr describes data distribution among subgroups or work-items.
However, whereas LayoutAttr requires the data to have the same rank as the attribute,
SliceAttr permits the data to have a lower rank. In this case, compute units in the
specified dimensions share the data, provided that the remaining ranks match the data
rank. SliceAttr is commonly used by operations such as vector.multi_reduction and
vector.broadcast.

Example:
```
#l = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
#r = #xegpu.slice<#l, dim = 0>

%exp = math.exp %input {layout_result_0 = #l}: vector<256x128xf32>
%red = vector.multi_reduction<add>, %exp, %acc [0] {layout_result_0 = #r}: vector<256x128xf32> to vector<128xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to add a comment here explaning the output layout of %red.

here the output sg layout is still [8, 4]. however data is shared along dim 0. So effectively there are 8 slices of [1, 4] SG segments each owning [1, 32] data. For example SGs [0-7][0] owns the same 1x32 data segement. 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated it

%bcast = vector.broadcast %red {layout_result_0 = #l} : vector<128xf32> to vector<256x128xf32>
```
}];

let parameters = (ins
"xegpu::LayoutAttr": $parent,
"DenseI64ArrayAttr": $dims
);

let extraClassDeclaration = [{

int64_t getRank() const {
return getParent().getRank() - getDims().size();
}

DenseI32ArrayAttr getOrder() const {
return getParent().getOrder();
}

bool isWgLayout() const {
return getParent().isWgLayout();
}

bool isSgLayout() const {
return getParent().isSgLayout();
}

std::optional<SmallVector<int64_t>> getEffectiveSgLayout() const {
if (auto layout = getParent().getEffectiveSgLayout()) {
ArrayRef<int64_t> dims = getDims().asArrayRef();
return XeGPUDialect::dropDims(llvm::ArrayRef<int64_t>(*layout), dims);
}
return std::nullopt;
}

std::optional<SmallVector<int64_t>> getEffectiveSgData() const {
if (auto data = getParent().getEffectiveSgData()) {
ArrayRef<int64_t> dims = getDims().asArrayRef();
return XeGPUDialect::dropDims(llvm::ArrayRef<int64_t>(*data), dims);
}
return std::nullopt;
}

FailureOr<SmallVector<Value>>
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);

FailureOr<SmallVector<SmallVector<Value>>>
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);

}];

let assemblyFormat = "`<` $parent `,` `dims` `=` $dims `>`";
let genVerifyDecl = 1;
}

def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
let summary = [{Specifies a half-open range}];
let description = [{
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def XeGPU_Dialect : Dialect {
/// Checks if the given shape can be evenly distributed based on the layout
/// and data factors provided by the LayoutAttr.
static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);

/// drops the data in the specified dimension, and return the rest. e.g.,
/// for data = [32, 64, 8], dropPositions = [0, 2], it will return [64]
template<typename T, typename U>
static llvm::SmallVector<T> dropDims(llvm::ArrayRef<T> data, llvm::ArrayRef<U> dropPositions) {
Copy link
Contributor

@Jianhui-Li Jianhui-Li Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : dropDims -> sliceDims? data->dims, dropPosition -> dropDims

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed it to static llvm::SmallVector<T> slice(llvm::ArrayRef<T> shape, llvm::ArrayRef<U> dims). is it good to you?

llvm::SmallVector<T> result;
for (auto [i, v]: llvm::enumerate(data)) {
if (!llvm::is_contained(dropPositions, i))
result.push_back(v);
}
return result;
}
}];
}

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_mlir_dialect_library(MLIRXeGPUDialect

DEPENDS
MLIRXeGPUIncGen
MLIRXeGPUAttrInterfaceIncGen
MLIRXeGPUAttrsIncGen
MLIRXeGPUEnumsIncGen

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont you need to link MLIRIndexDialect ,MLIRAffineUtils here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Yes, they are needed.

Expand Down
175 changes: 175 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Index/IR/IndexOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
#include "mlir/Dialect/XeGPU/IR/XeGPUTargetInfo.h"
Expand Down Expand Up @@ -211,6 +214,178 @@ LayoutAttr::verify(llvm::function_ref<mlir::InFlightDiagnostic()> emitError,
return success();
}

FailureOr<SmallVector<Value>>
LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
Value linearId) {
// delinearizeSubgroupId is only available for
// workgroup-level layout attribute
if (!isWgLayout())
return failure();

// TODO: handle order attribute
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe assert unavailability of Order attribute?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

auto dims =
llvm::map_to_vector(*getEffectiveSgLayout(), [&](int64_t d) -> Value {
return builder.createOrFold<arith::ConstantIndexOp>(loc, d);
});

return affine::delinearizeIndex(builder, loc, linearId, dims);
}

FailureOr<SmallVector<SmallVector<Value>>>
LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to add a comment on what the purpose of this function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

ArrayRef<int64_t> shape) {
if (!isWgLayout())
return failure();

auto sgLayout = getEffectiveSgLayout().value();
SmallVector<int64_t> sgShape;
if (auto maybeSgShape = getEffectiveSgData())
sgShape = maybeSgShape.value();
else if (auto ratio = computeShapeRatio(shape, sgLayout))
sgShape = ratio.value();
else
return failure();

// distUnit[i] is the minimum value between shape[i] and
// sgLayout[i] * sgShape[i]
SmallVector<int64_t> distUnit = llvm::map_to_vector(
llvm::zip_equal(shape, computeElementwiseMul(sgLayout, sgShape)),
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });

// delinearize Ids
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
SmallVector<Value> sgIds = *maybeIds;

// nd local offset, localOffset[i] = sgId[i] * sgShape[i]
SmallVector<Value> localOffsets = llvm::map_to_vector(
llvm::zip(sgIds, sgShape), [&](const auto &t) -> Value {
return builder.createOrFold<index::MulOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});

SmallVector<SmallVector<Value>> offsets;
for (SmallVector<int64_t> unitOffs : StaticTileOffsetRange(shape, distUnit)) {
SmallVector<Value> base =
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
return builder.create<arith::ConstantIndexOp>(loc, d);
});

SmallVector<Value> adds = llvm::map_to_vector(
llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
std::get<1>(t));
});

SmallVector<Value> mods = llvm::map_to_vector(
llvm::zip_equal(adds, shape), [&](const auto &t) -> Value {
return builder.createOrFold<index::RemUOp>(
loc, std::get<0>(t),
builder.create<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});

offsets.push_back(mods);
}

return offsets;
}

//===----------------------------------------------------------------------===//
// XeGPU_SliceAttr
//===----------------------------------------------------------------------===//
LogicalResult
SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
xegpu::LayoutAttr parent, DenseI64ArrayAttr dims) {
if (!parent || !dims)
return emitError() << "expected parent layout and dims attribute";

int rank = parent.getRank();
// check every element in dims is unique and smaller than rank
llvm::SmallDenseSet<int64_t> seen;
for (int64_t dim : dims.asArrayRef()) {
if (dim >= rank)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we check if dim >= 0 here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch. Fixed.

return emitError() << "invalid dim (" << dim << ") in slice attribute.";
if (!seen.insert(dim).second)
return emitError() << "repeated dim (" << dim << ") in slice attribute.";
}
return success();
}

FailureOr<SmallVector<Value>>
SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
Value linearId) {
return getParent().delinearizeSubgroupId(builder, loc, linearId);
}

FailureOr<SmallVector<SmallVector<Value>>>
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
ArrayRef<int64_t> shape) {
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
if (!isWgLayout())
return failure();

auto sgLayout = getEffectiveSgLayout().value();

SmallVector<int64_t> sgShape;
if (auto maybeSgShape = getEffectiveSgData())
sgShape = maybeSgShape.value();
else if (auto ratio = computeShapeRatio(shape, sgLayout))
sgShape = ratio.value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape == ratio?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed the variable name to derivedShape for clarification per discussion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems you change the last place, but missed this one.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

else
return failure();

// distUnit[i] is the minimum value between shape[i] and
// sgLayout[i] * sgShape[i]
SmallVector<int64_t> distUnit = llvm::map_to_vector(
llvm::zip_equal(shape, computeElementwiseMul(sgLayout, sgShape)),
[](const auto &t) { return std::min(std::get<0>(t), std::get<1>(t)); });

// delinearize Ids
auto maybeIds = delinearizeSubgroupId(builder, loc, linearId);
if (failed(maybeIds))
return failure();
// The effective sgIds for offsets computing correspond
// to the dims that are not sliced.
ArrayRef<int64_t> dims = getDims().asArrayRef();
SmallVector<Value> sgIds =
XeGPUDialect::dropDims(ArrayRef<Value>(*maybeIds), dims);

// nd local offset, localOffset[i] = sgId[i] * sgShape[i]
SmallVector<Value> localOffsets = llvm::map_to_vector(
llvm::zip(sgIds, sgShape), [&](const auto &t) -> Value {
return builder.createOrFold<index::MulOp>(
loc, std::get<0>(t),
builder.createOrFold<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});

SmallVector<SmallVector<Value>> offsets;
for (SmallVector<int64_t> unitOffs : StaticTileOffsetRange(shape, distUnit)) {
SmallVector<Value> base =
llvm::map_to_vector(unitOffs, [&](int64_t d) -> Value {
return builder.create<arith::ConstantIndexOp>(loc, d);
});

SmallVector<Value> adds = llvm::map_to_vector(
llvm::zip_equal(base, localOffsets), [&](const auto &t) -> Value {
return builder.createOrFold<arith::AddIOp>(loc, std::get<0>(t),
std::get<1>(t));
});

SmallVector<Value> mods = llvm::map_to_vector(
llvm::zip_equal(adds, shape), [&](const auto &t) -> Value {
return builder.createOrFold<index::RemUOp>(
loc, std::get<0>(t),
builder.create<arith::ConstantIndexOp>(loc, std::get<1>(t)));
});

offsets.push_back(mods);
}

return offsets;
}

//===----------------------------------------------------------------------===//
// XeGPU_RangeAttr
//===----------------------------------------------------------------------===//
Expand Down
Loading