Skip to content

Commit 79f8442

Browse files
committed
Use subgroupd 2d layout to compute tile sizes 1/?
1 parent 2f338af commit 79f8442

File tree

3 files changed

+106
-19
lines changed

3 files changed

+106
-19
lines changed

third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,7 @@ def Subgroup2DBlockEncodingAttr : DistributedEncoding<"Subgroup2DBlockEncoding",
317317

318318
let extraClassDeclaration = extraDistributedDeclaration # [{
319319
SmallVector<unsigned> getRepOrderForOperand(int opIdx) const;
320+
static SmallVector<unsigned, 3> getInstrShapeForLayout(DistributedEncodingTrait layout, ArrayRef<int64_t> shape, bool memoryRowMajor, unsigned kWidth, MLIRContext* context);
320321
}];
321322

322323
let hasCustomAssemblyFormat = 1;

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,92 @@ Subgroup2DBlockEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {
662662
return subgroup2DBlockToLinearLayout(shape, *this, getKWidth());
663663
}
664664

665+
SmallVector<unsigned, 3> Subgroup2DBlockEncodingAttr::getInstrShapeForLayout(
666+
DistributedEncodingTrait layout, ArrayRef<int64_t> tensorShape,
667+
bool memoryRowMajor, unsigned kWidth, MLIRContext *context) {
668+
const auto rank = tensorShape.size();
669+
670+
std::optional<LinearLayout> llEncoding = layout.toLinearLayout(tensorShape);
671+
assert(llEncoding.has_value() && "invalid dot layout to linear layout");
672+
LinearEncodingAttr llAttr = LinearEncodingAttr::get(context, *llEncoding);
673+
SmallVector<unsigned> threadOrder = llAttr.getThreadOrder();
674+
675+
const bool valueRowMajor =
676+
(threadOrder[rank - 2] == 1 && threadOrder[rank - 1] == 0);
677+
assert((valueRowMajor ||
678+
(threadOrder[rank - 2] == 0 && threadOrder[rank - 1] == 1)) &&
679+
"Only row_major or column_major is allowed");
680+
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;
681+
682+
auto dotEncodingAttr = dyn_cast<DotOperandEncodingAttr>(layout);
683+
const unsigned opIdx = dotEncodingAttr ? dotEncodingAttr.getOpIdx() : 2;
684+
685+
// TODO: can this be moved into the DpasEncodingAttr layout?
686+
auto getDPASInstShape = [](const auto dpasLayout, const unsigned opIdx) {
687+
switch (opIdx) {
688+
case 0:
689+
return dpasLayout.getDPASInstShapeA();
690+
case 1:
691+
return dpasLayout.getDPASInstShapeB();
692+
case 2:
693+
return dpasLayout.getDPASInstShapeC();
694+
default:
695+
llvm_unreachable("invalid opidx");
696+
}
697+
};
698+
699+
DpasEncodingAttr dpasLayout =
700+
dotEncodingAttr ? cast<DpasEncodingAttr>(dotEncodingAttr.getParent())
701+
: cast<DpasEncodingAttr>(layout);
702+
assert(dpasLayout && "only dpas layout is supported");
703+
704+
const SmallVector<unsigned> dpasInstShape =
705+
getDPASInstShape(dpasLayout, opIdx);
706+
const SmallVector<unsigned> elemsPerDPASInst = {dpasInstShape[0],
707+
dpasInstShape[1]};
708+
unsigned tileWidth = elemsPerDPASInst[threadOrder[rank - 2]];
709+
unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1]];
710+
711+
if (opIdx == 2) {
712+
return {tileHeight, tileWidth, 1};
713+
}
714+
715+
// For the A and B matrices, enlarge the tile size to support multiple DPAS
716+
// operands
717+
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
718+
SmallVector<int64_t> numReps =
719+
dpasLayout.getDPASRepetitions(tensorShape, opIdx);
720+
721+
const bool isOperandA = opIdx == 0;
722+
const unsigned dimOuter = bool(opIdx) ? rank - 1 : rank - 2;
723+
unsigned dpasOperandsPerTileX =
724+
isOperandA ? repCluster[dimOuter] : numReps[unsigned(opIdx) ? 1 : 2];
725+
unsigned dpasOperandsPerTileY =
726+
isOperandA ? numReps[unsigned(opIdx) ? 1 : 2] : repCluster[dimOuter];
727+
728+
if (isTransposeRequired) {
729+
std::swap(tileWidth, tileHeight);
730+
731+
const unsigned threadsPerWarp = dpasLayout.getThreadsPerWarp();
732+
dpasOperandsPerTileX =
733+
(threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
734+
735+
// limit transpose loads to HW's limitations (what are those...?)
736+
tileWidth = tileWidth / (32 / (kWidth * 8));
737+
738+
dpasOperandsPerTileY = 1;
739+
}
740+
741+
// PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
742+
// by enlarging the number of blocks.
743+
const unsigned totalBytesPerRowPerDPASOp = tileWidth * kWidth;
744+
dpasOperandsPerTileY =
745+
std::min(dpasOperandsPerTileY, 64 / totalBytesPerRowPerDPASOp);
746+
const unsigned numBlocks = dpasOperandsPerTileY;
747+
748+
return {tileHeight, tileWidth, numBlocks};
749+
}
750+
665751
//===----------------------------------------------------------------------===//
666752
// Dialect Interface
667753
//===----------------------------------------------------------------------===//

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1439,6 +1439,14 @@ struct LoadOpConversion
14391439

14401440
Type eltTy = tensorType.getElementType();
14411441
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();
1442+
1443+
auto tileParams = Subgroup2DBlockEncodingAttr::getInstrShapeForLayout(
1444+
cast<DistributedEncodingTrait>(encoding), tensorType.getShape(),
1445+
memoryRowMajor, elemSizeInBits / 8, rewriter.getContext());
1446+
unsigned tileHeight = tileParams[0];
1447+
const unsigned tileWidth = tileParams[1];
1448+
const unsigned vBlocks = tileParams[2];
1449+
14421450
DpasEncodingAttr dpasLayout = getDpasLayout(tensorType);
14431451
const ArrayRef<int64_t> tensorShape = tensorType.getShape();
14441452
unsigned numElems = getTotalElemsPerThread(resultType);
@@ -1476,8 +1484,7 @@ struct LoadOpConversion
14761484

14771485
Value elemSizeInBytes = b.i32_val(elemSizeInBits / 8);
14781486

1479-
SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
1480-
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
1487+
const unsigned elemsPerLane = tileWidth * tileHeight / threadsPerWarp;
14811488
Type load2DGenXType =
14821489
LLVM::getVectorType(IntegerType::get(ctx, elemSizeInBits),
14831490
elemsPerLane); // make it opaque type.
@@ -1527,12 +1534,12 @@ struct LoadOpConversion
15271534
for (int repM = 0; repM < repCluster[0]; ++repM) {
15281535

15291536
Value offsetY =
1530-
b.add(warpId0Offset, b.i32_val(m * replicaStride[0] +
1531-
repM * elemsPerInstr[0]));
1537+
b.add(warpId0Offset,
1538+
b.i32_val(m * replicaStride[0] + repM * tileHeight));
15321539
for (int repN = 0; repN < repCluster[1]; ++repN) {
15331540
Value offsetX =
1534-
b.add(warpId1Offset, b.i32_val(n * replicaStride[1] +
1535-
repN * elemsPerInstr[1]));
1541+
b.add(warpId1Offset,
1542+
b.i32_val(n * replicaStride[1] + repN * tileWidth));
15361543

15371544
auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
15381545
loc, load2DGenXType,
@@ -1543,9 +1550,9 @@ struct LoadOpConversion
15431550
/*x*/ b.trunc(i32_ty, offsetX),
15441551
/*y*/ b.trunc(i32_ty, offsetY),
15451552
/*elem_size_in_bits*/ elemSizeInBits,
1546-
/*tile_width*/ elemsPerInstr[1],
1547-
/*tile_height*/ elemsPerInstr[0],
1548-
/*v_blocks*/ 1,
1553+
/*tile_width*/ tileWidth,
1554+
/*tile_height*/ tileHeight,
1555+
/*v_blocks*/ vBlocks,
15491556
/*transpose*/ false,
15501557
/*vnni_transform*/ false);
15511558
if (failed(load2dOp.verify())) {
@@ -1659,9 +1666,6 @@ struct LoadOpConversion
16591666
offsetBaseY] =
16601667
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
16611668

1662-
unsigned tileWidth = elemsPerDPASInst[threadOrder[rank - 2]];
1663-
unsigned tileHeight = elemsPerDPASInst[threadOrder[rank - 1]];
1664-
16651669
MLIRContext *ctx = rewriter.getContext();
16661670
const StringAttr dimOuterStr = S("dim" + std::to_string(dimOuter));
16671671
const StringAttr dimInnerStr = S("dim" + std::to_string(dimInner));
@@ -1739,7 +1743,6 @@ struct LoadOpConversion
17391743
llvm::dbgs() << "tile layout done\n";
17401744
});
17411745

1742-
unsigned vBlocks = 1;
17431746
unsigned numOperandsOuterDimPerLoad = 1;
17441747
unsigned numOperandsInnerDimPerLoad = 1;
17451748

@@ -1756,25 +1759,24 @@ struct LoadOpConversion
17561759
if (!usePackedType)
17571760
return failure();
17581761

1759-
std::swap(tileHeight, tileWidth);
1760-
17611762
if (oneMatrixPerLoadForBT) {
17621763
// Only load 1 operand per inst on row.
17631764
numOperandsPer2DLoadM = 1;
1765+
tileHeight = elemsPerDPASInst[threadOrder[rank - 2]];
17641766
} else {
17651767
// We can decompose the matrix returned by transposed large 2d load
17661768
// when threads per warp < column size. Otherwise we have to load one
17671769
// operand per inst.
17681770
// Note: the tileHeight and numOperandsPer2DLoadM are the column size
17691771
// now.
1770-
numOperandsPer2DLoadM =
1771-
(threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
1772+
numOperandsPer2DLoadM = (threadsPerWarp <= tileHeight) ? repCluster[rank - 1] : 1;
17721773
}
17731774
// The transpose 2d load only support 1 operand per inst on column.
17741775
// (vBlocks = 1)
17751776
numOperandsPer2DloadN = 1;
17761777
}
17771778

1779+
// TODO: move this logic to the instr shape computation
17781780
// PVC 2D load supports 32 rows at most. Load multiple dot operands in by
17791781
// enlarging the tileHeight.
17801782
numOperandsPer2DLoadM = std::min(numOperandsPer2DLoadM, 32 / tileHeight);
@@ -1785,7 +1787,6 @@ struct LoadOpConversion
17851787
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8;
17861788
numOperandsPer2DloadN =
17871789
std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
1788-
vBlocks = numOperandsPer2DloadN;
17891790

17901791
numOperandsOuterDimPerLoad =
17911792
isOperandA ? numOperandsPer2DLoadM : numOperandsPer2DloadN;
@@ -1960,7 +1961,6 @@ struct LoadOpConversion
19601961
if (isTransposeRequired) {
19611962
// adjust the block io parameter to align HW's limitations on
19621963
// transposing load.
1963-
tileWidth = tileWidth / (32 / originalElemBits);
19641964
elemSizeInBits = 32;
19651965
}
19661966
Value elemSizeInBytes = b.i32_val(originalElemBits / 8);

0 commit comments

Comments
 (0)