Skip to content

Commit b1e46f8

Browse files
committed
Fix 3d dot layout to llvm
1 parent 6a2c836 commit b1e46f8

File tree

8 files changed

+195
-152
lines changed

8 files changed

+195
-152
lines changed

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -202,39 +202,45 @@ SmallVector<unsigned> DpasEncodingAttr::getCTAsPerCGA() const {
202202

203203
SmallVector<int64_t>
204204
DpasEncodingAttr::getDPASRepetitions(ArrayRef<int64_t> shape, int opIdx) const {
205+
// Always return a 3D shape repetitions for the ease of value handling, same
206+
// to mma.
205207
auto warpsPerCTA = getWarpsPerCTA();
206208
int rank = shape.size();
207-
SmallVector<int64_t> res(rank);
209+
SmallVector<int64_t> rep(3, 1);
208210
if (opIdx == 0) {
209211
auto shapePerWarp = getShapeA();
210-
if (rank == 3)
211-
res[0] =
212-
std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]));
213-
res[rank - 2] = std::max<int64_t>(
214-
1, shape[rank - 2] / (shapePerWarp[rank - 2] * warpsPerCTA[rank - 2]));
215-
res[rank - 1] =
216-
std::max<int64_t>(1, shape[rank - 1] / shapePerWarp[rank - 1]);
212+
int64_t numRepBatch =
213+
rank == 3 ? std::max<int64_t>(1, shape[0] /
214+
(shapePerWarp[0] * warpsPerCTA[0]))
215+
: 1;
216+
return {numRepBatch,
217+
std::max<int64_t>(1, shape[rank - 2] / (shapePerWarp[rank - 2] *
218+
warpsPerCTA[rank - 2])),
219+
std::max<int64_t>(1, shape[rank - 1] / shapePerWarp[rank - 1])};
217220
} else if (opIdx == 1) {
218221
auto shapePerWarp = getShapeB();
219-
if (rank == 3)
220-
res[0] =
221-
std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]));
222-
res[rank - 2] =
223-
std::max<int64_t>(1, shape[rank - 2] / shapePerWarp[rank - 2]);
224-
res[rank - 1] = std::max<int64_t>(
225-
1, shape[rank - 1] / (shapePerWarp[rank - 1] * warpsPerCTA[rank - 1]));
222+
int64_t numRepBatch =
223+
rank == 3 ? std::max<int64_t>(1, shape[0] /
224+
(shapePerWarp[0] * warpsPerCTA[0]))
225+
: 1;
226+
return {numRepBatch,
227+
std::max<int64_t>(1, shape[rank - 2] / shapePerWarp[rank - 2]),
228+
std::max<int64_t>(1, shape[rank - 1] / (shapePerWarp[rank - 1] *
229+
warpsPerCTA[rank - 1]))};
226230
} else {
227231
assert(opIdx == 2 && "Unexpected operand id (valid ids are 0, 1 or 2)");
228232
auto shapePerWarp = getShapeC();
229-
if (rank == 3)
230-
res[0] =
231-
std::max<int64_t>(1, shape[0] / (shapePerWarp[0] * warpsPerCTA[0]));
232-
res[rank - 2] = std::max<int64_t>(
233-
1, shape[rank - 2] / (shapePerWarp[rank - 2] * warpsPerCTA[rank - 2]));
234-
res[rank - 1] = std::max<int64_t>(
235-
1, shape[rank - 1] / (shapePerWarp[rank - 1] * warpsPerCTA[rank - 1]));
233+
int64_t numRepBatch =
234+
rank == 3 ? std::max<int64_t>(1, shape[0] /
235+
(shapePerWarp[0] * warpsPerCTA[0]))
236+
: 1;
237+
return {numRepBatch,
238+
std::max<int64_t>(1, shape[rank - 2] / (shapePerWarp[rank - 2] *
239+
warpsPerCTA[rank - 2])),
240+
std::max<int64_t>(1, shape[rank - 1] / (shapePerWarp[rank - 1] *
241+
warpsPerCTA[rank - 1]))};
236242
}
237-
return res;
243+
return rep;
238244
}
239245

240246
unsigned DpasEncodingAttr::getTotalElemsPerThreadForOperands(
@@ -364,8 +370,8 @@ SmallVector<unsigned> DpasEncodingAttr::getElemsPerThreadForOperands(
364370
SmallVector<unsigned> elemsPerThread(rank);
365371
if (rank == 3)
366372
elemsPerThread[0] = repetitions[0];
367-
elemsPerThread[rank - 2] = sizePerThread[rank - 2] * repetitions[rank - 2];
368-
elemsPerThread[rank - 1] = sizePerThread[rank - 1] * repetitions[rank - 1];
373+
elemsPerThread[rank - 2] = sizePerThread[rank - 2] * repetitions[1];
374+
elemsPerThread[rank - 1] = sizePerThread[rank - 1] * repetitions[2];
369375

370376
return elemsPerThread;
371377
};

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "triton/Tools/StrUtil.h"
1111
#include "llvm/ADT/ArrayRef.h"
1212
#include "llvm/ADT/DenseMap.h"
13+
#include "llvm/ADT/SmallVector.h"
1314
#include "llvm/ADT/Twine.h"
1415
#include "llvm/Support/ErrorHandling.h"
1516
#include "llvm/Support/MathExtras.h"
@@ -565,6 +566,7 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
565566
DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp);
566567
tileLayout = LinearLayout({{kRegister, regBasesC}, {kLane, laneBasesC}},
567568
ArrayRef(outDimNames).take_back(2));
569+
// llvm::to_vector(llvm::reverse(ArrayRef(outDimNames).take_back(2))));
568570
// std::cout << (tileLayout.toString()) << std::endl;
569571
// The per-inst layout is repeated at each repCluster.
570572
// Hence, multiply with the identity layouts starting from the
@@ -575,30 +577,34 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
575577
outDimNames[KDim]);
576578
tileLayout *= LinearLayout::identity1D(repCluster[nonKDim], kRegister,
577579
outDimNames[nonKDim]);
578-
// std::cout << (tileLayout.toString()) << std::endl;
580+
std::cout << (tileLayout.toString()) << std::endl;
579581

580582
// // The identical layout is repeated among warps
581583
tileLayout *=
582584
LinearLayout::identity1D(warpsPerCTA[KDim], kWarp, outDimNames[KDim]);
583585
tileLayout *= LinearLayout::identity1D(warpsPerCTA[nonKDim], kWarp,
584586
outDimNames[nonKDim]);
585-
// tileLayout *=
586-
// LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
587+
if (rank == 3)
588+
tileLayout *=
589+
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
590+
auto order =
591+
llvm::to_vector(llvm::reverse(triton::gpu::getWarpOrder(layout)));
592+
std::cout << "order: " << order[1] << ", " << order[0] << std::endl;
587593
// tileLayout *= identityND(kWarp, warpsPerCTA,
588-
// llvm::to_vector(llvm::reverse(triton::gpu::getWarpOrder(layout))),
594+
// llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank))),
589595
// outDimNames);
590-
// std::cout << (tileLayout.toString()) << std::endl;
596+
std::cout << (tileLayout.toString()) << std::endl;
591597
}
592598

593599
// Lastly, the layout repeats to match the shape.
594600
// Operand A/B repeats through the K-dimension first then repeats
595601
// through the non-K dimension.
596-
SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, opIdx);
597-
tileLayout *=
598-
LinearLayout::identity1D(numReps[KDim], kRegister, outDimNames[KDim]);
599-
tileLayout *= LinearLayout::identity1D(numReps[nonKDim], kRegister,
600-
outDimNames[nonKDim]);
601-
// std::cout << (tileLayout.toString()) << std::endl;
602+
// SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, opIdx);
603+
// tileLayout *=
604+
// LinearLayout::identity1D(numReps[KDim], kRegister, outDimNames[KDim]);
605+
// tileLayout *= LinearLayout::identity1D(numReps[nonKDim], kRegister,
606+
// outDimNames[nonKDim]);
607+
// // std::cout << (tileLayout.toString()) << std::endl;
602608

603609
return combineCtaCgaWithShape(std::move(tileLayout),
604610
CTALayoutAttr::getDefault(ctx, rank), shape);

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,8 @@ struct ConvertLayoutOpConversion
313313
return success();
314314
}
315315

316-
using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
316+
// using ValueTable = std::map<std::pair<unsigned, unsigned>, Value>;
317+
using ValueTable = std::map<std::array<unsigned, 3>, Value>;
317318

318319
ValueTable getValuesFromDpasLayoutStruct(Location loc,
319320
ConversionPatternRewriter &rewriter,
@@ -338,17 +339,20 @@ struct ConvertLayoutOpConversion
338339

339340
int offset = 0;
340341
ValueTable result;
341-
for (int i = 0; i < repetitions[outerDim]; ++i) {
342-
for (int j = 0; j < repetitions[innerDim]; ++j) {
343-
for (int repOuter = 0; repOuter < repCluster[outerDim]; ++repOuter) {
344-
for (int repInner = 0; repInner < repCluster[innerDim]; ++repInner) {
345-
Value matVal = rewriter.create<LLVM::UndefOp>(loc, dotOpTy);
346-
for (int k = 0; k < numElemsPerOperand; ++k) {
347-
matVal =
348-
insert_element(dotOpTy, matVal, elems[offset++], i32_val(k));
342+
for (unsigned b = 0; b < repetitions[0]; ++b) {
343+
for (int i = 0; i < repetitions[1]; ++i) {
344+
for (int j = 0; j < repetitions[2]; ++j) {
345+
for (int repOuter = 0; repOuter < repCluster[outerDim]; ++repOuter) {
346+
for (int repInner = 0; repInner < repCluster[innerDim];
347+
++repInner) {
348+
Value matVal = rewriter.create<LLVM::UndefOp>(loc, dotOpTy);
349+
for (int k = 0; k < numElemsPerOperand; ++k) {
350+
matVal = insert_element(dotOpTy, matVal, elems[offset++],
351+
i32_val(k));
352+
}
353+
result[{b, i * repCluster[outerDim] + repOuter,
354+
j * repCluster[innerDim] + repInner}] = matVal;
349355
}
350-
result[{i * repCluster[outerDim] + repOuter,
351-
j * repCluster[innerDim] + repInner}] = matVal;
352356
}
353357
}
354358
}
@@ -367,35 +371,38 @@ struct ConvertLayoutOpConversion
367371
dpasLayout.getDPASRepetitions(dstType.getShape(), opIdx);
368372
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
369373
size_t rank = repCluster.size();
374+
unsigned repBatch = repetitions[0];
370375
unsigned repOuter = 0u;
371376
unsigned repInner = 0u;
372377
unsigned repClusterOuter = 0u;
373378
if (opIdx == 0) {
374379
// operand A
375-
repOuter = repetitions[rank - 2];
376-
repInner = repetitions[rank - 1];
380+
repOuter = repetitions[1];
381+
repInner = repetitions[2];
377382
repClusterOuter = repCluster[rank - 2];
378383
} else {
379384
// operand B
380-
repOuter = repetitions[rank - 1];
381-
repInner = repetitions[rank - 2];
385+
repOuter = repetitions[2];
386+
repInner = repetitions[1];
382387
repClusterOuter = repCluster[rank - 1];
383388
}
384389

385390
// TODO: Operands B requires extra steps to combine [8, 16] to [16, 16].
386391
SmallVector<Value> elems;
387-
for (int m = 0; m < repOuter; ++m) {
388-
for (int k = 0; k < repInner; ++k) {
389-
for (int repOuterIdx = 0; repOuterIdx < repClusterOuter;
390-
++repOuterIdx) {
391-
unsigned offsetM = m * repClusterOuter + repOuterIdx;
392-
unsigned offsetN = k;
393-
Value matVal = vals.at({offsetM, offsetN});
394-
VectorType vecType = cast<mlir::VectorType>(matVal.getType());
395-
Type valTy = vecType.getElementType();
396-
for (int i = 0; i < vecType.getNumElements(); ++i) {
397-
Value val = extract_element(valTy, matVal, i32_val(i));
398-
elems.push_back(val);
392+
for (unsigned b = 0; b < repBatch; ++b) {
393+
for (int m = 0; m < repOuter; ++m) {
394+
for (int k = 0; k < repInner; ++k) {
395+
for (int repOuterIdx = 0; repOuterIdx < repClusterOuter;
396+
++repOuterIdx) {
397+
unsigned offsetM = m * repClusterOuter + repOuterIdx;
398+
unsigned offsetN = k;
399+
Value matVal = vals.at({b, offsetM, offsetN});
400+
VectorType vecType = cast<mlir::VectorType>(matVal.getType());
401+
Type valTy = vecType.getElementType();
402+
for (int i = 0; i < vecType.getNumElements(); ++i) {
403+
Value val = extract_element(valTy, matVal, i32_val(i));
404+
elems.push_back(val);
405+
}
399406
}
400407
}
401408
}

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandDPAS.cpp

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
44
#include "llvm/Support/ErrorHandling.h"
55

6-
using ValueTable = std::map<std::pair<int, int>, Value>;
6+
using ValueTable = std::map<std::array<int, 3>, Value>;
77
using mlir::triton::gpu::getShapePerCTA;
88
using mlir::triton::gpu::SharedEncodingAttr;
99
using mlir::triton::gpu::intel::DpasEncodingAttr;
@@ -44,9 +44,9 @@ template <unsigned opIdx> class DpasMatmulLoader {
4444
SmallVector<Value> computeLdsMatOffs(Value warpOff, Value lane,
4545
Value cSwizzleOffset);
4646
// Load the matrix value.
47-
Value loadMatrix(int repOuter, int repInner, const ArrayRef<Value> ptrs,
48-
LLVM::LLVMStructType structTy, Type smemTy,
49-
Value cSwizzleOffset) const;
47+
Value loadMatrix(int repBatch, int repOuter, int repInner,
48+
const ArrayRef<Value> ptrs, LLVM::LLVMStructType structTy,
49+
Type smemTy, Value cSwizzleOffset) const;
5050

5151
private:
5252
unsigned getThreadsPerWarp() const {
@@ -57,6 +57,7 @@ template <unsigned opIdx> class DpasMatmulLoader {
5757
MemDescType descTy;
5858

5959
SmallVector<Value> smemStrides;
60+
Value repBatchDimStride;
6061
Value repNonKDimStride;
6162
Value repKDimStride;
6263

@@ -176,19 +177,19 @@ DpasMatmulLoader<opIdx>::computeLdsMatOffs(Value warpId, Value laneId,
176177
}
177178

178179
template <unsigned opIdx>
179-
Value DpasMatmulLoader<opIdx>::loadMatrix(int repOuter, int repInner,
180-
const ArrayRef<Value> ptrs,
181-
LLVM::LLVMStructType structTy,
182-
Type smemTy,
183-
Value cSwizzleOffset) const {
180+
Value DpasMatmulLoader<opIdx>::loadMatrix(
181+
int repBatch, int repOuter, int repInner, const ArrayRef<Value> ptrs,
182+
LLVM::LLVMStructType structTy, Type smemTy, Value cSwizzleOffset) const {
184183
Type elemTy = structTy.getBody()[0];
185184
assert(
186185
llvm::any_of(structTy.getBody(), [&](Type ty) { return ty == elemTy; }) &&
187186
"The struct should have the same element types.");
188187

188+
Value offsetBatch = mul(i32_val(repBatch), repBatchDimStride);
189189
Value offsetOuter = mul(i32_val(repOuter), repNonKDimStride);
190190
Value offsetInner = mul(i32_val(repInner), repKDimStride);
191191
Value offset = add(offsetOuter, offsetInner);
192+
offset = add(offset, offsetBatch);
192193

193194
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
194195
size_t elemNum = structTy.getBody().size();
@@ -203,18 +204,20 @@ Value DpasMatmulLoader<opIdx>::loadMatrix(int repOuter, int repInner,
203204
}
204205

205206
Value composeValuesToDotOperandLayoutStruct(
206-
const ValueTable &vals, int n0, int n1,
207+
const ValueTable &vals, int batch, int n0, int n1,
207208
const LLVMTypeConverter *typeConverter, Location loc,
208209
ConversionPatternRewriter &rewriter) {
209210
std::vector<Value> elems;
210-
for (int m = 0; m < n0; ++m) {
211-
for (int k = 0; k < n1; ++k) {
212-
Value matVal = vals.at({m, k});
213-
auto matType = cast<LLVM::LLVMStructType>(matVal.getType());
214-
Type valTy = matType.getBody()[0];
215-
for (int i = 0; i < matType.getBody().size(); ++i) {
216-
auto val = extract_val(valTy, matVal, i);
217-
elems.push_back(val);
211+
for (int b = 0; b < batch; ++b) {
212+
for (int m = 0; m < n0; ++m) {
213+
for (int k = 0; k < n1; ++k) {
214+
Value matVal = vals.at({b, m, k});
215+
auto matType = cast<LLVM::LLVMStructType>(matVal.getType());
216+
Type valTy = matType.getBody()[0];
217+
for (int i = 0; i < matType.getBody().size(); ++i) {
218+
auto val = extract_val(valTy, matVal, i);
219+
elems.push_back(val);
220+
}
218221
}
219222
}
220223
}
@@ -245,7 +248,7 @@ Type getSharedMemTy(Type argType) {
245248
}
246249

247250
template <unsigned opIdx>
248-
std::function<void(int, int)>
251+
std::function<void(int, int, int)>
249252
getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
250253
DpasEncodingAttr dpasLayout, unsigned warpsPerTile,
251254
SmallVector<unsigned> instrShape, Value warpId,
@@ -261,7 +264,8 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
261264
ArrayRef<unsigned> order = sharedLayout.getOrder();
262265

263266
// (a, b) is the coordinate.
264-
auto load = [=, &rewriter, &smemObj, &instrShape, &vals](int a, int b) {
267+
auto load = [=, &rewriter, &smemObj, &instrShape, &vals](int batch, int outer,
268+
int inner) {
265269
DpasMatmulLoader<opIdx> loader(dpasLayout, descTy, warpsPerTile,
266270
smemObj.strides, instrShape, rewriter,
267271
typeConverter, loc);
@@ -289,7 +293,8 @@ getLoadMatrixFn(MemDescType descTy, const SharedMemoryObject &smemObj,
289293
SmallVector<Type>(totalElem / threadsPerWarp,
290294
typeConverter->convertType(eltTy)));
291295

292-
vals[{a, b}] = loader.loadMatrix(a, b, ptrs, matTy, smemTy, cSwizzleOffset);
296+
vals[{batch, outer, inner}] = loader.loadMatrix(
297+
batch, outer, inner, ptrs, matTy, smemTy, cSwizzleOffset);
293298
};
294299

295300
return load;
@@ -325,27 +330,32 @@ Value loadOperand(ConversionPatternRewriter &rewriter, Location loc,
325330
LLVM::delinearize(rewriter, loc, warpId, warpsPerCTA, order);
326331

327332
// FIXME: Using opIdx as the dimIdx will be incorrect in 3D case.
328-
unsigned ceilRes = mlir::ceil<unsigned>(shapePerCTA[opIdx], shape[opIdx]);
329-
Value outerWarpDim = urem(multiDimWarpId[opIdx], i32_val(ceilRes));
330-
unsigned warpsPerTile = std::min<unsigned>(warpsPerCTA[opIdx], ceilRes);
333+
unsigned rank = shape.size();
334+
unsigned dimOuter = opIdx ? (rank - 1) : (rank - 2);
335+
unsigned ceilRes =
336+
mlir::ceil<unsigned>(shapePerCTA[dimOuter], shape[dimOuter]);
337+
Value outerWarpDim = urem(multiDimWarpId[dimOuter], i32_val(ceilRes));
338+
unsigned warpsPerTile = std::min<unsigned>(warpsPerCTA[dimOuter], ceilRes);
331339

332340
// Get the function to use to load the operand.
333341
ValueTable vals;
334-
std::function<void(int, int)> loadFn = getLoadMatrixFn<opIdx>(
342+
std::function<void(int, int, int)> loadFn = getLoadMatrixFn<opIdx>(
335343
descTy, smemObj, dpasLayout, warpsPerTile, std::move(shape), warpId,
336344
outerWarpDim, laneId, vals, typeConverter, rewriter, loc);
337345

338346
// Load the operand.
339-
int64_t numRepOuter = numReps[opIdx];
340-
int64_t numRepK = numReps[(opIdx == 0) ? 1 : 0];
347+
int64_t numRepBatch = numReps[0];
348+
int64_t numRepOuter = numReps[opIdx ? 2 : 1];
349+
int64_t numRepK = numReps[opIdx ? 1 : 2];
341350

342-
for (int m = 0; m < numRepOuter; ++m)
343-
for (int k = 0; k < numRepK; ++k)
344-
loadFn(m, k);
351+
for (int b = 0; b < numRepBatch; ++b)
352+
for (int m = 0; m < numRepOuter; ++m)
353+
for (int k = 0; k < numRepK; ++k)
354+
loadFn(b, m, k);
345355

346356
// Format the values into an LLVM::Struct.
347-
return composeValuesToDotOperandLayoutStruct(vals, numRepOuter, numRepK,
348-
typeConverter, loc, rewriter);
357+
return composeValuesToDotOperandLayoutStruct(
358+
vals, numRepBatch, numRepOuter, numRepK, typeConverter, loc, rewriter);
349359
}
350360

351361
} // namespace

0 commit comments

Comments
 (0)