Skip to content

Commit 60cc68a

Browse files
committed
Fix LinearLayout enlarge
1 parent e800858 commit 60cc68a

File tree

3 files changed

+89
-62
lines changed

3 files changed

+89
-62
lines changed

third_party/intel/lib/Dialect/TritonIntelGPU/IR/LinearLayoutConversions.cpp

Lines changed: 82 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -280,17 +280,26 @@ LinearLayout ensureLayoutNotSmallerThan(
280280
return layout;
281281
}
282282

283-
MLIRContext *ctx = shape.begin()->first.getContext();
283+
// MLIRContext *ctx = shape.begin()->first.getContext();
284284
StringAttr kDim = *layout.getInDimNames().begin();
285285
assert(kDim == "register" || kDim == "offset" && "unexpected kDim");
286286

287287
LinearLayout ret = layout;
288-
for (StringAttr outDimName : layout.getOutDimNames()) {
288+
for (StringAttr outDimName : llvm::reverse(layout.getOutDimNames())) {
289289
int32_t actualSize = layout.getOutDimSize(outDimName);
290290
int32_t desiredSize = shape.lookup(outDimName);
291291
assert(actualSize > desiredSize ||
292292
desiredSize % actualSize == 0 && "bad shape");
293293
ret *= LinearLayout::identity1D(desiredSize / actualSize, kDim, outDimName);
294+
std::cout << "actualSize: " << actualSize << " desiredSize: " << desiredSize
295+
<< std::endl;
296+
std::cout << "outDimName: " << outDimName.str() << std::endl;
297+
std::cout << "identity1D: "
298+
<< LinearLayout::identity1D(desiredSize / actualSize, kDim,
299+
outDimName)
300+
.toString()
301+
<< std::endl;
302+
std::cout << "ret: " << ret.toString() << std::endl;
294303
assert(ret.getOutDimSize(outDimName) >= desiredSize && "bad grow");
295304
}
296305
return ret;
@@ -314,6 +323,12 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
314323

315324
SmallVector<StringAttr> outDimNames = standardOutDimNames(ctx, rank);
316325

326+
std::cout << "shape: ";
327+
for (auto s : shape) {
328+
std::cout << s << ", ";
329+
}
330+
331+
std::cout << std::endl;
317332
llvm::SmallDenseMap<StringAttr, int64_t> labeledShape;
318333
for (auto [dim, size] : llvm::zip(outDimNames, shape)) {
319334
labeledShape[dim] = size;
@@ -322,27 +337,38 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
322337
LinearLayout cgaLayout =
323338
ensureLayoutNotLargerThan(makeCgaLayout(cgaLayoutAttr), labeledShape)
324339
.transposeOuts(llvm::to_vector(ctaLayout.getOutDimNames()));
340+
std::cout << "\ncgaLayout: " << cgaLayout.toString() << std::endl;
325341

326342
// Calculate the shape of the ctaLayout, which is `shape` divided by the
327343
// cgaLayout's size.
328344
llvm::SmallDenseMap<StringAttr, int64_t> ctaShape;
329345
assert(llvm::to_vector(ctaLayout.getOutDimNames()) ==
330346
llvm::to_vector(cgaLayout.getOutDimNames()) &&
331347
"bad layout");
348+
349+
std::cout << "ctaShape: ";
332350
for (auto dim : ctaLayout.getOutDimNames()) {
333351
ctaShape[dim] =
334352
std::max(int64_t{1}, labeledShape[dim] / cgaLayout.getOutDimSize(dim));
353+
std::cout << ctaShape[dim] << ", ";
335354
}
355+
std::cout << std::endl;
336356

337357
ctaLayout = ensureLayoutNotSmallerThan(ctaLayout, ctaShape);
358+
std::cout << "\nctaLayout not smaller than: " << ctaLayout.toString()
359+
<< std::endl;
338360
ctaLayout = ensureLayoutNotLargerThan(ctaLayout, ctaShape);
361+
std::cout << "\nctaLayout not larger than: " << ctaLayout.toString()
362+
<< std::endl;
339363

364+
std::cout << "\ncta * cga: " << (ctaLayout * cgaLayout).toString()
365+
<< std::endl;
340366
LinearLayout ret =
341367
(std::move(ctaLayout) * std::move(cgaLayout)).transposeOuts(outDimNames);
342368
for (auto dim : ret.getOutDimNames()) {
343369
assert(ret.getOutDimSize(dim) == labeledShape[dim] && "bad shape");
344370
}
345-
std::cout << "combineCtaCgaWithShape: \n" << ret.toString() << std::endl;
371+
std::cout << "\ncombineCtaCgaWithShape: " << ret.toString() << std::endl;
346372
return ret;
347373
}
348374

@@ -515,26 +541,28 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
515541
int systolicDepth = dpas.getSystolicDepth();
516542
int repeatCount = dpas.getRepeatCount();
517543
int executionSize = dpas.getExecutionSize();
518-
unsigned KDim = 0;
519-
unsigned nonKDim = 0;
544+
unsigned dimK, dimNonK;
520545
if (opIdx == 0) { // Operand A
521546
auto regBasesA = DPASRegBasesA(opsPerChannel, repeatCount, threadsPerWarp,
522547
systolicDepth);
523548
auto laneBasesA =
524549
DPASLaneBasesA(opsPerChannel, threadsPerWarp, systolicDepth);
525550
tileLayout = LinearLayout({{kRegister, regBasesA}, {kLane, laneBasesA}},
526551
outDimNames);
527-
// A only repeats by repCluster[rank-2]
528-
tileLayout *= LinearLayout::identity1D(repCluster[rank - 2], kRegister,
529-
outDimNames[rank - 2]);
552+
// A only repeats by repCluster[rank - 2]
553+
dimNonK = rank - 2;
554+
dimK = rank - 1;
555+
tileLayout *= LinearLayout::identity1D(repCluster[dimNonK], kRegister,
556+
outDimNames[dimNonK]);
530557

531-
nonKDim = rank - 2;
532-
KDim = rank - 1;
533558
// K-dimension is shared among warps
534-
tileLayout *= LinearLayout::zeros1D(warpsPerCTA[rank - 1], kWarp,
535-
outDimNames[rank - 1]);
536-
tileLayout *= LinearLayout::identity1D(warpsPerCTA[rank - 2], kWarp,
537-
outDimNames[rank - 2]);
559+
tileLayout *=
560+
LinearLayout::zeros1D(warpsPerCTA[dimK], kWarp, outDimNames[dimK]);
561+
tileLayout *= LinearLayout::identity1D(warpsPerCTA[dimNonK], kWarp,
562+
outDimNames[dimNonK]);
563+
if (rank == 3)
564+
tileLayout *=
565+
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
538566

539567
} else if (opIdx == 1) { // Operand B
540568
std::cout << "\nOperand B" << std::endl;
@@ -544,67 +572,72 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
544572
DPASLaneBasesB(opsPerChannel, threadsPerWarp, executionSize);
545573
tileLayout = LinearLayout({{kRegister, regBasesB}, {kLane, laneBasesB}},
546574
ArrayRef(outDimNames).take_back(2));
547-
// std::cout << (tileLayout.toString()) << std::endl;
548-
// B only repeats by repCluster[rank-1]
549-
tileLayout *= LinearLayout::identity1D(repCluster[rank - 1], kRegister,
550-
outDimNames[rank - 1]);
551-
// std::cout << (tileLayout.toString()) << std::endl;
552-
553-
nonKDim = rank - 1;
554-
KDim = rank - 2;
575+
// B only repeats by repCluster[rank - 1]
576+
dimNonK = rank - 1;
577+
dimK = rank - 2;
578+
tileLayout *= LinearLayout::identity1D(repCluster[dimNonK], kRegister,
579+
outDimNames[dimNonK]);
555580

556581
// K-dimension is shared among warps
557-
tileLayout *= LinearLayout::identity1D(warpsPerCTA[rank - 1], kWarp,
558-
outDimNames[rank - 1]);
559-
tileLayout *= LinearLayout::zeros1D(warpsPerCTA[rank - 2], kWarp,
560-
outDimNames[rank - 2]);
561-
// std::cout << (tileLayout.toString()) << std::endl;
582+
tileLayout *= LinearLayout::identity1D(warpsPerCTA[dimNonK], kWarp,
583+
outDimNames[dimNonK]);
584+
tileLayout *=
585+
LinearLayout::zeros1D(warpsPerCTA[dimK], kWarp, outDimNames[dimK]);
586+
if (rank == 3)
587+
tileLayout *=
588+
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
562589
} else { // opIdx=2 -> Operand C
563590
std::cout << "\nOperand C" << std::endl;
564591
auto regBasesC = DPASRegBasesC(repeatCount, executionSize, threadsPerWarp);
565592
auto laneBasesC =
566593
DPASLaneBasesC(repeatCount, executionSize, threadsPerWarp);
567594
tileLayout = LinearLayout({{kRegister, regBasesC}, {kLane, laneBasesC}},
568595
ArrayRef(outDimNames).take_back(2));
569-
// llvm::to_vector(llvm::reverse(ArrayRef(outDimNames).take_back(2))));
570-
// std::cout << (tileLayout.toString()) << std::endl;
596+
std::cout << tileLayout.toString() << std::endl;
571597
// The per-inst layout is repeated at each repCluster.
572598
// Hence, multiply with the identity layouts starting from the
573599
// least significant dimension.
574-
nonKDim = rank - 2;
575-
KDim = rank - 1;
576-
tileLayout *= LinearLayout::identity1D(repCluster[KDim], kRegister,
577-
outDimNames[KDim]);
578-
tileLayout *= LinearLayout::identity1D(repCluster[nonKDim], kRegister,
579-
outDimNames[nonKDim]);
600+
dimNonK = rank - 2;
601+
dimK = rank - 1;
602+
tileLayout *= LinearLayout::identity1D(repCluster[dimK], kRegister,
603+
outDimNames[dimK]);
604+
std::cout << (LinearLayout::identity1D(repCluster[dimK], kRegister,
605+
outDimNames[dimK])
606+
.toString())
607+
<< std::endl;
608+
std::cout << (tileLayout.toString()) << std::endl;
609+
tileLayout *= LinearLayout::identity1D(repCluster[dimNonK], kRegister,
610+
outDimNames[dimNonK]);
611+
std::cout << (LinearLayout::identity1D(repCluster[dimNonK], kRegister,
612+
outDimNames[dimNonK])
613+
.toString())
614+
<< std::endl;
580615
std::cout << (tileLayout.toString()) << std::endl;
581616

582617
// // The identical layout is repeated among warps
583618
tileLayout *=
584-
LinearLayout::identity1D(warpsPerCTA[KDim], kWarp, outDimNames[KDim]);
585-
tileLayout *= LinearLayout::identity1D(warpsPerCTA[nonKDim], kWarp,
586-
outDimNames[nonKDim]);
619+
LinearLayout::identity1D(warpsPerCTA[dimK], kWarp, outDimNames[dimK]);
620+
tileLayout *= LinearLayout::identity1D(warpsPerCTA[dimNonK], kWarp,
621+
outDimNames[dimNonK]);
587622
if (rank == 3)
588623
tileLayout *=
589624
LinearLayout::identity1D(warpsPerCTA[0], kWarp, outDimNames[0]);
590-
auto order =
591-
llvm::to_vector(llvm::reverse(triton::gpu::getWarpOrder(layout)));
592-
std::cout << "order: " << order[1] << ", " << order[0] << std::endl;
593-
// tileLayout *= identityND(kWarp, warpsPerCTA,
594-
// llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank))),
595-
// outDimNames);
596-
std::cout << (tileLayout.toString()) << std::endl;
625+
// std::cout << (tileLayout.toString()) << std::endl;
597626
}
598627

599628
// Lastly, the layout repeats to match the shape.
600629
// Operand A/B repeats through the K-dimension first then repeats
601630
// through the non-K dimension.
602631
// SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, opIdx);
632+
// std::cout << "numReps: " << numReps[0] << ", " << numReps[1] << std::endl;
603633
// tileLayout *=
604-
// LinearLayout::identity1D(numReps[KDim], kRegister, outDimNames[KDim]);
605-
// tileLayout *= LinearLayout::identity1D(numReps[nonKDim], kRegister,
606-
// outDimNames[nonKDim]);
607-
// // std::cout << (tileLayout.toString()) << std::endl;
634+
// LinearLayout::identity1D(numReps[dimK], kRegister, outDimNames[dimK]);
635+
// tileLayout *= LinearLayout::identity1D(numReps[dimNonK], kRegister,
636+
// outDimNames[dimNonK]);
637+
// if (rank == 3)
638+
// tileLayout *=
639+
// LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]);
640+
// std::cout << (tileLayout.toString()) << std::endl;
608641

609642
return combineCtaCgaWithShape(std::move(tileLayout),
610643
CTALayoutAttr::getDefault(ctx, rank), shape);

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM/DPAS.cpp

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,8 @@ class DotOpDPASConversionHelper {
137137
ATensorTy.getShape(), AEncoding.getOpIdx());
138138
SmallVector<int64_t> repB = BDpasEncoding.getDPASRepetitions(
139139
BTensorTy.getShape(), BEncoding.getOpIdx());
140-
assert(repA.size() == repB.size() && "A and B rank should match");
141-
size_t rank = repA.size();
142-
assert(repA[rank - 1] == repB[rank - 2] &&
143-
"Unexpected rep for A and B operands");
144-
145-
assert(repA[2] == repB[1]);
146-
assert(repA[0] == repB[0]);
140+
assert(repA[0] == repB[0] && "A and B should have the same batch size");
141+
assert(repA[2] == repB[1] && "Unexpected rep for A and B operands");
147142
unsigned repM = repA[1], repN = repB[2], repK = repA[2];
148143
unsigned repBatch = repA[0];
149144

@@ -196,14 +191,15 @@ class DotOpDPASConversionHelper {
196191
};
197192

198193
ArrayRef<unsigned> repCluster = dpasEncoding.getRepCluster();
194+
unsigned rank = repCluster.size();
199195
for (int b = 0; b < repBatch; ++b)
200196
for (int k = 0; k < repK; ++k)
201197
for (int m = 0; m < repM; ++m)
202198
for (int n = 0; n < repN; ++n)
203-
for (int repRow = 0; repRow < repCluster[0]; ++repRow)
204-
for (int repCol = 0; repCol < repCluster[1]; ++repCol)
205-
generateDPASOp(b, m * repCluster[0] + repRow,
206-
n * repCluster[1] + repCol, k);
199+
for (int repRow = 0; repRow < repCluster[rank - 2]; ++repRow)
200+
for (int repCol = 0; repCol < repCluster[rank - 1]; ++repCol)
201+
generateDPASOp(b, m * repCluster[rank - 2] + repRow,
202+
n * repCluster[rank - 1] + repCol, k);
207203

208204
Value res = composeValuesToDotOperandLayoutStruct(fc, repBatch, repM, repN,
209205
resElemTy);

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,6 @@ class BlockedToDPAS : public RewritePattern {
176176
dpasEnc.getDPASRepetitions(oldBType.getShape(), 1);
177177
unsigned repClusterDimN =
178178
std::min(maxRepClusterN, static_cast<unsigned>(repB[2]));
179-
if (rank == 3)
180-
repCluster[0] = 1;
181179
repCluster[rank - 2] = repClusterDimM;
182180
repCluster[rank - 1] = repClusterDimN;
183181

0 commit comments

Comments
 (0)