Skip to content

Commit cbcfd61

Browse files
committed
address comments
1 parent 1db0ca5 commit cbcfd61

File tree

2 files changed

+37
-29
lines changed

2 files changed

+37
-29
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Passes.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ def XeGPUSubgroupDistribute : Pass<"xegpu-subgroup-distribute"> {
3030
}
3131

3232
def XeGPULayoutPropagate : Pass<"xegpu-layout-propagate"> {
33-
let summary = "Propagate XeGPU layout information";
33+
let summary = "Propagate and assign XeGPU layout information";
3434
let description = [{
3535
This pass propagates the XeGPU layout information accross ops. Starting
3636
from a set of anchor operations (e.g. `dpas`, `store_nd`), this will
37-
propagate the layouts required for operands and results to the producers or
38-
consumers.
37+
propagate the layouts required for their operands to the producers. With
38+
this propagated layout information, pass will then update the XeGPU tensor
39+
descriptor type with the layout information.
3940
}];
4041
let dependentDialects = ["memref::MemRefDialect", "xegpu::XeGPUDialect",
4142
"vector::VectorDialect"];

mlir/lib/Dialect/XeGPU/Transforms/XeGPULayoutPropagate.cpp

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
#include "llvm/ADT/STLExtras.h"
3131
#include "llvm/ADT/SmallVector.h"
3232
#include "llvm/ADT/TypeSwitch.h"
33+
#include "llvm/Support/Casting.h"
3334
#include "llvm/Support/Debug.h"
3435
#include "llvm/Support/InterleavedRange.h"
3536
#include "llvm/Support/raw_ostream.h"
@@ -103,6 +104,7 @@ struct LayoutInfo {
103104
private:
104105
LaneLayout laneLayout;
105106
LaneData laneData;
107+
xegpu::LayoutAttr layoutAttr;
106108

107109
public:
108110
LayoutInfo() = default;
@@ -186,7 +188,7 @@ struct LayoutInfoLattice : public Lattice<LayoutInfo> {
186188
/// Helper Function to get the default layout for uniform values like constants.
187189
/// For 1D vector, lane_layout is [subgroupSize] and lane_data is [1].
188190
/// For 2D vector, lane_layout is [1, subgroupSize] and lane_data is [1, 1].
189-
static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
191+
static LayoutInfo getDefaultSIMTLayoutInfo(unsigned rank) {
190192
assert((rank == 1 || rank == 2) && "Expected 1D or 2D vector.");
191193
if (rank == 1)
192194
return LayoutInfo(LaneLayout({xegpu::targetinfo::subgroupSize}),
@@ -196,7 +198,7 @@ static LayoutInfo getDefaultLayoutInfo(unsigned rank) {
196198
}
197199

198200
/// Helper to get the default layout for a vector type.
199-
static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
201+
static LayoutInfo getDefaultSIMTLayoutInfo(VectorType vectorTy) {
200202
// Expecting a 1D or 2D vector.
201203
assert((vectorTy.getRank() == 1 || vectorTy.getRank() == 2) &&
202204
"Expected 1D or 2D vector.");
@@ -205,7 +207,7 @@ static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
205207
"Expected int or float element type.");
206208
// If the rank is 1, then return default layout for 1D vector.
207209
if (vectorTy.getRank() == 1)
208-
return getDefaultLayoutInfo(1);
210+
return getDefaultSIMTLayoutInfo(1);
209211
// Packing factor is determined by the element type bitwidth.
210212
int packingFactor = 1;
211213
unsigned bitwidth = vectorTy.getElementType().getIntOrFloatBitWidth();
@@ -221,8 +223,8 @@ static LayoutInfo getDefaultLayoutInfo(VectorType vectorTy) {
221223
/// `packedSizeInBitsForDefault`
222224
/// * For B operand, the data must be packed in minimum
223225
/// `packedSizeInBitsForDpasB`
224-
static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
225-
unsigned operandNum) {
226+
static LayoutInfo getSIMTLayoutInfoForDPASOperand(VectorType vectorTy,
227+
unsigned operandNum) {
226228
Type elementTy = vectorTy.getElementType();
227229
assert(elementTy.isIntOrFloat() &&
228230
"Expected int or float type in DPAS operands");
@@ -237,7 +239,7 @@ static LayoutInfo getLayoutInfoForDPASOperand(VectorType vectorTy,
237239
return LayoutInfo(layout, data);
238240
}
239241
// Otherwise, return the default layout for the vector type.
240-
return getDefaultLayoutInfo(vectorTy);
242+
return getDefaultSIMTLayoutInfo(vectorTy);
241243
}
242244

243245
//===----------------------------------------------------------------------===//
@@ -360,17 +362,18 @@ LogicalResult LayoutInfoPropagation::visitOperation(
360362
// All other ops.
361363
.Default([&](Operation *op) {
362364
for (const LayoutInfoLattice *r : results) {
363-
for (LayoutInfoLattice *operand : operands) {
364-
// Propagate the layout of the result to the operand.
365-
if (r->getValue().isAssigned())
365+
if (r->getValue().isAssigned()) {
366+
for (LayoutInfoLattice *operand : operands) {
367+
// Propagate the layout of the result to the operand.
366368
meet(operand, *r);
369+
}
367370
}
368371
}
369372
});
370373
// Add a dependency from each result to program point after the operation.
371-
for (const LayoutInfoLattice *r : results) {
374+
for (const LayoutInfoLattice *r : results)
372375
addDependency(const_cast<LayoutInfoLattice *>(r), getProgramPointAfter(op));
373-
}
376+
374377
return success();
375378
}
376379

@@ -380,7 +383,7 @@ void LayoutInfoPropagation::visitPrefetchNdOp(
380383
// Here we assign the default layout to the tensor descriptor operand of
381384
// prefetch.
382385
auto tdescTy = prefetch.getTensorDescType();
383-
auto prefetchLayout = getDefaultLayoutInfo(
386+
auto prefetchLayout = getDefaultSIMTLayoutInfo(
384387
VectorType::get(tdescTy.getShape(), tdescTy.getElementType()));
385388
// Propagate the layout to the source tensor descriptor.
386389
propagateIfChanged(operands[0], operands[0]->meet(prefetchLayout));
@@ -395,11 +398,13 @@ void LayoutInfoPropagation::visitVectorMultiReductionOp(
395398
if (!resultLayout.isAssigned())
396399
return;
397400
// We only consider 2D -> 1D reductions at this point.
398-
assert(resultLayout.getLayout().size() == 1 &&
399-
"Expected 1D layout for reduction result.");
401+
if (resultLayout.getLayout().size() != 1) {
402+
reduction.emitWarning("Expected 1D layout for reduction result. ");
403+
return;
404+
}
400405
// Given that the result is 1D, the layout of the operand should be 2D with
401406
// default layout.
402-
LayoutInfo operandLayout = getDefaultLayoutInfo(2);
407+
LayoutInfo operandLayout = getDefaultSIMTLayoutInfo(2);
403408
propagateIfChanged(operands[0], operands[0]->meet(operandLayout));
404409
// Accumulator should have the same layout as the result.
405410
propagateIfChanged(operands[1], operands[1]->meet(resultLayout));
@@ -425,22 +430,23 @@ void LayoutInfoPropagation::visitDpasOp(
425430
ArrayRef<const LayoutInfoLattice *> results) {
426431
VectorType aTy = dpas.getLhsType();
427432
VectorType bTy = dpas.getRhsType();
428-
propagateIfChanged(operands[0],
429-
operands[0]->meet(getLayoutInfoForDPASOperand(aTy, 0)));
430-
propagateIfChanged(operands[1],
431-
operands[1]->meet(getLayoutInfoForDPASOperand(bTy, 1)));
433+
propagateIfChanged(
434+
operands[0], operands[0]->meet(getSIMTLayoutInfoForDPASOperand(aTy, 0)));
435+
propagateIfChanged(
436+
operands[1], operands[1]->meet(getSIMTLayoutInfoForDPASOperand(bTy, 1)));
432437
if (operands.size() > 2) {
433438
VectorType cTy = dpas.getAccType();
434-
propagateIfChanged(operands[2],
435-
operands[2]->meet(getLayoutInfoForDPASOperand(cTy, 2)));
439+
propagateIfChanged(
440+
operands[2],
441+
operands[2]->meet(getSIMTLayoutInfoForDPASOperand(cTy, 2)));
436442
}
437443
}
438444

439445
/// Set the layout for the value and tensor descriptor operands in StoreNdOp.
440446
void LayoutInfoPropagation::visitStoreNdOp(
441447
xegpu::StoreNdOp store, ArrayRef<LayoutInfoLattice *> operands,
442448
ArrayRef<const LayoutInfoLattice *> results) {
443-
LayoutInfo storeLayout = getDefaultLayoutInfo(store.getValueType());
449+
LayoutInfo storeLayout = getDefaultSIMTLayoutInfo(store.getValueType());
444450
// Both operands should have the same layout
445451
for (LayoutInfoLattice *operand : operands) {
446452
propagateIfChanged(operand, operand->meet(storeLayout));
@@ -539,7 +545,7 @@ void LayoutInfoPropagation::visitLoadGatherOp(
539545
tensorDescLayout = valueLayout.getTransposedLayout({1, 0});
540546
}
541547
// Mask operand should have 1D default layout.
542-
LayoutInfo maskLayout = getDefaultLayoutInfo(1);
548+
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
543549
// Propagate the new layout to the tensor descriptor operand.
544550
propagateIfChanged(operands[0], operands[0]->meet(tensorDescLayout));
545551
// Propagate the new layout to the mask operand.
@@ -556,7 +562,7 @@ void LayoutInfoPropagation::visitCreateDescOp(
556562
if (!descLayout.isAssigned())
557563
return;
558564
// For offset operand propagate 1D default layout.
559-
LayoutInfo layout = getDefaultLayoutInfo(1);
565+
LayoutInfo layout = getDefaultSIMTLayoutInfo(1);
560566
propagateIfChanged(operands[1], operands[1]->meet(layout));
561567
}
562568

@@ -575,7 +581,8 @@ void LayoutInfoPropagation::visitStoreScatterOp(
575581
"Expected the first dimension of 2D tensor descriptor to be equal to "
576582
"subgroup size.");
577583

578-
LayoutInfo valueLayout = getDefaultLayoutInfo(storeScatter.getValueType());
584+
LayoutInfo valueLayout =
585+
getDefaultSIMTLayoutInfo(storeScatter.getValueType());
579586
LayoutInfo storeScatterLayout = valueLayout;
580587
if (storeScatter.getTranspose()) {
581588
// StoreScatteOp allows transpose effect. However, at the stage of this
@@ -590,7 +597,7 @@ void LayoutInfoPropagation::visitStoreScatterOp(
590597
// Propagate the tensor descriptor layout.
591598
propagateIfChanged(operands[1], operands[1]->meet(storeScatterLayout));
592599
// Use default 1D layout for mask operand.
593-
LayoutInfo maskLayout = getDefaultLayoutInfo(1);
600+
LayoutInfo maskLayout = getDefaultSIMTLayoutInfo(1);
594601
propagateIfChanged(operands[2], operands[2]->meet(maskLayout));
595602
}
596603

0 commit comments

Comments
 (0)