Skip to content

Commit 68d6866

Browse files
authored
[mlir][XeGPU] add WgToSg distribution pattern for load_matrix and store_matrix. (#154403)
1 parent 32a5adb commit 68d6866

File tree

11 files changed

+430
-286
lines changed

11 files changed

+430
-286
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPU.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
#include "mlir/Bytecode/BytecodeOpInterface.h"
1313
#include "mlir/Dialect/Arith/IR/Arith.h"
14+
#include "mlir/Dialect/Utils/IndexingUtils.h"
1415
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1516
#include "mlir/IR/BuiltinTypes.h"
1617
#include "mlir/IR/Dialect.h"
@@ -23,6 +24,7 @@
2324
namespace mlir {
2425
namespace xegpu {
2526
class TensorDescType;
27+
class DistributeLayoutAttr;
2628
class LayoutAttr;
2729
class SliceAttr;
2830
} // namespace xegpu

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -175,22 +175,36 @@ def XeGPU_FenceScopeAttr:
175175
let assemblyFormat = "$value";
176176
}
177177

178-
def LayoutTrait: AttrInterface<"LayoutTrait"> {
178+
def DistributeLayoutAttr: AttrInterface<"DistributeLayoutAttr"> {
179179
let cppNamespace = "::mlir::xegpu";
180180
let description = [{
181181
Common trait for all XeGPU layouts.
182182
}];
183183

184184
let methods = [
185+
InterfaceMethod<"Check the availability of workgroup level layouts",
186+
"bool",
187+
"isForWorkgroup">,
185188
InterfaceMethod<"Get the rank of attribute",
186189
"int64_t",
187190
"getRank">,
191+
InterfaceMethod<"Get the num of effective subgroups",
192+
"int64_t",
193+
"getNumSubgroups", (ins), [{
194+
std::optional<SmallVector<int64_t>> sgLayout = llvm::cast<ConcreteAttr>(tablegen_opaque_val).getSgLayoutAsInt();
195+
if (sgLayout.has_value())
196+
return computeProduct(*sgLayout);
197+
return 0;
198+
}], [{}]>,
188199
InterfaceMethod<"Get the SgLayout field of the attribute as integer array",
189200
"std::optional<SmallVector<int64_t>>",
190201
"getSgLayoutAsInt">,
191202
InterfaceMethod<"Get the SgData field of the attribute as integer array",
192203
"std::optional<SmallVector<int64_t>>",
193204
"getSgDataAsInt">,
205+
InterfaceMethod<"Derive a new layout by dropping sgLayout and sgData",
206+
"xegpu::DistributeLayoutAttr",
207+
"dropSgLayoutAndData">,
194208
InterfaceMethod<[{Delinearizes a linear subgroup ID into its multidimensional
195209
indices based on the effective subgroup layout.}],
196210
"FailureOr<SmallVector<Value>>",
@@ -206,7 +220,7 @@ def LayoutTrait: AttrInterface<"LayoutTrait"> {
206220
];
207221
}
208222

209-
def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
223+
def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [DistributeLayoutAttr]> {
210224
let summary = [{
211225
Describes the data distribution to subgroups and work-items for a tensor
212226
specified by the tensor descriptor.
@@ -328,12 +342,12 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
328342
];
329343

330344
let extraClassDeclaration = [{
331-
bool isWgLayout() {
345+
bool isForWorkgroup() {
332346
return getSgLayout() != nullptr;
333347
}
334348

335-
bool isSgLayout() {
336-
return !isWgLayout();
349+
bool isForSubgroup() {
350+
return !isForWorkgroup();
337351
}
338352

339353
int64_t getRank() {
@@ -393,7 +407,7 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout", [LayoutTrait]> {
393407
}
394408

395409

396-
def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
410+
def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
397411
let summary = [{Describes the data distribution and sharing among subgroups or work-items.}];
398412

399413
let description = [{
@@ -420,7 +434,7 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
420434
}];
421435

422436
let parameters = (ins
423-
"xegpu::LayoutTrait": $parent,
437+
"xegpu::DistributeLayoutAttr": $parent,
424438
"DenseI64ArrayAttr": $dims
425439
);
426440

@@ -438,16 +452,16 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
438452
return parent.getOrder();
439453
}
440454

441-
bool isWgLayout() const {
455+
bool isForWorkgroup() const {
442456
SliceAttr attr = flatten();
443457
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
444-
return parent.isWgLayout();
458+
return parent.isForWorkgroup();
445459
}
446460

447-
bool isSgLayout() const {
461+
bool isForSubgroup() const {
448462
SliceAttr attr = flatten();
449463
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
450-
return parent.isSgLayout();
464+
return parent.isForSubgroup();
451465
}
452466

453467
/// Returns the SgLayout of the attribute, computed by applying
@@ -474,6 +488,20 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [LayoutTrait]> {
474488
return std::nullopt;
475489
}
476490

491+
SliceAttr dropSgLayoutAndData() {
492+
SliceAttr attr = flatten();
493+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
494+
parent = parent.dropSgLayoutAndData();
495+
return SliceAttr::get(getContext(), parent, attr.getDims());
496+
}
497+
498+
SliceAttr dropInstData() {
499+
SliceAttr attr = flatten();
500+
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
501+
parent = parent.dropInstData();
502+
return SliceAttr::get(getContext(), parent, attr.getDims());
503+
}
504+
477505
/// flatten a nested SliceAttr, e.g., for 2-level nested SliceAttr
478506
/// #xegpu.slice<#xegpu.slice<#xegpu.layout<sg_layout = [4, 8, 12]>, dims = [0]>, dims = [0]>
479507
/// it will coalese two slice operations and return a simplified SliceAttr

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,14 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
232232
return static_cast<unsigned>(MemorySpace::Global);
233233
}
234234

235+
xegpu::DistributeLayoutAttr getLayoutAttr() {
236+
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getType().getLayout());
237+
}
238+
239+
ArrayRef<int64_t> getDataShape() {
240+
return getTensorDescShape();
241+
}
242+
235243
}];
236244
}
237245

@@ -262,6 +270,23 @@ def XeGPU_PrefetchNdOp : XeGPU_Op<"prefetch_nd", []> {
262270
xegpu::TensorDescType getTensorDescType() {
263271
return getTensorDesc().getType();
264272
}
273+
274+
SmallVector<OpFoldResult> getMixedOffsets() {
275+
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
276+
auto dynamics = getOffsets();
277+
if (statics.size() == 0 && dynamics.size() == 0)
278+
return {};
279+
return getMixedValues(statics, dynamics, getContext());
280+
}
281+
282+
xegpu::DistributeLayoutAttr getLayoutAttr() {
283+
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
284+
}
285+
286+
ArrayRef<int64_t> getDataShape() {
287+
return getTensorDescType().getShape();
288+
}
289+
265290
}];
266291

267292
let assemblyFormat = [{
@@ -343,6 +368,24 @@ def XeGPU_LoadNdOp : XeGPU_Op<"load_nd", [
343368
xegpu::TensorDescType getTensorDescType() {
344369
return getTensorDesc().getType();
345370
}
371+
372+
SmallVector<OpFoldResult> getMixedOffsets() {
373+
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
374+
auto dynamics = getOffsets();
375+
if (statics.size() == 0 && dynamics.size() == 0)
376+
return {};
377+
return getMixedValues(statics, dynamics, getContext());
378+
}
379+
380+
xegpu::DistributeLayoutAttr getLayoutAttr() {
381+
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
382+
}
383+
384+
ArrayRef<int64_t> getDataShape() {
385+
return getTensorDescType().getShape();
386+
}
387+
388+
346389
}];
347390

348391
let assemblyFormat = [{
@@ -417,6 +460,23 @@ def XeGPU_StoreNdOp : XeGPU_Op<"store_nd", [
417460
xegpu::TensorDescType getTensorDescType() {
418461
return getTensorDesc().getType();
419462
}
463+
464+
SmallVector<OpFoldResult> getMixedOffsets() {
465+
auto statics = getConstOffsets().value_or(SmallVector<int64_t>());
466+
auto dynamics = getOffsets();
467+
if (statics.size() == 0 && dynamics.size() == 0)
468+
return {};
469+
return getMixedValues(statics, dynamics, getContext());
470+
}
471+
472+
xegpu::DistributeLayoutAttr getLayoutAttr() {
473+
return dyn_cast_if_present<xegpu::DistributeLayoutAttr>(getTensorDescType().getLayout());
474+
}
475+
476+
ArrayRef<int64_t> getDataShape() {
477+
return getTensorDescType().getShape();
478+
}
479+
420480
}];
421481

422482
let assemblyFormat = [{
@@ -640,6 +700,7 @@ def XeGPU_PrefetchOp : XeGPU_Op<"prefetch", []> {
640700
xegpu::TensorDescType getTensorDescType() {
641701
return dyn_cast<xegpu::TensorDescType>(getSourceType());
642702
}
703+
643704
}];
644705

645706
let assemblyFormat = [{
@@ -1150,7 +1211,7 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
11501211
let arguments = (ins XeGPU_MemDesc:$mem_desc,
11511212
Variadic<Index>: $offsets,
11521213
DenseI64ArrayAttr: $const_offsets,
1153-
OptionalAttr<LayoutTrait>:$layout
1214+
OptionalAttr<DistributeLayoutAttr>:$layout
11541215
);
11551216
let results = (outs XeGPU_ValueType:$res);
11561217
let assemblyFormat = [{
@@ -1175,12 +1236,16 @@ def XeGPU_LoadMatrixOp: XeGPU_Op<"load_matrix", [MemoryEffects<[MemRead]>,
11751236

11761237
let builders = [
11771238
OpBuilder<(ins "Type":$res, "TypedValue<MemDescType>": $mem_desc,
1178-
"llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
1239+
"llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr": $layout)>,
11791240
];
11801241
let extraClassDeclaration = [{
11811242
SmallVector<OpFoldResult> getMixedOffsets() {
11821243
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
11831244
}
1245+
1246+
ArrayRef<int64_t> getDataShape() {
1247+
return getRes().getType().getShape();
1248+
}
11841249
}];
11851250

11861251
let hasVerifier = 1;
@@ -1194,7 +1259,7 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
11941259
XeGPU_MemDesc:$mem_desc,
11951260
Variadic<Index>: $offsets,
11961261
DenseI64ArrayAttr: $const_offsets,
1197-
OptionalAttr<LayoutTrait>:$layout
1262+
OptionalAttr<DistributeLayoutAttr>:$layout
11981263
);
11991264
let assemblyFormat = [{ $data `,` $mem_desc `` custom<DynamicIndexList>($offsets, $const_offsets)
12001265
prop-dict attr-dict `` `:` type(operands)}];
@@ -1213,12 +1278,17 @@ def XeGPU_StoreMatrixOp: XeGPU_Op<"store_matrix", [MemoryEffects<[MemWrite]>,
12131278
}];
12141279
let builders = [
12151280
OpBuilder<(ins "Value" : $data, "TypedValue<MemDescType>": $mem_desc,
1216-
"llvm::ArrayRef<OpFoldResult>": $offsets, "LayoutTrait": $layout)>,
1281+
"llvm::ArrayRef<OpFoldResult>": $offsets, "DistributeLayoutAttr": $layout)>,
12171282
];
12181283
let extraClassDeclaration = [{
12191284
SmallVector<OpFoldResult> getMixedOffsets() {
12201285
return getMixedValues(getConstOffsets(), getOffsets(), getContext());
12211286
}
1287+
1288+
ArrayRef<int64_t> getDataShape() {
1289+
return getData().getType().getShape();
1290+
}
1291+
12221292
}];
12231293

12241294
let hasVerifier = 1;

mlir/include/mlir/Dialect/XeGPU/Utils/XeGPUUtils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#define MLIR_DIALECT_XEGPU_UTILS_XEGPUUTILS_H_
1111

1212
#include "mlir/IR/BuiltinTypes.h"
13+
#include "mlir/IR/OpDefinition.h"
1314
namespace mlir {
1415

1516
class VectorType;
@@ -128,6 +129,20 @@ void doSCFStructuralTypeConversionWithTensorType(Operation *op,
128129
/// if no GPU module parent or XeVM target attribute exists.
129130
std::optional<std::string> getChipStr(Operation *op);
130131

132+
/// Generates element-wise addition ops of two arrays with automatic alignment.
133+
/// When the input arrays have different sizes, the shorter array is
134+
/// right-aligned with the longer array, and the unmatched leading elements from
135+
/// the longer array are preserved unchanged. This is commonly used for offset
136+
/// computation where higher-dimensional offsets need to be added to
137+
/// lower-dimensional adjustments.
138+
///
139+
/// Example:
140+
/// lhs = [l1, l2, l3], rhs = [r1, r2]
141+
/// Result: [11, l2+r1, l3+r2]
142+
SmallVector<OpFoldResult> addWithRightAligned(OpBuilder &builder, Location loc,
143+
ArrayRef<OpFoldResult> lhs,
144+
ArrayRef<OpFoldResult> rhs);
145+
131146
} // namespace xegpu
132147

133148
} // namespace mlir

mlir/lib/Dialect/XeGPU/IR/XeGPUDialect.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
271271
Value linearId) {
272272
// delinearizeSubgroupId is only available for
273273
// workgroup-level layout attribute
274-
if (!isWgLayout())
274+
if (!isForWorkgroup())
275275
return failure();
276276

277277
// TODO: handle order attribute
@@ -290,12 +290,13 @@ LayoutAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
290290
return affine::delinearizeIndex(builder, loc, linearId, dims);
291291
}
292292

293-
/// Implements LayoutTrait::getOffsets to generate instructions for
294-
/// computing multi-dimensional offsets when distributed by LayoutAttr.
293+
/// Implements DistributeLayoutAttr::getOffsets to generate
294+
/// instructions for computing multi-dimensional offsets when distributed by
295+
/// LayoutAttr.
295296
FailureOr<SmallVector<SmallVector<Value>>>
296297
LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
297298
ArrayRef<int64_t> shape) {
298-
if (!isWgLayout())
299+
if (!isForWorkgroup())
299300
return failure();
300301

301302
SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();
@@ -322,7 +323,7 @@ LayoutAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
322323
//===----------------------------------------------------------------------===//
323324
LogicalResult
324325
SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
325-
xegpu::LayoutTrait parent, DenseI64ArrayAttr dims) {
326+
xegpu::DistributeLayoutAttr parent, DenseI64ArrayAttr dims) {
326327
if (!parent || !dims)
327328
return emitError() << "expected parent layout and dims attribute";
328329

@@ -340,7 +341,7 @@ SliceAttr::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
340341
}
341342

342343
SliceAttr SliceAttr::flatten() const {
343-
xegpu::LayoutTrait parent = getParent();
344+
xegpu::DistributeLayoutAttr parent = getParent();
344345
SmallVector<DenseI64ArrayAttr> slicedDims({getDims()});
345346

346347
while (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(parent)) {
@@ -375,13 +376,14 @@ SliceAttr::delinearizeSubgroupId(OpBuilder &builder, Location loc,
375376
return parent.delinearizeSubgroupId(builder, loc, linearId);
376377
}
377378

378-
/// Implements LayoutTrait::getOffsets to generate instructions for
379-
/// computing multi-dimensional offsets when distributed by SliceAttr.
379+
/// Implements DistributeLayoutAttr::getOffsets to generate
380+
/// instructions for computing multi-dimensional offsets when distributed by
381+
/// SliceAttr.
380382
FailureOr<SmallVector<SmallVector<Value>>>
381383
SliceAttr::getOffsets(OpBuilder &builder, Location loc, Value linearId,
382384
ArrayRef<int64_t> shape) {
383385
assert(getRank() == static_cast<int64_t>(shape.size()) && "invalid shape.");
384-
if (!isWgLayout())
386+
if (!isForWorkgroup())
385387
return failure();
386388

387389
SmallVector<int64_t> sgLayout = getSgLayoutAsInt().value();

0 commit comments

Comments
 (0)