Skip to content

Commit 6a2c836

Browse files
committed
Support dynamic rank to dpas/dot layout conversion
1 parent fe45283 commit 6a2c836

File tree

11 files changed

+242
-107
lines changed

11 files changed

+242
-107
lines changed

include/triton/Tools/LinearLayout.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,6 +330,7 @@ class LinearLayout {
330330

331331
// The 0-dimensional layout that maps everything to 0. This is useful as a
332332
// starting point when doing something like
333+
// i
333334
//
334335
// LinearLayout ret = LinearLayout::empty();
335336
// for (...) ret *= ...;

python/src/ir.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,7 @@ void init_triton_ir(py::module &&m) {
16221622
if (haveDump) {
16231623
auto printingFlags = OpPrintingFlags();
16241624
printingFlags.elideLargeElementsAttrs(16);
1625-
printingFlags.enableDebugInfo();
1625+
// printingFlags.enableDebugInfo();
16261626
auto printAlways = [funcToDump](Pass *, Operation *op) -> bool {
16271627
if (funcToDump.empty())
16281628
return true;

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 118 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Dialect/Triton/IR/Dialect.h"
22

3+
#include <cstdint>
34
#include <numeric>
45

56
#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h"
@@ -12,7 +13,9 @@
1213

1314
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.cpp.inc"
1415

16+
#include "llvm/ADT/SmallVector.h"
1517
#include "llvm/ADT/TypeSwitch.h"
18+
#include "llvm/Support/ErrorHandling.h"
1619

1720
using namespace mlir;
1821
using namespace mlir::triton;
@@ -104,49 +107,75 @@ SmallVector<unsigned> DpasEncodingAttr::getDPASInstShapeC() const {
104107
SmallVector<unsigned> DpasEncodingAttr::getShapeA() const {
105108
auto shapeA = getDPASInstShapeA();
106109
auto repCluster = getRepCluster();
107-
return {shapeA[0] * repCluster[0], shapeA[1]};
110+
size_t rank = repCluster.size();
111+
SmallVector<unsigned> resShape(rank, 1);
112+
resShape[rank - 2] = shapeA[0] * repCluster[rank - 2];
113+
resShape[rank - 1] = shapeA[1];
114+
return resShape;
108115
}
109116

110117
SmallVector<unsigned> DpasEncodingAttr::getShapeB() const {
111118
auto shapeB = getDPASInstShapeB();
112119
auto repCluster = getRepCluster();
113-
return {shapeB[0], shapeB[1] * repCluster[1]};
120+
size_t rank = repCluster.size();
121+
SmallVector<unsigned> resShape(rank, 1);
122+
resShape[rank - 2] = shapeB[0];
123+
resShape[rank - 1] = shapeB[1] * repCluster[rank - 1];
124+
return resShape;
114125
}
115126

116127
SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
117128
auto shapeC = getDPASInstShapeC();
118129
auto repCluster = getRepCluster();
119-
return {shapeC[0] * repCluster[0], shapeC[1] * repCluster[1]};
130+
size_t rank = repCluster.size();
131+
SmallVector<unsigned> resShape(rank, 1);
132+
resShape[rank - 2] = shapeC[0] * repCluster[rank - 2];
133+
resShape[rank - 1] = shapeC[1] * repCluster[rank - 1];
134+
return resShape;
120135
}
121136

122137
SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
138+
size_t rank = getWarpsPerCTA().size();
139+
SmallVector<unsigned> res(rank, 1);
123140
unsigned threadsPerWarp = getSubGroupSize();
124141
auto shapeC = getDPASInstShapeC();
125142
unsigned elemsNum = product<unsigned>(shapeC);
126143
unsigned elemsPerThread = elemsNum / threadsPerWarp;
127144
auto repCluster = getRepCluster();
128145
// The Value is shard to lanes to threads per DPAS instruction.
129-
return {elemsPerThread * repCluster[0], repCluster[1]};
146+
res[rank - 2] = elemsPerThread * repCluster[rank - 2];
147+
res[rank - 1] = repCluster[rank - 1];
148+
return res;
130149
}
131150

132151
SmallVector<unsigned>
133152
DpasEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
134153
auto shapeC = getShapeC();
135-
return {shapeC[0] * getWarpsPerCTA()[0], shapeC[1] * getWarpsPerCTA()[1]};
154+
SmallVector<unsigned> warpsPerCTA = getWarpsPerCTA();
155+
size_t rank = shapeC.size();
156+
assert(rank == shapeC.size() &&
157+
"ShapeC and WarpsPerCTA must have the same rank");
158+
SmallVector<unsigned> shapePerCTATile(rank);
159+
for (size_t i = 0; i < rank; ++i) {
160+
shapePerCTATile[i] = shapeC[i] * warpsPerCTA[i];
161+
}
162+
return shapePerCTATile;
136163
}
137164

138165
SmallVector<unsigned>
139166
DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
140167
size_t rank = shape.size();
141-
assert(rank == 2 && "Unexpected rank of mma layout");
168+
assert((rank == 2 || rank == 3) && "Unexpected rank of mma layout");
142169

143170
SmallVector<unsigned> elemsPerThread(rank);
144171
auto shapePerCTATile = getShapePerCTATile(shape);
145-
unsigned tilesRow = ceil<unsigned>(shape[0], shapePerCTATile[0]);
146-
unsigned tilesCol = ceil<unsigned>(shape[1], shapePerCTATile[1]);
172+
unsigned tilesRow =
173+
ceil<unsigned>(shape[rank - 2], shapePerCTATile[rank - 2]);
174+
unsigned tilesCol =
175+
ceil<unsigned>(shape[rank - 1], shapePerCTATile[rank - 1]);
147176
auto sizePerThread = getSizePerThread();
148-
elemsPerThread[0] = sizePerThread[0] * tilesRow;
149-
elemsPerThread[1] = sizePerThread[1] * tilesCol;
177+
elemsPerThread[rank - 2] = sizePerThread[rank - 2] * tilesRow;
178+
elemsPerThread[rank - 1] = sizePerThread[rank - 1] * tilesCol;
150179

151180
return elemsPerThread;
152181
}
@@ -174,48 +203,73 @@ SmallVector<unsigned> DpasEncodingAttr::getCTAsPerCGA() const {
174203
SmallVector<int64_t>
175204
DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
176205
auto warpsPerCTA = getWarpsPerCTA();
206+
int rank = shape.size();
207+
SmallVector<int64_t> res(rank);
177208
if (opIdx == 0) {
178209
auto shapePerWarp = getShapeA();
179-
return {std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])),
180-
std::max<int64_t>(1, shape[1] / shapePerWarp[1])};
210+
if (rank == 3)
211+
res[0] =
212+
std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]));
213+
res[rank - 2] = std::max<int64_t>(
214+
1, shape[rank - 2] / (shapePerWarp[rank - 2] * warpsPerCTA[rank - 2]));
215+
res[rank - 1] =
216+
std::max<int64_t>(1, shape[rank - 1] / shapePerWarp[rank - 1]);
181217
} else if (opIdx == 1) {
182218
auto shapePerWarp = getShapeB();
183-
return {
184-
std::max<int64_t>(1, shape[0] / shapePerWarp[0]),
185-
std::max<int64_t>(1, shape[1] / (shapePerWarp[1] * warpsPerCTA[1]))};
219+
if (rank == 3)
220+
res[0] =
221+
std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]));
222+
res[rank - 2] =
223+
std::max<int64_t>(1, shape[rank - 2] / shapePerWarp[rank - 2]);
224+
res[rank - 1] = std::max<int64_t>(
225+
1, shape[rank - 1] / (shapePerWarp[rank - 1] * warpsPerCTA[rank - 1]));
186226
} else {
187227
assert(opIdx == 2 && "Unexpected operand id (valid ids are 0, 1 or 2)");
188228
auto shapePerWarp = getShapeC();
189-
return {
190-
std::max<int64_t>(1, mlir::ceil<unsigned>(
191-
shape[0], shapePerWarp[0] * warpsPerCTA[0])),
192-
std::max<int64_t>(1, mlir::ceil<unsigned>(
193-
shape[1], shapePerWarp[1] * warpsPerCTA[1]))};
229+
if (rank == 3)
230+
res[0] =
231+
std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]));
232+
res[rank - 2] = std::max<int64_t>(
233+
1, shape[rank - 2] / (shapePerWarp[rank - 2] * warpsPerCTA[rank - 2]));
234+
res[rank - 1] = std::max<int64_t>(
235+
1, shape[rank - 1] / (shapePerWarp[rank - 1] * warpsPerCTA[rank - 1]));
194236
}
237+
return res;
195238
}
196239

197240
unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands(
198241
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
199242
auto shapePerCTA = getShapePerCTA(*this, shape);
200243
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
201244
auto threadsPerWar = getSubGroupSize();
245+
int rank = shape.size();
202246
if (opIdx == 0) {
203247
auto shapeA = getShapeA();
204248
auto totalElem = product<unsigned>(shapeA);
205249
// dpas operands scalar are evenly sharded to each work item.
206-
return (totalElem / threadsPerWar) * rep[0] * rep[1];
207-
} else { // if (opIdx == 1)
250+
return (totalElem / threadsPerWar) * product<int64_t>(rep);
251+
}
252+
if (opIdx == 1) {
208253
auto shapeB = getShapeB();
209254
auto totalElem = product<unsigned>(shapeB);
210255
// dpas operands scalar are evenly sharded to each work item.
211-
return (totalElem / threadsPerWar) * rep[0] * rep[1];
256+
return (totalElem / threadsPerWar) * product<int64_t>(rep);
212257
}
258+
llvm_unreachable("DpasEncodingAttr opIdx must be 0 or 1");
213259
}
214260

215-
SmallVector<unsigned> DpasEncodingAttr::getWarpOrder() const { return {1, 0}; }
261+
SmallVector<unsigned> DpasEncodingAttr::getWarpOrder() const {
262+
size_t rank = getWarpsPerCTA().size();
263+
SmallVector<unsigned> order(rank);
264+
std::iota(order.rbegin(), order.rend(), 0);
265+
return order;
266+
}
216267

217268
SmallVector<unsigned> DpasEncodingAttr::getThreadOrder() const {
218-
return {1, 0};
269+
size_t rank = getWarpsPerCTA().size();
270+
SmallVector<unsigned> order(rank);
271+
std::iota(order.rbegin(), order.rend(), 0);
272+
return order;
219273
}
220274

221275
SmallVector<unsigned> DpasEncodingAttr::getWarpsPerCTA() const {
@@ -224,33 +278,48 @@ SmallVector<unsigned> DpasEncodingAttr::getWarpsPerCTA() const {
224278
}
225279

226280
SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
281+
size_t rank = getWarpsPerCTA().size();
282+
SmallVector<unsigned> res(rank, 1);
227283
auto executionSize = getExecutionSize();
228284
auto subGroupSize = getSubGroupSize();
229285
if (subGroupSize < executionSize) {
230286
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
231287
"smaller than the execution size");
232288
}
233-
return {subGroupSize / executionSize, executionSize};
289+
res[rank - 2] = subGroupSize / executionSize;
290+
res[rank - 1] = executionSize;
291+
return res;
234292
}
235293

236294
SmallVector<unsigned>
237295
DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
238296
int opIdx) const {
239297
auto parentShapePerCTATile = getShapePerCTATile(shape);
240-
auto threadsPerWarp = getThreadsPerWarp();
298+
// auto threadsPerWarp = getThreadsPerWarp();
299+
size_t rank = parentShapePerCTATile.size();
241300
if (opIdx == 0) {
242301
auto shapeA = getShapeA();
243-
return {parentShapePerCTATile[0], shapeA[1]};
302+
if (rank == 2)
303+
return {parentShapePerCTATile[0], shapeA[1]};
304+
else
305+
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2],
306+
shapeA[rank - 1]};
244307
} else if (opIdx == 1) {
245308
auto shapeB = getShapeB();
246-
return {shapeB[0], parentShapePerCTATile[1]};
309+
if (rank == 2)
310+
return {shapeB[0], parentShapePerCTATile[1]};
311+
else
312+
return {parentShapePerCTATile[0], shapeB[rank - 2],
313+
parentShapePerCTATile[rank - 1]};
247314
} else {
248315
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
249316
}
250317
}
251318

252319
SmallVector<unsigned>
253320
DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
321+
ArrayRef<unsigned> repCluster = getRepCluster();
322+
size_t rank = repCluster.size();
254323
if (opIdx == 0) {
255324
SmallVector<unsigned> shapeA = getDPASInstShapeA();
256325
unsigned subGroupSize = getSubGroupSize();
@@ -267,8 +336,7 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
267336
"be smaller than the threads required per row.");
268337
}
269338
unsigned rowsPerWarp = mlir::ceil<unsigned>(subGroupSize, packedColNum);
270-
auto repCluster = getRepCluster();
271-
return {shapeA[0] / rowsPerWarp * repCluster[0], packedOpsPerLane};
339+
return {shapeA[0] / rowsPerWarp * repCluster[rank - 2], packedOpsPerLane};
272340
} else if (opIdx == 1) {
273341
auto shapeB = getShapeB();
274342
auto subGroupSize = getSubGroupSize();
@@ -279,9 +347,8 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
279347
}
280348
SmallVector<unsigned, 2> threadsPerWarp = {subGroupSize / executionSize,
281349
executionSize};
282-
auto repCluster = getRepCluster();
283-
return {shapeB[0] / threadsPerWarp[0],
284-
shapeB[1] / threadsPerWarp[1] * repCluster[1]};
350+
return {shapeB[rank - 2] / threadsPerWarp[0],
351+
shapeB[rank - 1] / threadsPerWarp[1] * repCluster[rank - 1]};
285352
} else {
286353
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
287354
return {};
@@ -293,20 +360,31 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
293360
SmallVector<unsigned> sizePerThread = getSizePerThreadForOperands(opIdx);
294361
SmallVector<int64_t> repetitions = getDPASRepetitions(shape, opIdx);
295362

296-
return {static_cast<unsigned>(sizePerThread[0] * repetitions[0]),
297-
static_cast<unsigned>(sizePerThread[1] * repetitions[1])};
363+
size_t rank = shape.size();
364+
SmallVector<unsigned> elemsPerThread(rank);
365+
if (rank == 3)
366+
elemsPerThread[0] = repetitions[0];
367+
elemsPerThread[rank - 2] = sizePerThread[rank - 2] * repetitions[rank - 2];
368+
elemsPerThread[rank - 1] = sizePerThread[rank - 1] * repetitions[rank - 1];
369+
370+
return elemsPerThread;
298371
};
299372

300373
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
374+
size_t rank = getWarpsPerCTA().size();
375+
assert(rank == 2 || rank == 3);
376+
SmallVector<unsigned> contigPerThread(rank, 1);
377+
301378
unsigned threadsPerWarp = getSubGroupSize();
302379
auto shapeC = getDPASInstShapeC();
303380
// The software vectorization vectorized the value as C array: int a[N] -> int
304381
// a[N][threadsPerWarp]
305382
if (threadsPerWarp > shapeC[1]) {
306-
return {1, 1};
383+
return contigPerThread;
307384
} else if (threadsPerWarp == shapeC[1]) {
308385
auto repCluster = getRepCluster();
309-
return {shapeC[0] * repCluster[0], 1};
386+
contigPerThread[rank - 2] = shapeC[0] * repCluster[rank - 2];
387+
return contigPerThread;
310388
} else {
311389
// threadsPerWarp < shapeC[1]
312390
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
@@ -333,8 +411,8 @@ LogicalResult DpasEncodingAttr::verify(
333411
return emitError() << "systolicDepth must be 8, but was:" << opsPerChan;
334412
}
335413

336-
if (repCluster.size() != 2) {
337-
return emitError() << "expected rank 2 of repCluster, but the rank is:"
414+
if (!(repCluster.size() == 2 || repCluster.size() == 3)) {
415+
return emitError() << "expected rank 2 or 3 of repCluster, but the rank is:"
338416
<< repCluster.size();
339417
}
340418

0 commit comments

Comments
 (0)