1- #include < iostream>
21#include < vector>
32
43#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
87#include " triton/Dialect/TritonGPU/IR/Dialect.h"
98#include " triton/Tools/LinearLayout.h"
109#include " triton/Tools/StrUtil.h"
11- #include " llvm/ADT/ArrayRef.h"
1210#include " llvm/ADT/DenseMap.h"
13- #include " llvm/ADT/SmallVector.h"
1411#include " llvm/ADT/Twine.h"
1512#include " llvm/Support/ErrorHandling.h"
1613#include " llvm/Support/MathExtras.h"
@@ -56,8 +53,6 @@ LinearLayout identityND(StringAttr inDimName, ArrayRef<unsigned> shape,
5653 LinearLayout ret = LinearLayout::empty ();
5754 for (int i = 0 ; i < shape.size (); i++) {
5855 // Start with the most-minor dimension, which is order[0].
59- // std::cout << "i: " << i << " shape[i]: " << shape[i]
60- // << " order[i]: " << order[i] << std::endl;
6156 int dim = order[i];
6257 ret *= LinearLayout::identity1D (shape[dim], inDimName, outDimNames[dim]);
6358 }
@@ -280,7 +275,6 @@ LinearLayout ensureLayoutNotSmallerThan(
280275 return layout;
281276 }
282277
283- // MLIRContext *ctx = shape.begin()->first.getContext();
284278 StringAttr kDim = *layout.getInDimNames ().begin ();
285279 assert (kDim == " register" || kDim == " offset" && " unexpected kDim" );
286280
@@ -291,16 +285,6 @@ LinearLayout ensureLayoutNotSmallerThan(
291285 assert (actualSize > desiredSize ||
292286 desiredSize % actualSize == 0 && " bad shape" );
293287 ret *= LinearLayout::identity1D (desiredSize / actualSize, kDim , outDimName);
294- // std::cout << "actualSize: " << actualSize << " desiredSize: " <<
295- // desiredSize
296- // << std::endl;
297- // std::cout << "outDimName: " << outDimName.str() << std::endl;
298- // std::cout << "identity1D: "
299- // << LinearLayout::identity1D(desiredSize / actualSize, kDim,
300- // outDimName)
301- // .toString()
302- // << std::endl;
303- // std::cout << "ret: " << ret.toString() << std::endl;
304288 assert (ret.getOutDimSize (outDimName) >= desiredSize && " bad grow" );
305289 }
306290 return ret;
@@ -324,12 +308,6 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
324308
325309 SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
326310
327- std::cout << " shape: " ;
328- for (auto s : shape) {
329- std::cout << s << " , " ;
330- }
331- std::cout << std::endl;
332-
333311 llvm::SmallDenseMap<StringAttr, int64_t > labeledShape;
334312 for (auto [dim, size] : llvm::zip (outDimNames, shape)) {
335313 labeledShape[dim] = size;
@@ -338,41 +316,26 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
338316 LinearLayout cgaLayout =
339317 ensureLayoutNotLargerThan (makeCgaLayout (cgaLayoutAttr), labeledShape)
340318 .transposeOuts (llvm::to_vector (ctaLayout.getOutDimNames ()));
341- // std::cout << "\ncgaLayout: " << cgaLayout.toString() << std::endl;
342319
343320 // Calculate the shape of the ctaLayout, which is `shape` divided by the
344321 // cgaLayout's size.
345322 llvm::SmallDenseMap<StringAttr, int64_t > ctaShape;
346323 assert (llvm::to_vector (ctaLayout.getOutDimNames ()) ==
347324 llvm::to_vector (cgaLayout.getOutDimNames ()) &&
348325 " bad layout" );
349-
350- // std::cout << "ctaShape: ";
351326 for (auto dim : ctaLayout.getOutDimNames ()) {
352327 ctaShape[dim] =
353328 std::max (int64_t {1 }, labeledShape[dim] / cgaLayout.getOutDimSize (dim));
354- // std::cout << ctaShape[dim] << ", ";
355329 }
356- // std::cout << std::endl;
357330
358- std::cout << " ensureLayoutNotSmallerThan start" << std::endl;
359331 ctaLayout = ensureLayoutNotSmallerThan (ctaLayout, ctaShape);
360- // std::cout << "\nctaLayout not smaller than: " << ctaLayout.toString()
361- // << std::endl;
362- std::cout << " ensureLayoutNotLargerThan start" << std::endl;
363332 ctaLayout = ensureLayoutNotLargerThan (ctaLayout, ctaShape);
364- // std::cout << "\nctaLayout not larger than: " << ctaLayout.toString()
365- // << std::endl;
366333
367- // std::cout << "\ncta * cga: " << (ctaLayout * cgaLayout).toString()
368- // << std::endl;
369334 LinearLayout ret =
370335 (std::move (ctaLayout) * std::move (cgaLayout)).transposeOuts (outDimNames);
371336 for (auto dim : ret.getOutDimNames ()) {
372337 assert (ret.getOutDimSize (dim) == labeledShape[dim] && " bad shape" );
373338 }
374- // std::cout << "\ncombineCtaCgaWithShape: " << ret.toString() << std::endl;
375- std::cout << " combineCtaCgaWithShape end" << std::endl;
376339 return ret;
377340}
378341
@@ -569,7 +532,6 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
569532 LinearLayout::identity1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
570533
571534 } else if (opIdx == 1 ) { // Operand B
572- std::cout << " \n Operand B" << std::endl;
573535 auto regBasesB = DPASRegBasesB (opsPerChannel, executionSize, threadsPerWarp,
574536 systolicDepth);
575537 auto laneBasesB =
@@ -591,32 +553,20 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
591553 tileLayout *=
592554 LinearLayout::identity1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
593555 } else { // opIdx=2 -> Operand C
594- std::cout << " \n Operand C" << std::endl;
595556 auto regBasesC = DPASRegBasesC (repeatCount, executionSize, threadsPerWarp);
596557 auto laneBasesC =
597558 DPASLaneBasesC (repeatCount, executionSize, threadsPerWarp);
598559 tileLayout = LinearLayout ({{kRegister , regBasesC}, {kLane , laneBasesC}},
599560 ArrayRef (outDimNames).take_back (2 ));
600- // std::cout << tileLayout.toString() << std::endl;
601561 // The per-inst layout is repeated at each repCluster.
602562 // Hence, multiply with the identity layouts starting from the
603563 // least significant dimension.
604564 dimNonK = rank - 2 ;
605565 dimK = rank - 1 ;
606566 tileLayout *= LinearLayout::identity1D (repCluster[dimK], kRegister ,
607567 outDimNames[dimK]);
608- // std::cout << (LinearLayout::identity1D(repCluster[dimK], kRegister,
609- // outDimNames[dimK])
610- // .toString())
611- // << std::endl;
612- // std::cout << (tileLayout.toString()) << std::endl;
613568 tileLayout *= LinearLayout::identity1D (repCluster[dimNonK], kRegister ,
614569 outDimNames[dimNonK]);
615- // std::cout << (LinearLayout::identity1D(repCluster[dimNonK], kRegister,
616- // outDimNames[dimNonK])
617- // .toString())
618- // << std::endl;
619- // std::cout << (tileLayout.toString()) << std::endl;
620570
621571 // // The identical layout is repeated among warps
622572 tileLayout *=
@@ -626,34 +576,23 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
626576 if (rank == 3 )
627577 tileLayout *=
628578 LinearLayout::identity1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
629- // std::cout << (tileLayout.toString()) << std::endl;
630579 }
631580
632581 // Lastly, the layout repeats to match the shape.
633582 // Operand A/B repeats through the K-dimension first then repeats
634583 // through the non-K dimension.
635584 SmallVector<int64_t > numReps = dpas.getDPASRepetitions (shape, opIdx);
636585
637- std::cout << " numReps: " ;
638- for (auto numRep : numReps) {
639- std::cout << numRep << " , " ;
640- }
641- std::cout << std::endl;
642-
643586 // numReps is always 3D, we should add 1 to dim id when rank is 2
644587 int repDimK = rank == 2 ? dimK + 1 : dimK;
645588 int repDimNonK = rank == 2 ? dimNonK + 1 : dimNonK;
646589 tileLayout *=
647590 LinearLayout::identity1D (numReps[repDimK], kRegister , outDimNames[dimK]);
648591 tileLayout *= LinearLayout::identity1D (numReps[repDimNonK], kRegister ,
649592 outDimNames[dimNonK]);
650- std::cout << " rank: " << rank << std::endl;
651593 if (rank == 3 )
652594 tileLayout *=
653595 LinearLayout::identity1D (numReps[0 ], kRegister , outDimNames[0 ]);
654- // std::cout << "\ntileLayout with DPASRepetition: " <<
655- // (tileLayout.toString())
656- // << std::endl;
657596
658597 return combineCtaCgaWithShape (std::move (tileLayout),
659598 CTALayoutAttr::getDefault (ctx, rank), shape);
0 commit comments