Skip to content

Commit 2dd1d03

Browse files
committed
Fix LinearLayout and ConvertLayout to LLVM
1 parent 5b616f4 commit 2dd1d03

File tree

5 files changed

+84
-48
lines changed

5 files changed

+84
-48
lines changed

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ LinearLayout ensureLayoutNotSmallerThan(
285285
assert(kDim == "register" || kDim == "offset" && "unexpected kDim");
286286

287287
LinearLayout ret = layout;
288-
for (StringAttr outDimName : llvm::reverse(layout.getOutDimNames())) {
288+
for (StringAttr outDimName : layout.getOutDimNames()) {
289289
int32_t actualSize = layout.getOutDimSize(outDimName);
290290
int32_t desiredSize = shape.lookup(outDimName);
291291
assert(actualSize > desiredSize ||
@@ -548,7 +548,7 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
548548
auto laneBasesA =
549549
DPASLaneBasesA(opsPerChannel, threadsPerWarp, systolicDepth);
550550
tileLayout = LinearLayout({{kRegister, regBasesA}, {kLane, laneBasesA}},
551-
outDimNames);
551+
ArrayRef(outDimNames).take_back(2));
552552
// A only repeats by repCluster[rank - 2]
553553
dimNonK = rank - 2;
554554
dimK = rank - 1;
@@ -622,22 +622,33 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
622622
if (rank == 3)
623623
tileLayout *=
624624
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
625-
// std::cout << (tileLayout.toString()) << std::endl;
625+
std::cout << (tileLayout.toString()) << std::endl;
626626
}
627627

628628
// Lastly, the layout repeats to match the shape.
629629
// Operand A/B repeats through the K-dimension first then repeats
630630
// through the non-K dimension.
631-
// SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, opIdx);
632-
// std::cout << "numReps: " << numReps[0] << ", " << numReps[1] << std::endl;
633-
// tileLayout *=
634-
// LinearLayout::identity1D(numReps[dimK], kRegister, outDimNames[dimK]);
635-
// tileLayout *= LinearLayout::identity1D(numReps[dimNonK], kRegister,
636-
// outDimNames[dimNonK]);
637-
// if (rank == 3)
638-
// tileLayout *=
639-
// LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]);
640-
// std::cout << (tileLayout.toString()) << std::endl;
631+
SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, opIdx);
632+
633+
std::cout << "numReps: ";
634+
for (auto numRep : numReps) {
635+
std::cout << numRep << ", ";
636+
}
637+
std::cout << std::endl;
638+
639+
// numReps is always 3D, we should add 1 to dim id when rank is 2
640+
int repDimK = rank == 2 ? dimK + 1 : dimK;
641+
int repDimNonK = rank == 2 ? dimNonK + 1 : dimNonK;
642+
tileLayout *=
643+
LinearLayout::identity1D(numReps[repDimK], kRegister, outDimNames[dimK]);
644+
tileLayout *= LinearLayout::identity1D(numReps[repDimNonK], kRegister,
645+
outDimNames[dimNonK]);
646+
std::cout << "rank: " << rank << std::endl;
647+
if (rank == 3)
648+
tileLayout *=
649+
LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]);
650+
std::cout << "\ntileLayout with DPASRepetition: " << (tileLayout.toString())
651+
<< std::endl;
641652

642653
return combineCtaCgaWithShape(std::move(tileLayout),
643654
CTALayoutAttr::getDefault(ctx, rank), shape);

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "PatternTritonGPUOpToLLVM.h"
22
#include "TargetInfo.h"
33
#include "Utility.h"
4+
#include <iostream>
45

56
#include "intel/include/Analysis/Utility.h"
67
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
@@ -110,17 +111,22 @@ struct ConvertLayoutOpConversion
110111
return multiDimOffset;
111112
}
112113
if (auto dpasLayout = dyn_cast<DpasEncodingAttr>(layout)) {
113-
assert(rank == 2);
114+
assert(rank == 2 || rank == 3);
115+
std::cout << "!!!getMultiDimOffset: dpasLayout" << std::endl;
114116
auto multiDimBase = ::intel::emitBaseIndexForLayout(
115117
loc, rewriter, targetInfo, layout, type, false);
116118
SmallVector<SmallVector<unsigned>> offsets;
117119
::emitOffsetForDpasLayoutPerCTA(
118120
dpasLayout, offsets, multiDimCTAInRepId[0] * shapePerCTATile[0],
119121
multiDimCTAInRepId[1] * shapePerCTATile[1]);
120122

121-
SmallVector<Value> multiDimOffset = {
122-
add(multiDimBase[0], i32_val(offsets[elemId][0])),
123-
add(multiDimBase[1], i32_val(offsets[elemId][1]))};
123+
SmallVector<Value> multiDimOffset(rank);
124+
if (rank == 3)
125+
multiDimOffset[0] = multiDimBase[0];
126+
multiDimOffset[rank - 2] =
127+
add(multiDimBase[rank - 2], i32_val(offsets[elemId][rank - 2]));
128+
multiDimOffset[rank - 1] =
129+
add(multiDimBase[rank - 1], i32_val(offsets[elemId][rank - 1]));
124130

125131
return multiDimOffset;
126132
}

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ class DotOpDPASConversionHelper {
168168
"A and B precision enumerators do not match");
169169

170170
LLVM_DEBUG({
171-
llvm::dbgs() << "repB = " << repBatch << "\n";
171+
llvm::dbgs() << "repBatch = " << repBatch << "\n";
172172
llvm::dbgs() << "repM = " << repM << "\n";
173173
llvm::dbgs() << "repK = " << repK << "\n";
174174
llvm::dbgs() << "repN = " << repN << "\n";

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -609,11 +609,10 @@ struct LoadOpConversion
609609

610610
unsigned numOperandsPer2DLoadM, numOperandsPer2DloadN;
611611
if (!isTransposeRequired) {
612-
int dimRep = opIdx ? 1 : 2;
613612
numOperandsPer2DLoadM =
614-
isOperandA ? repCluster[dimOuter] : numReps[dimRep];
613+
isOperandA ? repCluster[dimOuter] : numReps[opIdx ? 1 : 2];
615614
numOperandsPer2DloadN =
616-
isOperandA ? numReps[dimRep] : repCluster[dimOuter];
615+
isOperandA ? numReps[opIdx ? 1 : 2] : repCluster[dimOuter];
617616
} else {
618617
if (isOperandA)
619618
return failure();
@@ -671,8 +670,8 @@ struct LoadOpConversion
671670
unsigned warpOuterStride = warpShape[dimOuter];
672671
unsigned repKStride = elemsPerDPASInst[dimInner];
673672

674-
unsigned numRepOuter = numReps[dimOuter];
675-
unsigned numRepInner = numReps[dimInner];
673+
unsigned numRepOuter = numReps[opIdx ? 2 : 1];
674+
unsigned numRepInner = numReps[opIdx ? 1 : 2];
676675

677676
Value pitch;
678677
if (memoryRowMajor) {

third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,9 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
164164
SmallVector<unsigned> sizePerThreads = getSizePerThread(dpasLayout);
165165
ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
166166
size_t rank = repCluster.size();
167-
SmallVector<unsigned> sizePerDPASInst = {sizePerThreads[0] / repCluster[0],
168-
sizePerThreads[1] / repCluster[1]};
167+
SmallVector<unsigned> sizePerDPASInst = {
168+
sizePerThreads[rank - 2] / repCluster[rank - 2],
169+
sizePerThreads[rank - 1] / repCluster[rank - 1]};
169170

170171
unsigned rowsPerElem = dpasLayout.getSubGroupSize() / instShapeC[1];
171172
unsigned colsPerElem = 1;
@@ -176,15 +177,19 @@ emitOffsetForDpasLayoutPerCTA(const DpasEncodingAttr &dpasLayout,
176177
for (unsigned elemId = 0; elemId < elemNumberPerRep; ++elemId) {
177178
// Follows the C++ order for the dpas layout.
178179
SmallVector<unsigned> repOffset = {
179-
(repId / repCluster[1]) * instShapeC[0],
180-
(repId % repCluster[1]) * instShapeC[1]};
180+
(repId / repCluster[rank - 1]) * instShapeC[0],
181+
(repId % repCluster[rank - 1]) * instShapeC[1]};
181182

182183
SmallVector<unsigned> elemOffset = {
183184
(elemId / sizePerDPASInst[1]) * rowsPerElem,
184185
(elemId % sizePerDPASInst[1]) * colsPerElem};
185186

186-
offsets.push_back({repOffset[0] + elemOffset[0] + ctaOffsetX,
187-
repOffset[1] + elemOffset[1] + ctaOffsetY});
187+
if (rank == 3)
188+
offsets.push_back({0, repOffset[0] + elemOffset[0] + ctaOffsetX,
189+
repOffset[1] + elemOffset[1] + ctaOffsetY});
190+
else
191+
offsets.push_back({repOffset[0] + elemOffset[0] + ctaOffsetX,
192+
repOffset[1] + elemOffset[1] + ctaOffsetY});
188193
}
189194
}
190195
}
@@ -289,9 +294,10 @@ emitOffsetForDpasLayout(const DpasEncodingAttr &dpasLayout,
289294
ArrayRef<int64_t> shape = type.getShape();
290295
SmallVector<SmallVector<unsigned>> offsets;
291296
SmallVector<unsigned> shapePerCTA = getShapePerCTATile(dpasLayout);
297+
size_t rank = shape.size();
292298

293-
for (unsigned i = 0; i < shape[0]; i += shapePerCTA[0]) {
294-
for (unsigned j = 0; j < shape[1]; j += shapePerCTA[1]) {
299+
for (unsigned i = 0; i < shape[rank - 2]; i += shapePerCTA[rank - 2]) {
300+
for (unsigned j = 0; j < shape[rank - 1]; j += shapePerCTA[rank - 1]) {
295301
emitOffsetForDpasLayoutPerCTA(dpasLayout, offsets, i, j);
296302
}
297303
}
@@ -333,13 +339,14 @@ emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter,
333339
size_t rank = warpShape.size();
334340
assert(rank == shapePerCTA.size() && "Rank mismatch");
335341
Value warpIndex =
336-
(opIdx == 0) ? urem(multiDimWarpId[0],
342+
(opIdx == 0) ? urem(multiDimWarpId[rank - 2],
337343
i32_val(mlir::ceil<unsigned>(shapePerCTA[rank - 2],
338344
warpShape[rank - 2])))
339-
: urem(multiDimWarpId[1],
345+
: urem(multiDimWarpId[rank - 1],
340346
i32_val(mlir::ceil<unsigned>(shapePerCTA[rank - 1],
341347
warpShape[rank - 1])));
342-
Value warpOffset = mul(warpIndex, i32_val(warpShape[opIdx]));
348+
Value warpOffset =
349+
mul(warpIndex, i32_val(warpShape[opIdx ? rank - 1 : rank - 2]));
343350

344351
// Compute the 2-dim coordinates of the first element in the warp operated
345352
// own by this thread.
@@ -355,7 +362,7 @@ emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter,
355362
// Unlike the operand B, to pack the value to i16 for scalar bit width
356363
// <=16.
357364
unsigned packedOpsPerLane = opsPerChannel == 4 ? 2 : 1;
358-
unsigned packedColNum = shapeA[1] / packedOpsPerLane;
365+
unsigned packedColNum = shapeA[rank - 1] / packedOpsPerLane;
359366
if (warpSize < packedColNum)
360367
llvm::report_fatal_error(
361368
"DpasEncodingAttr sub-group size could not "
@@ -375,12 +382,18 @@ emitBaseIndexForDotOpLayout(Location loc, RewriterBase &rewriter,
375382
laneRowIndex = mul(laneRowIndex, i32_val(opsPerChannel));
376383
laneColIndex = urem(laneId, i32_val(executionSize));
377384
} break;
385+
default: {
386+
llvm::report_fatal_error("Only support opIdx 1 or 0 for DotOpLayout.");
387+
}
378388
}
379389

380-
auto multiDimBase =
381-
(opIdx == 0)
382-
? SmallVector<Value>{add(laneRowIndex, warpOffset), laneColIndex}
383-
: SmallVector<Value>{laneRowIndex, add(laneColIndex, warpOffset)};
390+
SmallVector<Value> multiDimBase(rank);
391+
if (rank == 3)
392+
multiDimBase[0] = multiDimWarpId[0];
393+
multiDimBase[rank - 2] =
394+
(opIdx == 0) ? add(laneRowIndex, warpOffset) : laneRowIndex;
395+
multiDimBase[rank - 1] =
396+
(opIdx == 0) ? laneColIndex : add(laneColIndex, warpOffset);
384397

385398
return multiDimBase;
386399
}
@@ -394,6 +407,7 @@ emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter,
394407
Value warpId = udiv(threadId, warpSize);
395408
Value laneId = urem(threadId, warpSize);
396409

410+
unsigned rank = type.getShape().size();
397411
auto warpsPerCTA = dpasLayout.getWarpsPerCTA();
398412
ArrayRef<int64_t> shape = type.getShape();
399413

@@ -404,19 +418,25 @@ emitBaseIndexForDpasLayout(Location loc, RewriterBase &rewriter,
404418
// Compute the 2-dim coordinates of the warp containing the tensor element
405419
// operated on by this thread.
406420
SmallVector<unsigned> warpShape = dpasLayout.getShapeC();
407-
Value rowWarpId = urem(multiDimWarpId[0],
408-
i32_val(mlir::ceil<unsigned>(shape[0], warpShape[0])));
409-
Value colWarpId = urem(multiDimWarpId[1],
410-
i32_val(mlir::ceil<unsigned>(shape[1], warpShape[1])));
411-
Value rowWarpOffset = mul(rowWarpId, i32_val(warpShape[0]));
412-
Value colWarpOffset = mul(colWarpId, i32_val(warpShape[1]));
421+
Value rowWarpId =
422+
urem(multiDimWarpId[rank - 2],
423+
i32_val(mlir::ceil<unsigned>(shape[rank - 2], warpShape[rank - 2])));
424+
Value colWarpId =
425+
urem(multiDimWarpId[rank - 1],
426+
i32_val(mlir::ceil<unsigned>(shape[rank - 1], warpShape[rank - 1])));
427+
Value rowWarpOffset = mul(rowWarpId, i32_val(warpShape[rank - 2]));
428+
Value colWarpOffset = mul(colWarpId, i32_val(warpShape[rank - 1]));
413429

414430
// Compute the 2-dim coordinates of the first element in the warp operated
415431
// on by this thread.
416432
SmallVector<unsigned> threadsPerWarp = getThreadsPerWarp(dpasLayout);
417-
SmallVector<Value> multiDimBase = {
418-
add(udiv(laneId, i32_val(threadsPerWarp[1])), rowWarpOffset),
419-
add(urem(laneId, i32_val(threadsPerWarp[1])), colWarpOffset)};
433+
SmallVector<Value> multiDimBase(rank);
434+
if (rank == 3)
435+
multiDimBase[0] = multiDimWarpId[0];
436+
multiDimBase[rank - 2] =
437+
add(udiv(laneId, i32_val(threadsPerWarp[rank - 1])), rowWarpOffset);
438+
multiDimBase[rank - 1] =
439+
add(urem(laneId, i32_val(threadsPerWarp[rank - 1])), colWarpOffset);
420440
return multiDimBase;
421441
}
422442

0 commit comments

Comments
 (0)