@@ -280,17 +280,26 @@ LinearLayout ensureLayoutNotSmallerThan(
280280 return layout;
281281 }
282282
283- MLIRContext *ctx = shape.begin ()->first .getContext ();
283+ // MLIRContext *ctx = shape.begin()->first.getContext();
284284 StringAttr kDim = *layout.getInDimNames ().begin ();
285285 assert (kDim == " register" || kDim == " offset" && " unexpected kDim" );
286286
287287 LinearLayout ret = layout;
288- for (StringAttr outDimName : layout.getOutDimNames ()) {
288+ for (StringAttr outDimName : llvm::reverse ( layout.getOutDimNames () )) {
289289 int32_t actualSize = layout.getOutDimSize (outDimName);
290290 int32_t desiredSize = shape.lookup (outDimName);
291291 assert (actualSize > desiredSize ||
292292 desiredSize % actualSize == 0 && " bad shape" );
293293 ret *= LinearLayout::identity1D (desiredSize / actualSize, kDim , outDimName);
294+ std::cout << " actualSize: " << actualSize << " desiredSize: " << desiredSize
295+ << std::endl;
296+ std::cout << " outDimName: " << outDimName.str () << std::endl;
297+ std::cout << " identity1D: "
298+ << LinearLayout::identity1D (desiredSize / actualSize, kDim ,
299+ outDimName)
300+ .toString ()
301+ << std::endl;
302+ std::cout << " ret: " << ret.toString () << std::endl;
294303 assert (ret.getOutDimSize (outDimName) >= desiredSize && " bad grow" );
295304 }
296305 return ret;
@@ -314,6 +323,12 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
314323
315324 SmallVector<StringAttr> outDimNames = standardOutDimNames (ctx, rank);
316325
326+ std::cout << " shape: " ;
327+ for (auto s : shape) {
328+ std::cout << s << " , " ;
329+ }
330+
331+ std::cout << std::endl;
317332 llvm::SmallDenseMap<StringAttr, int64_t > labeledShape;
318333 for (auto [dim, size] : llvm::zip (outDimNames, shape)) {
319334 labeledShape[dim] = size;
@@ -322,27 +337,38 @@ LinearLayout combineCtaCgaWithShape(LinearLayout ctaLayout,
322337 LinearLayout cgaLayout =
323338 ensureLayoutNotLargerThan (makeCgaLayout (cgaLayoutAttr), labeledShape)
324339 .transposeOuts (llvm::to_vector (ctaLayout.getOutDimNames ()));
340+ std::cout << " \n cgaLayout: " << cgaLayout.toString () << std::endl;
325341
326342 // Calculate the shape of the ctaLayout, which is `shape` divided by the
327343 // cgaLayout's size.
328344 llvm::SmallDenseMap<StringAttr, int64_t > ctaShape;
329345 assert (llvm::to_vector (ctaLayout.getOutDimNames ()) ==
330346 llvm::to_vector (cgaLayout.getOutDimNames ()) &&
331347 " bad layout" );
348+
349+ std::cout << " ctaShape: " ;
332350 for (auto dim : ctaLayout.getOutDimNames ()) {
333351 ctaShape[dim] =
334352 std::max (int64_t {1 }, labeledShape[dim] / cgaLayout.getOutDimSize (dim));
353+ std::cout << ctaShape[dim] << " , " ;
335354 }
355+ std::cout << std::endl;
336356
337357 ctaLayout = ensureLayoutNotSmallerThan (ctaLayout, ctaShape);
358+ std::cout << " \n ctaLayout not smaller than: " << ctaLayout.toString ()
359+ << std::endl;
338360 ctaLayout = ensureLayoutNotLargerThan (ctaLayout, ctaShape);
361+ std::cout << " \n ctaLayout not larger than: " << ctaLayout.toString ()
362+ << std::endl;
339363
364+ std::cout << " \n cta * cga: " << (ctaLayout * cgaLayout).toString ()
365+ << std::endl;
340366 LinearLayout ret =
341367 (std::move (ctaLayout) * std::move (cgaLayout)).transposeOuts (outDimNames);
342368 for (auto dim : ret.getOutDimNames ()) {
343369 assert (ret.getOutDimSize (dim) == labeledShape[dim] && " bad shape" );
344370 }
345- std::cout << " combineCtaCgaWithShape: \n " << ret.toString () << std::endl;
371+ std::cout << " \n combineCtaCgaWithShape: " << ret.toString () << std::endl;
346372 return ret;
347373}
348374
@@ -515,26 +541,28 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
515541 int systolicDepth = dpas.getSystolicDepth ();
516542 int repeatCount = dpas.getRepeatCount ();
517543 int executionSize = dpas.getExecutionSize ();
518- unsigned KDim = 0 ;
519- unsigned nonKDim = 0 ;
544+ unsigned dimK, dimNonK;
520545 if (opIdx == 0 ) { // Operand A
521546 auto regBasesA = DPASRegBasesA (opsPerChannel, repeatCount, threadsPerWarp,
522547 systolicDepth);
523548 auto laneBasesA =
524549 DPASLaneBasesA (opsPerChannel, threadsPerWarp, systolicDepth);
525550 tileLayout = LinearLayout ({{kRegister , regBasesA}, {kLane , laneBasesA}},
526551 outDimNames);
527- // A only repeats by repCluster[rank-2]
528- tileLayout *= LinearLayout::identity1D (repCluster[rank - 2 ], kRegister ,
529- outDimNames[rank - 2 ]);
552+ // A only repeats by repCluster[rank - 2]
553+ dimNonK = rank - 2 ;
554+ dimK = rank - 1 ;
555+ tileLayout *= LinearLayout::identity1D (repCluster[dimNonK], kRegister ,
556+ outDimNames[dimNonK]);
530557
531- nonKDim = rank - 2 ;
532- KDim = rank - 1 ;
533558 // K-dimension is shared among warps
534- tileLayout *= LinearLayout::zeros1D (warpsPerCTA[rank - 1 ], kWarp ,
535- outDimNames[rank - 1 ]);
536- tileLayout *= LinearLayout::identity1D (warpsPerCTA[rank - 2 ], kWarp ,
537- outDimNames[rank - 2 ]);
559+ tileLayout *=
560+ LinearLayout::zeros1D (warpsPerCTA[dimK], kWarp , outDimNames[dimK]);
561+ tileLayout *= LinearLayout::identity1D (warpsPerCTA[dimNonK], kWarp ,
562+ outDimNames[dimNonK]);
563+ if (rank == 3 )
564+ tileLayout *=
565+ LinearLayout::identity1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
538566
539567 } else if (opIdx == 1 ) { // Operand B
540568 std::cout << " \n Operand B" << std::endl;
@@ -544,67 +572,72 @@ LinearLayout DPAStoLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
544572 DPASLaneBasesB (opsPerChannel, threadsPerWarp, executionSize);
545573 tileLayout = LinearLayout ({{kRegister , regBasesB}, {kLane , laneBasesB}},
546574 ArrayRef (outDimNames).take_back (2 ));
547- // std::cout << (tileLayout.toString()) << std::endl;
548- // B only repeats by repCluster[rank-1]
549- tileLayout *= LinearLayout::identity1D (repCluster[rank - 1 ], kRegister ,
550- outDimNames[rank - 1 ]);
551- // std::cout << (tileLayout.toString()) << std::endl;
552-
553- nonKDim = rank - 1 ;
554- KDim = rank - 2 ;
575+ // B only repeats by repCluster[rank - 1]
576+ dimNonK = rank - 1 ;
577+ dimK = rank - 2 ;
578+ tileLayout *= LinearLayout::identity1D (repCluster[dimNonK], kRegister ,
579+ outDimNames[dimNonK]);
555580
556581 // K-dimension is shared among warps
557- tileLayout *= LinearLayout::identity1D (warpsPerCTA[rank - 1 ], kWarp ,
558- outDimNames[rank - 1 ]);
559- tileLayout *= LinearLayout::zeros1D (warpsPerCTA[rank - 2 ], kWarp ,
560- outDimNames[rank - 2 ]);
561- // std::cout << (tileLayout.toString()) << std::endl;
582+ tileLayout *= LinearLayout::identity1D (warpsPerCTA[dimNonK], kWarp ,
583+ outDimNames[dimNonK]);
584+ tileLayout *=
585+ LinearLayout::zeros1D (warpsPerCTA[dimK], kWarp , outDimNames[dimK]);
586+ if (rank == 3 )
587+ tileLayout *=
588+ LinearLayout::identity1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
562589 } else { // opIdx=2 -> Operand C
563590 std::cout << " \n Operand C" << std::endl;
564591 auto regBasesC = DPASRegBasesC (repeatCount, executionSize, threadsPerWarp);
565592 auto laneBasesC =
566593 DPASLaneBasesC (repeatCount, executionSize, threadsPerWarp);
567594 tileLayout = LinearLayout ({{kRegister , regBasesC}, {kLane , laneBasesC}},
568595 ArrayRef (outDimNames).take_back (2 ));
569- // llvm::to_vector(llvm::reverse(ArrayRef(outDimNames).take_back(2))));
570- // std::cout << (tileLayout.toString()) << std::endl;
596+ std::cout << tileLayout.toString () << std::endl;
571597 // The per-inst layout is repeated at each repCluster.
572598 // Hence, multiply with the identity layouts starting from the
573599 // least significant dimension.
574- nonKDim = rank - 2 ;
575- KDim = rank - 1 ;
576- tileLayout *= LinearLayout::identity1D (repCluster[KDim], kRegister ,
577- outDimNames[KDim]);
578- tileLayout *= LinearLayout::identity1D (repCluster[nonKDim], kRegister ,
579- outDimNames[nonKDim]);
600+ dimNonK = rank - 2 ;
601+ dimK = rank - 1 ;
602+ tileLayout *= LinearLayout::identity1D (repCluster[dimK], kRegister ,
603+ outDimNames[dimK]);
604+ std::cout << (LinearLayout::identity1D (repCluster[dimK], kRegister ,
605+ outDimNames[dimK])
606+ .toString ())
607+ << std::endl;
608+ std::cout << (tileLayout.toString ()) << std::endl;
609+ tileLayout *= LinearLayout::identity1D (repCluster[dimNonK], kRegister ,
610+ outDimNames[dimNonK]);
611+ std::cout << (LinearLayout::identity1D (repCluster[dimNonK], kRegister ,
612+ outDimNames[dimNonK])
613+ .toString ())
614+ << std::endl;
580615 std::cout << (tileLayout.toString ()) << std::endl;
581616
582617 // // The identical layout is repeated among warps
583618 tileLayout *=
584- LinearLayout::identity1D (warpsPerCTA[KDim ], kWarp , outDimNames[KDim ]);
585- tileLayout *= LinearLayout::identity1D (warpsPerCTA[nonKDim ], kWarp ,
586- outDimNames[nonKDim ]);
619+ LinearLayout::identity1D (warpsPerCTA[dimK ], kWarp , outDimNames[dimK ]);
620+ tileLayout *= LinearLayout::identity1D (warpsPerCTA[dimNonK ], kWarp ,
621+ outDimNames[dimNonK ]);
587622 if (rank == 3 )
588623 tileLayout *=
589624 LinearLayout::identity1D (warpsPerCTA[0 ], kWarp , outDimNames[0 ]);
590- auto order =
591- llvm::to_vector (llvm::reverse (triton::gpu::getWarpOrder (layout)));
592- std::cout << " order: " << order[1 ] << " , " << order[0 ] << std::endl;
593- // tileLayout *= identityND(kWarp, warpsPerCTA,
594- // llvm::to_vector(llvm::reverse(llvm::seq<unsigned>(rank))),
595- // outDimNames);
596- std::cout << (tileLayout.toString ()) << std::endl;
625+ // std::cout << (tileLayout.toString()) << std::endl;
597626 }
598627
599628 // Lastly, the layout repeats to match the shape.
600629 // Operand A/B repeats through the K-dimension first then repeats
601630 // through the non-K dimension.
602631 // SmallVector<int64_t> numReps = dpas.getDPASRepetitions(shape, opIdx);
632+ // std::cout << "numReps: " << numReps[0] << ", " << numReps[1] << std::endl;
603633 // tileLayout *=
604- // LinearLayout::identity1D(numReps[KDim], kRegister, outDimNames[KDim]);
605- // tileLayout *= LinearLayout::identity1D(numReps[nonKDim], kRegister,
606- // outDimNames[nonKDim]);
607- // // std::cout << (tileLayout.toString()) << std::endl;
634+ // LinearLayout::identity1D(numReps[dimK], kRegister, outDimNames[dimK]);
635+ // tileLayout *= LinearLayout::identity1D(numReps[dimNonK], kRegister,
636+ // outDimNames[dimNonK]);
637+ // if (rank == 3)
638+ // tileLayout *=
639+ // LinearLayout::identity1D(numReps[0], kRegister, outDimNames[0]);
640+ // std::cout << (tileLayout.toString()) << std::endl;
608641
609642 return combineCtaCgaWithShape (std::move (tileLayout),
610643 CTALayoutAttr::getDefault (ctx, rank), shape);
0 commit comments