Skip to content

[mlir][xegpu] Add definition of SliceAttr #150146

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Aug 8, 2025
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2bc70b6
add definition draft of SliceAttr
chencha3 Jul 22, 2025
3959f9e
add layout traits
chencha3 Jul 22, 2025
2027cfc
add verifier and interface
chencha3 Jul 22, 2025
638c085
add invalid unit test
chencha3 Jul 23, 2025
91048f0
add wrappers
chencha3 Jul 23, 2025
7eaf0a6
Merge branch 'main' into xegpu_slice_attr
chencha3 Jul 23, 2025
ddc42c2
update description
chencha3 Jul 23, 2025
36e2c3a
refactor
chencha3 Jul 23, 2025
6872e6d
add delinearizeSubgroupId interface
chencha3 Jul 23, 2025
ded53b4
Merge branch 'main' into xegpu_slice_attr
chencha3 Jul 23, 2025
223fab9
fix format
chencha3 Jul 23, 2025
60e20a0
add impl of getOffsets for LayoutAttr
chencha3 Jul 24, 2025
3630966
apply getOffsets in CreateNdDescOp
chencha3 Jul 25, 2025
398d69b
cleanup
chencha3 Jul 25, 2025
08e4aa9
fix a bug
chencha3 Jul 25, 2025
62aa1dd
cleanup
chencha3 Jul 25, 2025
de0a1bb
add unit test
chencha3 Jul 25, 2025
a483699
Merge branch 'main' into xegpu_slice_attr
chencha3 Jul 25, 2025
e7f2977
fix a typo
chencha3 Jul 25, 2025
e3e4a61
add unit test
chencha3 Jul 25, 2025
4d72663
Merge branch 'main' into xegpu_slice_attr
chencha3 Aug 4, 2025
3f59105
fix conflicts
chencha3 Aug 4, 2025
129312a
address comments
chencha3 Aug 4, 2025
0865612
add support for nested SliceAttr
chencha3 Aug 5, 2025
b67f2b1
add unit test for nested slice attr
chencha3 Aug 5, 2025
01e4efe
cleanup
chencha3 Aug 5, 2025
3077c6c
Update mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
chencha3 Aug 5, 2025
d1f7bac
update docs
chencha3 Aug 6, 2025
27da02a
add check for order attribute
chencha3 Aug 6, 2025
e49e1cf
clean up
chencha3 Aug 6, 2025
59de450
clean up
chencha3 Aug 6, 2025
1b16552
address comments
chencha3 Aug 8, 2025
0511e1b
cleanup
chencha3 Aug 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,9 @@ mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRXeGPUEnumsIncGen)
add_dependencies(mlir-headers MLIRXeGPUEnumsIncGen)

set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
mlir_tablegen(XeGPUAttrInterface.h.inc -gen-attr-interface-decls)
mlir_tablegen(XeGPUAttrInterface.cpp.inc -gen-attr-interface-defs)
add_public_tablegen_target(MLIRXeGPUAttrInterfaceIncGen)
add_dependencies(mlir-headers MLIRXeGPUAttrInterfaceIncGen)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

newline

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

9 changes: 6 additions & 3 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,27 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/IR/Value.h"
#include "mlir/Interfaces/ShapedOpInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

namespace mlir {
namespace xegpu {
class TensorDescType;
class LayoutAttr;
class SliceAttr;
} // namespace xegpu
} // namespace mlir

#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>

#define GET_ATTRDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
#define GET_TYPEDEF_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>

#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>

#define GET_OP_CLASSES
#include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>

Expand Down
164 changes: 163 additions & 1 deletion mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,38 @@ def XeGPU_FenceScopeAttr:
let assemblyFormat = "$value";
}

def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
def LayoutTrait: AttrInterface<"LayoutTrait"> {
let cppNamespace = "::mlir::xegpu";
let description = [{
Common trait for all XeGPU layouts.
}];

let methods = [
InterfaceMethod<"Get the rank of attribute",
"int64_t",
"getRank">,
InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
"std::optional<SmallVector<int64_t>>",
"getSgLayoutAsInt">,
InterfaceMethod<"Get the SgData field of the attribute as integer array",
"std::optional<SmallVector<int64_t>>",
"getSgDataAsInt">,
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
indices based on the effective subgroup layout.}],
"FailureOr<SmallVector<Value>>",
"delinearizeSubgroupId",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
assigned to a subgroup identified by linearId. The shape parameter
represents the workgroup-level problem size. Each subgroup may access
multiple blocks according to round-robin distribution rules.}],
"FailureOr<SmallVector<SmallVector<Value>>>",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the need for having a vector<vector<>> here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since each subgroup may be assigned with multiple blocks.

"getOffsets",
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
];
}

def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
let summary = [{
Describes the data distribution to subgroups and work-items for a tensor
specified by the tensor descriptor.
Expand Down Expand Up @@ -330,12 +361,143 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
getLaneLayout(), getLaneData(), getOrder());
}

std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
if (DenseI32ArrayAttr layout = getSgLayout())
return llvm::to_vector_of<int64_t>(layout.asArrayRef());
return std::nullopt;
}

std::optional<SmallVector<int64_t>> getSgDataAsInt() const {
if (DenseI32ArrayAttr data = getSgData())
return llvm::to_vector_of<int64_t>(data.asArrayRef());
return std::nullopt;
}

/// Delinearizes a linear subgroup ID into its multidimensional indices
/// based on the effective subgroup layout.
FailureOr<SmallVector<Value>>
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);

/// Generates instructions to compute multidimensional offsets for blocks
/// assigned to a subgroup identified by linearId. The shape parameter
/// represents the workgroup-level problem size. Each subgroup may access
/// multiple blocks according to round-robin distribution rules.
FailureOr<SmallVector<SmallVector<Value>>>
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);

}];

let assemblyFormat = "`<` struct(params) `>`";
let genVerifyDecl = 1;
}


def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is desirable to allow nested slice attribute to match the staged reduction use case, where reduction may follow another reduction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will add it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the support for nested SliceAttr has been enabled. @Jianhui-Li

let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];

let description = [{
Like LayoutAttr, SliceAttr describes data distribution among subgroups or work-items.
However, whereas LayoutAttr requires the data to have the same rank as the attribute,
SliceAttr permits the data to have a lower rank. In this case, compute units in the
specified dimensions (given by `$dims`) share the data, provided that the remaining
ranks match the data rank. SliceAttr is commonly used by operations such as
vector.multi_reduction and vector.broadcast.

Example:
```
#l = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
#r = #xegpu.slice<#l, dim = [0]>

%exp = math.exp %input {layout_result_0 = #l}: vector<256x128xf32>
%red = vector.multi_reduction<add>, %exp, %acc [0] {layout_result_0 = #r}: vector<256x128xf32> to vector<128xf32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to add a comment here explaning the output layout of %red.

here the output sg layout is still [8, 4]. however data is shared along dim 0. So effectively there are 8 slices of [1, 4] SG segments each owning [1, 32] data. For example SGs [0-7][0] owns the same 1x32 data segement. 

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated it

%bcast = vector.broadcast %red {layout_result_0 = #l} : vector<128xf32> to vector<256x128xf32>
```
In this example, %red is conceptually divided into 4 vectors of type vector<32xf32>, each assigned to
a group of subgroups. Each group consists of 8 subgroups from the same column of sg_layout, sharing a
single reduction result of type vector<32xf32>.

}];

let parameters = (ins
"xegpu::LayoutTrait": $parent,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is parent a trait and not LayoutAttr?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is to support using both LayoutAttr and SliceAttr as parent. The later one is nested definition.

"DenseI64ArrayAttr": $dims
);

let extraClassDeclaration = [{

int64_t getRank() const {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
return parent.getRank() - attr.getDims().size();
}

DenseI32ArrayAttr getOrder() const {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
return parent.getOrder();
}

bool isWgLayout() const {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
return parent.isWgLayout();
}

bool isSgLayout() const {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
return parent.isSgLayout();
}

/// Returns the SgLayout of the attribute, computed by applying
/// the slice dimensions to the underlying LayoutAttr.
std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
if (auto layout = parent.getSgLayoutAsInt()) {
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*layout), dims);
}
return std::nullopt;
}

/// Returns the SgData of the attribute, computed by applying
/// the slice dimensions to the underlying LayoutAttr.
std::optional<SmallVector<int64_t>> getSgDataAsInt() const {
SliceAttr attr = flatten();
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
if (auto data = parent.getSgDataAsInt()) {
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*data), dims);
}
return std::nullopt;
}

/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
/// it will coalese two slice operations and return a simplified SliceAttr
/// #xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0, 1]>
SliceAttr flatten() const;

/// Delinearizes a linear subgroup ID into its multidimensional indices
/// based on the effective subgroup layout.
FailureOr<SmallVector<Value>>
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);

/// Generates instructions to compute multidimensional offsets for blocks
/// assigned to a subgroup identified by linearId. The shape parameter
/// represents the workgroup-level problem size. Each subgroup may access
/// multiple blocks according to round-robin distribution rules.
FailureOr<SmallVector<SmallVector<Value>>>
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);

}];

let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
let genVerifyDecl = 1;
}

def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
let summary = [{Specifies a half-open range}];
let description = [{
Expand Down
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ def XeGPU_Dialect : Dialect {
/// Checks if the given shape can be evenly distributed based on the layout
/// and data factors provided by the LayoutAttr.
static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);

/// drops/slices the shape in the specified dims, and return the rest. e.g.,
/// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
template<typename T, typename U>
static llvm::SmallVector<T> slice(llvm::ArrayRef<T> shape, llvm::ArrayRef<U> dims) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this unrelated to XeGPUDialect. any reason for placing here? can it be moved to Utils?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moving it to Utils will incur recursive dependence for linker.

llvm::SmallVector<T> result;
for (auto [i, v]: llvm::enumerate(shape)) {
if (!llvm::is_contained(dims, i))
result.push_back(v);
}
return result;
}
}];
}

Expand Down
3 changes: 3 additions & 0 deletions mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@ add_mlir_dialect_library(MLIRXeGPUDialect

DEPENDS
MLIRXeGPUIncGen
MLIRXeGPUAttrInterfaceIncGen
MLIRXeGPUAttrsIncGen
MLIRXeGPUEnumsIncGen

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dont you need to link MLIRIndexDialect ,MLIRAffineUtils here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. Yes, they are needed.

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRIndexDialect
MLIRAffineUtils
MLIRArithUtils
MLIRDialectUtils
MLIRIR
Expand Down
Loading
Loading