@@ -85,18 +85,8 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout,
8585 return toLinearEncoding (layout, shape).getThreadsPerWarp ();
8686}
8787
88- SmallVector<unsigned > getWarpsPerCTA (Attribute layout) {
89- if (auto distributedLayout =
90- mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
91- return distributedLayout.getWarpsPerCTA ();
92- }
93-
94- llvm::report_fatal_error (" getWarpsPerCTA not implemented" );
95- return SmallVector<unsigned >();
96- }
97-
98- SmallVector<unsigned > getWarpsPerCTAWithUniqueData (Attribute layout,
99- ArrayRef<int64_t > shape) {
88+ SmallVector<unsigned > getWarpsPerCTA (Attribute layout,
89+ ArrayRef<int64_t > shape) {
10090 return toLinearEncoding (layout, shape).getWarpsPerCTA ();
10191}
10292
@@ -578,9 +568,6 @@ SmallVector<unsigned> BlockedEncodingAttr::getCTAOrder() const {
578568SmallVector<unsigned > BlockedEncodingAttr::getCTASplitNum () const {
579569 return SmallVector<unsigned >(getCTALayout ().getCTASplitNum ());
580570}
581- SmallVector<unsigned > BlockedEncodingAttr::getWarpsPerCTA () const {
582- return SmallVector<unsigned >(getWarpsPerCTA__ ());
583- }
584571
585572template <class T >
586573SmallVector<T> SliceEncodingAttr::paddedShape (ArrayRef<T> shape) const {
@@ -637,15 +624,6 @@ SmallVector<unsigned> SliceEncodingAttr::getCTAsPerCGA() const {
637624 llvm::report_fatal_error (
638625 " getCTAsPerCGA for SliceEncodingAttr is not well-defined" );
639626}
640- SmallVector<unsigned > SliceEncodingAttr::getWarpsPerCTA () const {
641- auto parent = getParent ();
642- auto parentWarpsPerCTA = ::getWarpsPerCTA (parent);
643- SmallVector<unsigned > warpsPerCTA = parentWarpsPerCTA;
644- warpsPerCTA.erase (warpsPerCTA.begin () + getDim ());
645- int32_t nextDim = getDim () < warpsPerCTA.size () ? getDim () : getDim () - 1 ;
646- warpsPerCTA[nextDim] *= parentWarpsPerCTA[getDim ()];
647- return warpsPerCTA;
648- }
649627
650628// Wmma encoding
651629
@@ -701,14 +679,6 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
701679 res[kDim ] = 1 ;
702680 return res;
703681}
704- SmallVector<unsigned > DotOperandEncodingAttr::getWarpsPerCTA () const {
705- auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent ());
706- auto warps = distributedLayout.getWarpsPerCTA ();
707- auto rank = warps.size ();
708- auto kDim = getOpIdx () == 0 ? rank - 1 : rank - 2 ;
709- warps[kDim ] = 1 ;
710- return warps;
711- }
712682
713683LogicalResult DotOperandEncodingAttr::verify (
714684 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
@@ -1306,7 +1276,7 @@ void NvidiaMmaEncodingAttr::print(AsmPrinter &printer) const {
13061276 << " , warpsPerCTA = [" << ArrayRef (getWarpsPerCTA ()) << " ]" ;
13071277
13081278 maybePrintCTALayout (getContext (), printer, getCTALayout (),
1309- /* rank=*/ getWarpsPerCTA (). size ());
1279+ /* rank=*/ getRank ());
13101280
13111281 printer << " , instrShape = [" << getInstrShape () << " ]}>" ;
13121282}
@@ -1386,11 +1356,11 @@ void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
13861356 printer << " <{"
13871357 << " versionMajor = " << getVersionMajor () //
13881358 << " , versionMinor = " << getVersionMinor () //
1389- << " , warpsPerCTA = [" << ArrayRef ( getWarpsPerCTA ()) << " ]" //
1359+ << " , warpsPerCTA = [" << getWarpsPerCTA () << " ]" //
13901360 << " , instrShape = [" << ArrayRef{getMDim (), getNDim ()} << " ]" //
13911361 << " , isTransposed = " << getIsTransposed ();
13921362 maybePrintCTALayout (getContext (), printer, getCTALayout (),
1393- /* rank=*/ getWarpsPerCTA (). size ());
1363+ /* rank=*/ getRank ());
13941364 printer << " }>" ;
13951365}
13961366
@@ -1721,9 +1691,6 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getCTAOrder() const {
17211691SmallVector<unsigned > AMDMfmaEncodingAttr::getCTASplitNum () const {
17221692 return SmallVector<unsigned >(getCTALayout ().getCTASplitNum ());
17231693}
1724- SmallVector<unsigned > AMDMfmaEncodingAttr::getWarpsPerCTA () const {
1725- return SmallVector<unsigned >(getWarpsPerCTA__ ());
1726- }
17271694
17281695SmallVector<int64_t >
17291696AMDMfmaEncodingAttr::getInstrShapeForOperand (int kWidth , int opIdx) const {
@@ -1842,9 +1809,6 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAOrder() const {
18421809SmallVector<unsigned > AMDWmmaEncodingAttr::getCTASplitNum () const {
18431810 return SmallVector<unsigned >(getCTALayout ().getCTASplitNum ());
18441811}
1845- SmallVector<unsigned > AMDWmmaEncodingAttr::getWarpsPerCTA () const {
1846- return SmallVector<unsigned >(getWarpsPerCTA__ ());
1847- }
18481812
18491813SmallVector<int64_t > AMDWmmaEncodingAttr::getElemsPerInstrForOperands () const {
18501814 return {16 , 16 };
@@ -1916,9 +1880,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getCTAOrder() const {
19161880SmallVector<unsigned > NvidiaMmaEncodingAttr::getCTASplitNum () const {
19171881 return SmallVector<unsigned >(getCTALayout ().getCTASplitNum ());
19181882}
1919- SmallVector<unsigned > NvidiaMmaEncodingAttr::getWarpsPerCTA () const {
1920- return SmallVector<unsigned >(getWarpsPerCTA__ ());
1921- }
19221883
19231884SmallVector<unsigned >
19241885NvidiaMmaEncodingAttr::getRepOrderForOperand (int opIdx) const {
@@ -1933,7 +1894,7 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19331894 " kWidth must be >= 32 / bitwidth for this function to be well-defined" );
19341895 auto rank = shape.size ();
19351896 // Broadcast long K
1936- auto warpsPerCTA = getWarpsPerCTA ();
1897+ auto warpsPerCTA = to_vector ( getWarpsPerCTA () );
19371898 auto kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
19381899 warpsPerCTA[kDim ] = 1 ;
19391900
0 commit comments