Skip to content

Commit c7b3773

Browse files
authored
Move tile size computation for subgroup 2d block io encoding (#4461)
Moves the tile size computation from LLVM lowering to a static method on the Subgroup 2D block encoding layout. This allows us to create the layout with the desired tile sizes at a higher level in the pass hierarchy. Moving this functionality now allows us to test using the existing layouts and lowering, ensuring no regressions. There is some cleanup that could be done but I opted for generic objects for now (e.g. `SmallVector`) for flexibility.
1 parent ff637a1 commit c7b3773

File tree

3 files changed

+105
-17
lines changed

3 files changed

+105
-17
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding",
317317

318318
let extraClassDeclaration = extraDistributedDeclaration # [{
319319
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
320+
static SmallVector<unsigned, 3> getInstrShapeForLayout(DistributedEncodingTrait layout, ArrayRef<int64_t> shape, bool memoryRowMajor, unsigned kWidth, MLIRContext* context);
320321
}];
321322

322323
let hasCustomAssemblyFormat = 1;

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,92 @@ Subgroup2DBlockEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
662662
return subgroup2DBlockToLinearLayout(shape, *this, getKWidth());
663663
}
664664

665+
SmallVector<unsigned, 3> Subgroup2DBlockEncodingAttr::getInstrShapeForLayout(
666+
DistributedEncodingTrait layout, ArrayRef<int64_t> tensorShape,
667+
bool memoryRowMajor, unsigned kWidth, MLIRContext *context) {
668+
const auto rank = tensorShape.size();
669+
670+
std::optional<LinearLayout> llEncoding = layout.toLinearLayout(tensorShape);
671+
assert(llEncoding.has_value() && "invalid dot layout to linear layout");
672+
LinearEncodingAttr llAttr = LinearEncodingAttr::get(context, *llEncoding);
673+
SmallVector<unsigned> threadOrder = llAttr.getThreadOrder();
674+
675+
const bool valueRowMajor =
676+
(threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0);
677+
assert((valueRowMajor ||
678+
(threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) &&
679+
"Only row_major or column_major is allowed");
680+
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
681+
682+
auto dotEncodingAttr = dyn_cast<DotOperandEncodingAttr>(layout);
683+
const unsigned opIdx = dotEncodingAttr ? dotEncodingAttr.getOpIdx() : 2;
684+
685+
// TODO: can this be moved into the DpasEncodingAttr layout?
686+
auto getDPASInstShape = [](const auto dpasLayout, const unsigned opIdx) {
687+
switch (opIdx) {
688+
case 0:
689+
return dpasLayout.getDPASInstShapeA();
690+
case 1:
691+
return dpasLayout.getDPASInstShapeB();
692+
case 2:
693+
return dpasLayout.getDPASInstShapeC();
694+
default:
695+
llvm_unreachable("invalid opidx");
696+
}
697+
};
698+
699+
DpasEncodingAttr dpasLayout =
700+
dotEncodingAttr ? cast<DpasEncodingAttr>(dotEncodingAttr.getParent())
701+
: cast<DpasEncodingAttr>(layout);
702+
assert(dpasLayout && "only dpas layout is supported");
703+
704+
const SmallVector<unsigned> dpasInstShape =
705+
getDPASInstShape(dpasLayout, opIdx);
706+
const SmallVector<unsigned> elemsPerDPASInst = {dpasInstShape[0],
707+
dpasInstShape[1]};
708+
unsigned tileWidth = elemsPerDPASInst[threadOrder[rank - 2]];
709+
unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1]];
710+
711+
if (opIdx == 2) {
712+
return {tileHeight, tileWidth, 1};
713+
}
714+
715+
// For the A and B matrices, enlarge the tile size to support multiple DPAS
716+
// operands
717+
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
718+
SmallVector<int64_t> numReps =
719+
dpasLayout.getDPASRepetitions(tensorShape, opIdx);
720+
721+
const bool isOperandA = opIdx == 0;
722+
const unsigned dimOuter = bool(opIdx) ? rank - 1 : rank - 2;
723+
unsigned dpasOperandsPerTileX =
724+
isOperandA ? repCluster[dimOuter] : numReps[unsigned(opIdx) ? 1 : 2];
725+
unsigned dpasOperandsPerTileY =
726+
isOperandA ? numReps[unsigned(opIdx) ? 1 : 2] : repCluster[dimOuter];
727+
728+
if (isTransposeRequired) {
729+
std::swap(tileWidth, tileHeight);
730+
731+
const unsigned threadsPerWarp = dpasLayout.getThreadsPerWarp();
732+
dpasOperandsPerTileX =
733+
(threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
734+
735+
// limit transpose loads to HW's limitations (what are those...?)
736+
tileWidth = tileWidth / (32 / (kWidth * 8));
737+
738+
dpasOperandsPerTileY = 1;
739+
}
740+
741+
// PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
742+
// by enlarging the number of blocks.
743+
const unsigned totalBytesPerRowPerDPASOp = tileWidth * kWidth;
744+
dpasOperandsPerTileY =
745+
std::min(dpasOperandsPerTileY, 64 / totalBytesPerRowPerDPASOp);
746+
const unsigned numBlocks = dpasOperandsPerTileY;
747+
748+
return {tileHeight, tileWidth, numBlocks};
749+
}
750+
665751
//===----------------------------------------------------------------------===//
666752
// Dialect Interface
667753
//===----------------------------------------------------------------------===//

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,14 @@ struct LoadOpConversion
14391439

14401440
Type eltTy = tensorType.getElementType();
14411441
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
1442+
1443+
auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout(
1444+
cast<DistributedEncodingTrait>(encoding), tensorType.getShape(),
1445+
memoryRowMajor, elemSizeInBits / 8, rewriter.getContext());
1446+
unsigned tileHeight = tileParams[0];
1447+
const unsigned tileWidth = tileParams[1];
1448+
const unsigned vBlocks = tileParams[2];
1449+
14421450
DpasEncodingAttr dpasLayout = getDpasLayout(tensorType);
14431451
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
14441452
unsigned numElems = getTotalElemsPerThread(resultType);
@@ -1476,8 +1484,7 @@ struct LoadOpConversion
14761484

14771485
Value elemSizeInBytes = b.i32_val(elemSizeInBits / 8);
14781486

1479-
SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
1480-
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
1487+
const unsigned elemsPerLane = tileWidth * tileHeight / threadsPerWarp;
14811488
Type load2DGenXType =
14821489
LLVM::getVectorType(IntegerType::get(ctx, elemSizeInBits),
14831490
elemsPerLane); // make it opaque type.
@@ -1527,12 +1534,12 @@ struct LoadOpConversion
15271534
for (int repM = 0; repM < repCluster[0]; ++repM) {
15281535

15291536
Value offsetY =
1530-
b.add(warpId0Offset, b.i32_val(m * replicaStride[0] +
1531-
repM * elemsPerInstr[0]));
1537+
b.add(warpId0Offset,
1538+
b.i32_val(m * replicaStride[0] + repM * tileHeight));
15321539
for (int repN = 0; repN < repCluster[1]; ++repN) {
15331540
Value offsetX =
1534-
b.add(warpId1Offset, b.i32_val(n * replicaStride[1] +
1535-
repN * elemsPerInstr[1]));
1541+
b.add(warpId1Offset,
1542+
b.i32_val(n * replicaStride[1] + repN * tileWidth));
15361543

15371544
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
15381545
loc, load2DGenXType,
@@ -1543,9 +1550,9 @@ struct LoadOpConversion
15431550
/*x*/ b.trunc(i32_ty, offsetX),
15441551
/*y*/ b.trunc(i32_ty, offsetY),
15451552
/*elem_size_in_bits*/ elemSizeInBits,
1546-
/*tile_width*/ elemsPerInstr[1],
1547-
/*tile_height*/ elemsPerInstr[0],
1548-
/*v_blocks*/ 1,
1553+
/*tile_width*/ tileWidth,
1554+
/*tile_height*/ tileHeight,
1555+
/*v_blocks*/ vBlocks,
15491556
/*transpose*/ false,
15501557
/*vnni_transform*/ false);
15511558
if (failed(load2dOp.verify())) {
@@ -1659,9 +1666,6 @@ struct LoadOpConversion
16591666
offsetBaseY] =
16601667
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
16611668

1662-
unsigned tileWidth = elemsPerDPASInst[threadOrder[rank - 2]];
1663-
unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1]];
1664-
16651669
MLIRContext *ctx = rewriter.getContext();
16661670
const StringAttr dimOuterStr = S("dim" + std::to_string(dimOuter));
16671671
const StringAttr dimInnerStr = S("dim" + std::to_string(dimInner));
@@ -1739,7 +1743,6 @@ struct LoadOpConversion
17391743
llvm::dbgs() << "tile layout done\n";
17401744
});
17411745

1742-
unsigned vBlocks = 1;
17431746
unsigned numOperandsOuterDimPerLoad = 1;
17441747
unsigned numOperandsInnerDimPerLoad = 1;
17451748

@@ -1756,11 +1759,10 @@ struct LoadOpConversion
17561759
if (!usePackedType)
17571760
return failure();
17581761

1759-
std::swap(tileHeight, tileWidth);
1760-
17611762
if (oneMatrixPerLoadForBT) {
17621763
// Only load 1 operand per inst on row.
17631764
numOperandsPer2DLoadM = 1;
1765+
tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
17641766
} else {
17651767
// We can decompose the matrix returned by transposed large 2d load
17661768
// when threads per warp < column size. Otherwise we have to load one
@@ -1775,6 +1777,7 @@ struct LoadOpConversion
17751777
numOperandsPer2DloadN = 1;
17761778
}
17771779

1780+
// TODO: move this logic to the instr shape computation
17781781
// PVC 2D load supports 32 rows at most. Load multiple dot operands in by
17791782
// enlarging the tileHeight.
17801783
numOperandsPer2DLoadM = std::min(numOperandsPer2DLoadM, 32 / tileHeight);
@@ -1785,7 +1788,6 @@ struct LoadOpConversion
17851788
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8;
17861789
numOperandsPer2DloadN =
17871790
std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
1788-
vBlocks = numOperandsPer2DloadN;
17891791

17901792
numOperandsOuterDimPerLoad =
17911793
isOperandA ? numOperandsPer2DLoadM : numOperandsPer2DloadN;
@@ -1960,7 +1962,6 @@ struct LoadOpConversion
19601962
if (isTransposeRequired) {
19611963
// adjust the block io parameter to align HW's limitations on
19621964
// transposing load.
1963-
tileWidth = tileWidth / (32 / originalElemBits);
19641965
elemSizeInBits = 32;
19651966
}
19661967
Value elemSizeInBytes = b.i32_val(originalElemBits / 8);

0 commit comments

Comments
 (0)