Skip to content

Commit e800858

Browse files
committed
Support rank>2 to dpas/dot layout conversion
Fix 3d dot layout to llvm Fix opIdx to dimIdx
1 parent fe45283 commit e800858

File tree

10 files changed

+424
-238
lines changed

10 files changed

+424
-238
lines changed

python/src/ir.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1622,7 +1622,7 @@ void init_triton_ir(py::module &&m) {
16221622
if (haveDump) {
16231623
auto printingFlags = OpPrintingFlags();
16241624
printingFlags.elideLargeElementsAttrs(16);
1625-
printingFlags.enableDebugInfo();
1625+
// printingFlags.enableDebugInfo();
16261626
auto printAlways = [funcToDump](Pass *, Operation *op) -> bool {
16271627
if (funcToDump.empty())
16281628
return true;

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 124 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "triton/Dialect/Triton/IR/Dialect.h"
22

3+
#include <cstdint>
34
#include <numeric>
45

56
#include "intel/include/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.h"
@@ -12,7 +13,9 @@
1213

1314
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.cpp.inc"
1415

16+
#include "llvm/ADT/SmallVector.h"
1517
#include "llvm/ADT/TypeSwitch.h"
18+
#include "llvm/Support/ErrorHandling.h"
1619

1720
using namespace mlir;
1821
using namespace mlir::triton;
@@ -104,49 +107,75 @@ SmallVector<unsigned> DpasEncodingAttr::getDPASInstShapeC() const {
104107
SmallVector<unsigned> DpasEncodingAttr::getShapeA() const {
105108
auto shapeA = getDPASInstShapeA();
106109
auto repCluster = getRepCluster();
107-
return {shapeA[0] * repCluster[0], shapeA[1]};
110+
size_t rank = repCluster.size();
111+
SmallVector<unsigned> resShape(rank, 1);
112+
resShape[rank - 2] = shapeA[0] * repCluster[rank - 2];
113+
resShape[rank - 1] = shapeA[1];
114+
return resShape;
108115
}
109116

110117
SmallVector<unsigned> DpasEncodingAttr::getShapeB() const {
111118
auto shapeB = getDPASInstShapeB();
112119
auto repCluster = getRepCluster();
113-
return {shapeB[0], shapeB[1] * repCluster[1]};
120+
size_t rank = repCluster.size();
121+
SmallVector<unsigned> resShape(rank, 1);
122+
resShape[rank - 2] = shapeB[0];
123+
resShape[rank - 1] = shapeB[1] * repCluster[rank - 1];
124+
return resShape;
114125
}
115126

116127
SmallVector<unsigned> DpasEncodingAttr::getShapeC() const {
117128
auto shapeC = getDPASInstShapeC();
118129
auto repCluster = getRepCluster();
119-
return {shapeC[0] * repCluster[0], shapeC[1] * repCluster[1]};
130+
size_t rank = repCluster.size();
131+
SmallVector<unsigned> resShape(rank, 1);
132+
resShape[rank - 2] = shapeC[0] * repCluster[rank - 2];
133+
resShape[rank - 1] = shapeC[1] * repCluster[rank - 1];
134+
return resShape;
120135
}
121136

122137
SmallVector<unsigned> DpasEncodingAttr::getSizePerThread() const {
138+
size_t rank = getWarpsPerCTA().size();
139+
SmallVector<unsigned> res(rank, 1);
123140
unsigned threadsPerWarp = getSubGroupSize();
124141
auto shapeC = getDPASInstShapeC();
125142
unsigned elemsNum = product<unsigned>(shapeC);
126143
unsigned elemsPerThread = elemsNum / threadsPerWarp;
127144
auto repCluster = getRepCluster();
128145
// The Value is shard to lanes to threads per DPAS instruction.
129-
return {elemsPerThread * repCluster[0], repCluster[1]};
146+
res[rank - 2] = elemsPerThread * repCluster[rank - 2];
147+
res[rank - 1] = repCluster[rank - 1];
148+
return res;
130149
}
131150

132151
SmallVector<unsigned>
133152
DpasEncodingAttr::getShapePerCTATile(ArrayRef<int64_t> tensorShape) const {
134153
auto shapeC = getShapeC();
135-
return {shapeC[0] * getWarpsPerCTA()[0], shapeC[1] * getWarpsPerCTA()[1]};
154+
SmallVector<unsigned> warpsPerCTA = getWarpsPerCTA();
155+
size_t rank = shapeC.size();
156+
assert(rank == shapeC.size() &&
157+
"ShapeC and WarpsPerCTA must have the same rank");
158+
SmallVector<unsigned> shapePerCTATile(rank);
159+
for (size_t i = 0; i < rank; ++i) {
160+
shapePerCTATile[i] = shapeC[i] * warpsPerCTA[i];
161+
}
162+
return shapePerCTATile;
136163
}
137164

138165
SmallVector<unsigned>
139166
DpasEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const {
140167
size_t rank = shape.size();
141-
assert(rank == 2 && "Unexpected rank of mma layout");
168+
assert((rank == 2 || rank == 3) && "Unexpected rank of mma layout");
142169

143170
SmallVector<unsigned> elemsPerThread(rank);
144171
auto shapePerCTATile = getShapePerCTATile(shape);
145-
unsigned tilesRow = ceil<unsigned>(shape[0], shapePerCTATile[0]);
146-
unsigned tilesCol = ceil<unsigned>(shape[1], shapePerCTATile[1]);
172+
unsigned tilesRow =
173+
ceil<unsigned>(shape[rank - 2], shapePerCTATile[rank - 2]);
174+
unsigned tilesCol =
175+
ceil<unsigned>(shape[rank - 1], shapePerCTATile[rank - 1]);
147176
auto sizePerThread = getSizePerThread();
148-
elemsPerThread[0] = sizePerThread[0] * tilesRow;
149-
elemsPerThread[1] = sizePerThread[1] * tilesCol;
177+
elemsPerThread[rank - 2] = sizePerThread[rank - 2] * tilesRow;
178+
elemsPerThread[rank - 1] = sizePerThread[rank - 1] * tilesCol;
150179

151180
return elemsPerThread;
152181
}
@@ -173,49 +202,80 @@ SmallVector<unsigned> DpasEncodingAttr::getCTAsPerCGA() const {
173202

174203
SmallVector<int64_t>
175204
DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
205+
// Always return a 3D shape repetitions for the ease of value handling, same
206+
// to mma.
176207
auto warpsPerCTA = getWarpsPerCTA();
208+
int rank = shape.size();
209+
SmallVector<int64_t> rep(3, 1);
177210
if (opIdx == 0) {
178211
auto shapePerWarp = getShapeA();
179-
return {std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0])),
180-
std::max<int64_t>(1, shape[1] / shapePerWarp[1])};
212+
int64_t numRepBatch =
213+
rank == 3 ? std::max<int64_t>(1, shape[0] /
214+
(shapePerWarp[0] * warpsPerCTA[0]))
215+
: 1;
216+
return {numRepBatch,
217+
std::max<int64_t>(1, shape[rank - 2] / (shapePerWarp[rank - 2] *
218+
warpsPerCTA[rank - 2])),
219+
std::max<int64_t>(1, shape[rank - 1] / shapePerWarp[rank - 1])};
181220
} else if (opIdx == 1) {
182221
auto shapePerWarp = getShapeB();
183-
return {
184-
std::max<int64_t>(1, shape[0] / shapePerWarp[0]),
185-
std::max<int64_t>(1, shape[1] / (shapePerWarp[1] * warpsPerCTA[1]))};
222+
int64_t numRepBatch =
223+
rank == 3 ? std::max<int64_t>(1, shape[0] /
224+
(shapePerWarp[0] * warpsPerCTA[0]))
225+
: 1;
226+
return {numRepBatch,
227+
std::max<int64_t>(1, shape[rank - 2] / shapePerWarp[rank - 2]),
228+
std::max<int64_t>(1, shape[rank - 1] / (shapePerWarp[rank - 1] *
229+
warpsPerCTA[rank - 1]))};
186230
} else {
187231
assert(opIdx == 2 && "Unexpected operand id (valid ids are 0, 1 or 2)");
188232
auto shapePerWarp = getShapeC();
189-
return {
190-
std::max<int64_t>(1, mlir::ceil<unsigned>(
191-
shape[0], shapePerWarp[0] * warpsPerCTA[0])),
192-
std::max<int64_t>(1, mlir::ceil<unsigned>(
193-
shape[1], shapePerWarp[1] * warpsPerCTA[1]))};
233+
int64_t numRepBatch =
234+
rank == 3 ? std::max<int64_t>(1, shape[0] /
235+
(shapePerWarp[0] * warpsPerCTA[0]))
236+
: 1;
237+
return {numRepBatch,
238+
std::max<int64_t>(1, shape[rank - 2] / (shapePerWarp[rank - 2] *
239+
warpsPerCTA[rank - 2])),
240+
std::max<int64_t>(1, shape[rank - 1] / (shapePerWarp[rank - 1] *
241+
warpsPerCTA[rank - 1]))};
194242
}
243+
return rep;
195244
}
196245

197246
unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands(
198247
ArrayRef<int64_t> shape, mlir::Type eltTy, int kWidth, int opIdx) const {
199248
auto shapePerCTA = getShapePerCTA(*this, shape);
200249
auto rep = getDPASRepetitions(shapePerCTA, opIdx);
201250
auto threadsPerWar = getSubGroupSize();
251+
int rank = shape.size();
202252
if (opIdx == 0) {
203253
auto shapeA = getShapeA();
204254
auto totalElem = product<unsigned>(shapeA);
205255
// dpas operands scalar are evenly sharded to each work item.
206-
return (totalElem / threadsPerWar) * rep[0] * rep[1];
207-
} else { // if (opIdx == 1)
256+
return (totalElem / threadsPerWar) * product<int64_t>(rep);
257+
}
258+
if (opIdx == 1) {
208259
auto shapeB = getShapeB();
209260
auto totalElem = product<unsigned>(shapeB);
210261
// dpas operands scalar are evenly sharded to each work item.
211-
return (totalElem / threadsPerWar) * rep[0] * rep[1];
262+
return (totalElem / threadsPerWar) * product<int64_t>(rep);
212263
}
264+
llvm_unreachable("DpasEncodingAttr opIdx must be 0 or 1");
213265
}
214266

215-
SmallVector<unsigned> DpasEncodingAttr::getWarpOrder() const { return {1, 0}; }
267+
SmallVector<unsigned> DpasEncodingAttr::getWarpOrder() const {
268+
size_t rank = getWarpsPerCTA().size();
269+
SmallVector<unsigned> order(rank);
270+
std::iota(order.rbegin(), order.rend(), 0);
271+
return order;
272+
}
216273

217274
SmallVector<unsigned> DpasEncodingAttr::getThreadOrder() const {
218-
return {1, 0};
275+
size_t rank = getWarpsPerCTA().size();
276+
SmallVector<unsigned> order(rank);
277+
std::iota(order.rbegin(), order.rend(), 0);
278+
return order;
219279
}
220280

221281
SmallVector<unsigned> DpasEncodingAttr::getWarpsPerCTA() const {
@@ -224,33 +284,48 @@ SmallVector<unsigned> DpasEncodingAttr::getWarpsPerCTA() const {
224284
}
225285

226286
SmallVector<unsigned> DpasEncodingAttr::getThreadsPerWarp() const {
287+
size_t rank = getWarpsPerCTA().size();
288+
SmallVector<unsigned> res(rank, 1);
227289
auto executionSize = getExecutionSize();
228290
auto subGroupSize = getSubGroupSize();
229291
if (subGroupSize < executionSize) {
230292
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not be "
231293
"smaller than the execution size");
232294
}
233-
return {subGroupSize / executionSize, executionSize};
295+
res[rank - 2] = subGroupSize / executionSize;
296+
res[rank - 1] = executionSize;
297+
return res;
234298
}
235299

236300
SmallVector<unsigned>
237301
DpasEncodingAttr::getShapePerCTATileForDotOperands(ArrayRef<int64_t> shape,
238302
int opIdx) const {
239303
auto parentShapePerCTATile = getShapePerCTATile(shape);
240-
auto threadsPerWarp = getThreadsPerWarp();
304+
// auto threadsPerWarp = getThreadsPerWarp();
305+
size_t rank = parentShapePerCTATile.size();
241306
if (opIdx == 0) {
242307
auto shapeA = getShapeA();
243-
return {parentShapePerCTATile[0], shapeA[1]};
308+
if (rank == 2)
309+
return {parentShapePerCTATile[0], shapeA[1]};
310+
else
311+
return {parentShapePerCTATile[0], parentShapePerCTATile[rank - 2],
312+
shapeA[rank - 1]};
244313
} else if (opIdx == 1) {
245314
auto shapeB = getShapeB();
246-
return {shapeB[0], parentShapePerCTATile[1]};
315+
if (rank == 2)
316+
return {shapeB[0], parentShapePerCTATile[1]};
317+
else
318+
return {parentShapePerCTATile[0], shapeB[rank - 2],
319+
parentShapePerCTATile[rank - 1]};
247320
} else {
248321
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
249322
}
250323
}
251324

252325
SmallVector<unsigned>
253326
DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
327+
ArrayRef<unsigned> repCluster = getRepCluster();
328+
size_t rank = repCluster.size();
254329
if (opIdx == 0) {
255330
SmallVector<unsigned> shapeA = getDPASInstShapeA();
256331
unsigned subGroupSize = getSubGroupSize();
@@ -267,8 +342,7 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
267342
"be smaller than the threads required per row.");
268343
}
269344
unsigned rowsPerWarp = mlir::ceil<unsigned>(subGroupSize, packedColNum);
270-
auto repCluster = getRepCluster();
271-
return {shapeA[0] / rowsPerWarp * repCluster[0], packedOpsPerLane};
345+
return {shapeA[0] / rowsPerWarp * repCluster[rank - 2], packedOpsPerLane};
272346
} else if (opIdx == 1) {
273347
auto shapeB = getShapeB();
274348
auto subGroupSize = getSubGroupSize();
@@ -279,9 +353,8 @@ DpasEncodingAttr::getSizePerThreadForOperands(unsigned opIdx) const {
279353
}
280354
SmallVector<unsigned, 2> threadsPerWarp = {subGroupSize / executionSize,
281355
executionSize};
282-
auto repCluster = getRepCluster();
283-
return {shapeB[0] / threadsPerWarp[0],
284-
shapeB[1] / threadsPerWarp[1] * repCluster[1]};
356+
return {shapeB[rank - 2] / threadsPerWarp[0],
357+
shapeB[rank - 1] / threadsPerWarp[1] * repCluster[rank - 1]};
285358
} else {
286359
llvm::report_fatal_error("DotOperandEncodingAttr opIdx must be 0 or 1");
287360
return {};
@@ -293,20 +366,31 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
293366
SmallVector<unsigned> sizePerThread = getSizePerThreadForOperands(opIdx);
294367
SmallVector<int64_t> repetitions = getDPASRepetitions(shape, opIdx);
295368

296-
return {static_cast<unsigned>(sizePerThread[0] * repetitions[0]),
297-
static_cast<unsigned>(sizePerThread[1] * repetitions[1])};
369+
size_t rank = shape.size();
370+
SmallVector<unsigned> elemsPerThread(rank);
371+
if (rank == 3)
372+
elemsPerThread[0] = repetitions[0];
373+
elemsPerThread[rank - 2] = sizePerThread[rank - 2] * repetitions[1];
374+
elemsPerThread[rank - 1] = sizePerThread[rank - 1] * repetitions[2];
375+
376+
return elemsPerThread;
298377
};
299378

300379
SmallVector<unsigned> DpasEncodingAttr::getContigPerThread() {
380+
size_t rank = getWarpsPerCTA().size();
381+
assert(rank == 2 || rank == 3);
382+
SmallVector<unsigned> contigPerThread(rank, 1);
383+
301384
unsigned threadsPerWarp = getSubGroupSize();
302385
auto shapeC = getDPASInstShapeC();
303386
// The software vectorization vectorized the value as C array: int a[N] -> int
304387
// a[N][threadsPerWarp]
305388
if (threadsPerWarp > shapeC[1]) {
306-
return {1, 1};
389+
return contigPerThread;
307390
} else if (threadsPerWarp == shapeC[1]) {
308391
auto repCluster = getRepCluster();
309-
return {shapeC[0] * repCluster[0], 1};
392+
contigPerThread[rank - 2] = shapeC[0] * repCluster[rank - 2];
393+
return contigPerThread;
310394
} else {
311395
// threadsPerWarp < shapeC[1]
312396
llvm::report_fatal_error("DpasEncodingAttr sub-group size could not "
@@ -333,8 +417,8 @@ LogicalResult DpasEncodingAttr::verify(
333417
return emitError() << "systolicDepth must be 8, but was:" << opsPerChan;
334418
}
335419

336-
if (repCluster.size() != 2) {
337-
return emitError() << "expected rank 2 of repCluster, but the rank is:"
420+
if (!(repCluster.size() == 2 || repCluster.size() == 3)) {
421+
return emitError() << "expected rank 2 or 3 of repCluster, but the rank is:"
338422
<< repCluster.size();
339423
}
340424

0 commit comments

Comments
 (0)