Skip to content

Commit b141f1c

Browse files
committed
Fix warpsPerCTA
1 parent 0b772f3 commit b141f1c

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,8 @@ Value DpasMatmulLoader<opIdx>::loadMatrix(
206206
Value offsetOuter = mul(i32_val(repOuter), repNonKDimStride);
207207
Value offsetInner = mul(i32_val(repInner), repKDimStride);
208208
Value offset = add(offsetOuter, offsetInner);
209-
SmallVector<unsigned> warpsPerCTA = dpasLayout.getWarpsPerCTA();
210-
// 3DTODO: check if repBatch * warpsPerCTA[0] is correct for the offset.
211209
if (repBatch > 0) {
210+
SmallVector<unsigned> warpsPerCTA = dpasLayout.getWarpsPerCTA();
212211
Value offsetBatch =
213212
mul(i32_val(repBatch * warpsPerCTA[0]), repBatchDimStride);
214213
offset = add(offset, offsetBatch);

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,6 @@ SmallVector<unsigned> getWarpsPerTile(DotOp dotOp,
5959
struct IntelDPASCapability dpasCap,
6060
const ArrayRef<int64_t> shape,
6161
unsigned numWarps) {
62-
auto rank = shape.size();
63-
// Early exit for batched matmul
64-
// TODO: current strategy is same as upstream, there could be better strategy
65-
// when batch < numWarps
66-
if (rank == 3)
67-
return {numWarps, 1, 1};
68-
6962
auto filter = [&dotOp](Operation *op) {
7063
return op->getParentRegion() == dotOp->getParentRegion();
7164
};
@@ -76,29 +69,39 @@ SmallVector<unsigned> getWarpsPerTile(DotOp dotOp,
7669
if (isa<DotOp>(op) && (op != dotOp))
7770
return {numWarps, 1};
7871

79-
SmallVector<unsigned> ret{1, 1};
80-
SmallVector<int64_t> shapePerWarp{dpasCap.repeatCount, dpasCap.executionSize};
72+
size_t rank = shape.size();
73+
SmallVector<unsigned> ret(rank, 1);
74+
75+
if (rank == 3) {
76+
int batchWarp = numWarps;
77+
while (batchWarp > shape[0])
78+
batchWarp /= 2;
79+
ret[0] = batchWarp;
80+
numWarps /= batchWarp;
81+
}
8182

8283
// Try to find a proper tiling shape for the dot operation.
8384
// It doubles the warp number in col or row in each time based on column to
8485
// width ratio.
8586
// By this, we can minimize the duplication of the dot operands A and B.
87+
SmallVector<int64_t> shapePerWarp{dpasCap.repeatCount, dpasCap.executionSize};
8688
uint32_t rowColRatio =
8789
ceil<uint32_t>(dpasCap.repeatCount, dpasCap.executionSize);
8890
uint32_t colRowRatio =
8991
ceil<uint32_t>(dpasCap.executionSize, dpasCap.repeatCount);
9092

93+
int rowDim = rank - 2, colDim = rank - 1;
9194
do {
92-
if (ret[0] * ret[1] >= numWarps)
95+
if (ret[rowDim] * ret[colDim] >= numWarps)
9396
break;
94-
if (shape[0] / (shapePerWarp[0] * colRowRatio) / ret[0] >=
95-
shape[1] / (shapePerWarp[1] * rowColRatio) / ret[1]) {
96-
if (ret[0] < shape[0] / shapePerWarp[0])
97-
ret[0] *= 2;
97+
if (shape[rowDim] / (shapePerWarp[0] * colRowRatio) / ret[rowDim] >=
98+
shape[colDim] / (shapePerWarp[1] * rowColRatio) / ret[colDim]) {
99+
if (ret[rowDim] < shape[rowDim] / shapePerWarp[0])
100+
ret[rowDim] *= 2;
98101
else
99-
ret[1] *= 2;
102+
ret[colDim] *= 2;
100103
} else {
101-
ret[1] *= 2;
104+
ret[colDim] *= 2;
102105
}
103106
} while (true);
104107

0 commit comments

Comments
 (0)