Skip to content

Commit c962234

Browse files
[mlir][xegpu] Add definition of SliceAttr (#150146)
--------- Co-authored-by: Charitha Saumya <[email protected]>
1 parent b4e8b8e commit c962234

File tree

14 files changed

+651
-109
lines changed

14 files changed

+651
-109
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/CMakeLists.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,9 @@ mlir_tablegen(XeGPUEnums.h.inc -gen-enum-decls)
1212
mlir_tablegen(XeGPUEnums.cpp.inc -gen-enum-defs)
1313
add_public_tablegen_target(MLIRXeGPUEnumsIncGen)
1414
add_dependencies(mlir-headers MLIRXeGPUEnumsIncGen)
15+
16+
set(LLVM_TARGET_DEFINITIONS XeGPUAttrs.td)
17+
mlir_tablegen(XeGPUAttrInterface.h.inc -gen-attr-interface-decls)
18+
mlir_tablegen(XeGPUAttrInterface.cpp.inc -gen-attr-interface-defs)
19+
add_public_tablegen_target(MLIRXeGPUAttrInterfaceIncGen)
20+
add_dependencies(mlir-headers MLIRXeGPUAttrInterfaceIncGen)

mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,24 +15,27 @@
1515
#include "mlir/IR/BuiltinTypes.h"
1616
#include "mlir/IR/Dialect.h"
1717
#include "mlir/IR/TypeUtilities.h"
18+
#include "mlir/IR/Value.h"
1819
#include "mlir/Interfaces/ShapedOpInterfaces.h"
1920
#include "mlir/Interfaces/SideEffectInterfaces.h"
2021
#include "mlir/Interfaces/ViewLikeInterface.h"
2122

2223
namespace mlir {
2324
namespace xegpu {
2425
class TensorDescType;
26+
class LayoutAttr;
27+
class SliceAttr;
2528
} // namespace xegpu
2629
} // namespace mlir
2730

31+
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrInterface.h.inc>
32+
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
2833
#include <mlir/Dialect/XeGPU/IR/XeGPUEnums.h.inc>
34+
2935
#define GET_ATTRDEF_CLASSES
3036
#include <mlir/Dialect/XeGPU/IR/XeGPUAttrs.h.inc>
3137
#define GET_TYPEDEF_CLASSES
3238
#include <mlir/Dialect/XeGPU/IR/XeGPUTypes.h.inc>
33-
34-
#include <mlir/Dialect/XeGPU/IR/XeGPUDialect.h.inc>
35-
3639
#define GET_OP_CLASSES
3740
#include <mlir/Dialect/XeGPU/IR/XeGPU.h.inc>
3841

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 163 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,38 @@ def XeGPU_FenceScopeAttr:
175175
let assemblyFormat = "$value";
176176
}
177177

178-
def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
178+
def LayoutTrait: AttrInterface<"LayoutTrait"> {
179+
let cppNamespace = "::mlir::xegpu";
180+
let description = [{
181+
Common trait for all XeGPU layouts.
182+
}];
183+
184+
let methods = [
185+
InterfaceMethod<"Get the rank of attribute",
186+
"int64_t",
187+
"getRank">,
188+
InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
189+
"std::optional<SmallVector<int64_t>>",
190+
"getSgLayoutAsInt">,
191+
InterfaceMethod<"Get the SgData field of the attribute as integer array",
192+
"std::optional<SmallVector<int64_t>>",
193+
"getSgDataAsInt">,
194+
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
195+
indices based on the effective subgroup layout.}],
196+
"FailureOr<SmallVector<Value>>",
197+
"delinearizeSubgroupId",
198+
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId)>,
199+
InterfaceMethod<[{Generates instructions to compute multidimensional offsets for blocks
200+
assigned to a subgroup identified by linearId. The shape parameter
201+
represents the workgroup-level problem size. Each subgroup may access
202+
multiple blocks according to round-robin distribution rules.}],
203+
"FailureOr<SmallVector<SmallVector<Value>>>",
204+
"getOffsets",
205+
(ins "OpBuilder &": $builder, "Location":$loc, "Value":$linearId, "ArrayRef<int64_t>":$shape)>
206+
];
207+
}
208+
209+
def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
179210
let summary = [{
180211
Describes the data distribution to subgroups and work-items for a tensor
181212
specified by the tensor descriptor.
@@ -330,12 +361,143 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
330361
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
331362
getLaneLayout(), getLaneData(), getOrder());
332363
}
364+
365+
std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
366+
if (DenseI32ArrayAttr layout = getSgLayout())
367+
return llvm::to_vector_of<int64_t>(layout.asArrayRef());
368+
return std::nullopt;
369+
}
370+
371+
std::optional<SmallVector<int64_t>> getSgDataAsInt() const {
372+
if (DenseI32ArrayAttr data = getSgData())
373+
return llvm::to_vector_of<int64_t>(data.asArrayRef());
374+
return std::nullopt;
375+
}
376+
377+
/// Delinearizes a linear subgroup ID into its multidimensional indices
378+
/// based on the effective subgroup layout.
379+
FailureOr<SmallVector<Value>>
380+
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
381+
382+
/// Generates instructions to compute multidimensional offsets for blocks
383+
/// assigned to a subgroup identified by linearId. The shape parameter
384+
/// represents the workgroup-level problem size. Each subgroup may access
385+
/// multiple blocks according to round-robin distribution rules.
386+
FailureOr<SmallVector<SmallVector<Value>>>
387+
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
388+
333389
}];
334390

335391
let assemblyFormat = "`<` struct(params) `>`";
336392
let genVerifyDecl = 1;
337393
}
338394

395+
396+
def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
397+
let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
398+
399+
let description = [{
400+
Like LayoutAttr, SliceAttr describes data distribution among subgroups or work-items.
401+
However, whereas LayoutAttr requires the data to have the same rank as the attribute,
402+
SliceAttr permits the data to have a lower rank. In this case, compute units in the
403+
specified dimensions (given by `$dims`) share the data, provided that the remaining
404+
ranks match the data rank. SliceAttr is commonly used by operations such as
405+
vector.multi_reduction and vector.broadcast.
406+
407+
Example:
408+
```
409+
#l = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>
410+
#r = #xegpu.slice<#l, dim = [0]>
411+
412+
%exp = math.exp %input {layout_result_0 = #l}: vector<256x128xf32>
413+
%red = vector.multi_reduction<add>, %exp, %acc [0] {layout_result_0 = #r}: vector<256x128xf32> to vector<128xf32>
414+
%bcast = vector.broadcast %red {layout_result_0 = #l} : vector<128xf32> to vector<256x128xf32>
415+
```
416+
In this example, %red is conceptually divided into 4 vectors of type vector<32xf32>, each assigned to
417+
a group of subgroups. Each group consists of 8 subgroups from the same column of sg_layout, sharing a
418+
single reduction result of type vector<32xf32>.
419+
420+
}];
421+
422+
let parameters = (ins
423+
"xegpu::LayoutTrait": $parent,
424+
"DenseI64ArrayAttr": $dims
425+
);
426+
427+
let extraClassDeclaration = [{
428+
429+
int64_t getRank() const {
430+
SliceAttr attr = flatten();
431+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
432+
return parent.getRank() - attr.getDims().size();
433+
}
434+
435+
DenseI32ArrayAttr getOrder() const {
436+
SliceAttr attr = flatten();
437+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
438+
return parent.getOrder();
439+
}
440+
441+
bool isWgLayout() const {
442+
SliceAttr attr = flatten();
443+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
444+
return parent.isWgLayout();
445+
}
446+
447+
bool isSgLayout() const {
448+
SliceAttr attr = flatten();
449+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
450+
return parent.isSgLayout();
451+
}
452+
453+
/// Returns the SgLayout of the attribute, computed by applying
454+
/// the slice dimensions to the underlying LayoutAttr.
455+
std::optional<SmallVector<int64_t>> getSgLayoutAsInt() const {
456+
SliceAttr attr = flatten();
457+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
458+
if (auto layout = parent.getSgLayoutAsInt()) {
459+
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
460+
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*layout), dims);
461+
}
462+
return std::nullopt;
463+
}
464+
465+
/// Returns the SgData of the attribute, computed by applying
466+
/// the slice dimensions to the underlying LayoutAttr.
467+
std::optional<SmallVector<int64_t>> getSgDataAsInt() const {
468+
SliceAttr attr = flatten();
469+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
470+
if (auto data = parent.getSgDataAsInt()) {
471+
ArrayRef<int64_t> dims = attr.getDims().asArrayRef();
472+
return XeGPUDialect::slice(llvm::ArrayRef<int64_t>(*data), dims);
473+
}
474+
return std::nullopt;
475+
}
476+
477+
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
478+
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
479+
/// it will coalese two slice operations and return a simplified SliceAttr
480+
/// #xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0, 1]>
481+
SliceAttr flatten() const;
482+
483+
/// Delinearizes a linear subgroup ID into its multidimensional indices
484+
/// based on the effective subgroup layout.
485+
FailureOr<SmallVector<Value>>
486+
delinearizeSubgroupId(OpBuilder &builder, Location loc, Value linearId);
487+
488+
/// Generates instructions to compute multidimensional offsets for blocks
489+
/// assigned to a subgroup identified by linearId. The shape parameter
490+
/// represents the workgroup-level problem size. Each subgroup may access
491+
/// multiple blocks according to round-robin distribution rules.
492+
FailureOr<SmallVector<SmallVector<Value>>>
493+
getOffsets(OpBuilder &builder, Location loc, Value linearId, ArrayRef<int64_t> shape);
494+
495+
}];
496+
497+
let assemblyFormat = "`<` qualified($parent) `,` `dims` `=` $dims `>`";
498+
let genVerifyDecl = 1;
499+
}
500+
339501
def XeGPU_RangeAttr : XeGPUAttr<"Range", "range"> {
340502
let summary = [{Specifies a half-open range}];
341503
let description = [{

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUDialect.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,18 @@ def XeGPU_Dialect : Dialect {
4141
/// Checks if the given shape can be evenly distributed based on the layout
4242
/// and data factors provided by the LayoutAttr.
4343
static bool isEvenlyDistributable(llvm::ArrayRef<int64_t> shape, xegpu::LayoutAttr attr);
44+
45+
/// drops/slices the shape in the specified dims, and return the rest. e.g.,
46+
/// for shape = [32, 64, 8], dims = [0, 2], it will return [64]
47+
template<typename T, typename U>
48+
static llvm::SmallVector<T> slice(llvm::ArrayRef<T> shape, llvm::ArrayRef<U> dims) {
49+
llvm::SmallVector<T> result;
50+
for (auto [i, v]: llvm::enumerate(shape)) {
51+
if (!llvm::is_contained(dims, i))
52+
result.push_back(v);
53+
}
54+
return result;
55+
}
4456
}];
4557
}
4658

mlir/lib/Dialect/XeGPU/IR/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@ add_mlir_dialect_library(MLIRXeGPUDialect
77

88
DEPENDS
99
MLIRXeGPUIncGen
10+
MLIRXeGPUAttrInterfaceIncGen
1011
MLIRXeGPUAttrsIncGen
1112
MLIRXeGPUEnumsIncGen
1213

1314
LINK_LIBS PUBLIC
1415
MLIRArithDialect
16+
MLIRIndexDialect
17+
MLIRAffineUtils
1518
MLIRArithUtils
1619
MLIRDialectUtils
1720
MLIRIR

0 commit comments

Comments
 (0)