@@ -88,18 +88,8 @@ SmallVector<unsigned> getThreadsPerWarp(Attribute layout,
8888 return toLinearEncoding (layout, shape).getThreadsPerWarp ();
8989}
9090
91- SmallVector<unsigned > getWarpsPerCTA (Attribute layout) {
92- if (auto distributedLayout =
93- mlir::dyn_cast<DistributedEncodingTrait>(layout)) {
94- return distributedLayout.getWarpsPerCTA ();
95- }
96-
97- llvm::report_fatal_error (" getWarpsPerCTA not implemented" );
98- return SmallVector<unsigned >();
99- }
100-
101- SmallVector<unsigned > getWarpsPerCTAWithUniqueData (Attribute layout,
102- ArrayRef<int64_t > shape) {
91+ SmallVector<unsigned > getWarpsPerCTA (Attribute layout,
92+ ArrayRef<int64_t > shape) {
10393 return toLinearEncoding (layout, shape).getWarpsPerCTA ();
10494}
10595
@@ -581,9 +571,6 @@ SmallVector<unsigned> BlockedEncodingAttr::getCTAOrder() const {
581571SmallVector<unsigned > BlockedEncodingAttr::getCTASplitNum () const {
582572 return SmallVector<unsigned >(getCTALayout ().getCTASplitNum ());
583573}
584- SmallVector<unsigned > BlockedEncodingAttr::getWarpsPerCTA () const {
585- return SmallVector<unsigned >(getWarpsPerCTA__ ());
586- }
587574
588575template <class T >
589576SmallVector<T> SliceEncodingAttr::paddedShape (ArrayRef<T> shape) const {
@@ -640,15 +627,6 @@ SmallVector<unsigned> SliceEncodingAttr::getCTAsPerCGA() const {
640627 llvm::report_fatal_error (
641628 " getCTAsPerCGA for SliceEncodingAttr is not well-defined" );
642629}
643- SmallVector<unsigned > SliceEncodingAttr::getWarpsPerCTA () const {
644- auto parent = getParent ();
645- auto parentWarpsPerCTA = ::getWarpsPerCTA (parent);
646- SmallVector<unsigned > warpsPerCTA = parentWarpsPerCTA;
647- warpsPerCTA.erase (warpsPerCTA.begin () + getDim ());
648- int32_t nextDim = getDim () < warpsPerCTA.size () ? getDim () : getDim () - 1 ;
649- warpsPerCTA[nextDim] *= parentWarpsPerCTA[getDim ()];
650- return warpsPerCTA;
651- }
652630
653631// Wmma encoding
654632
@@ -704,14 +682,6 @@ SmallVector<unsigned> DotOperandEncodingAttr::getCTASplitNum() const {
704682 res[kDim ] = 1 ;
705683 return res;
706684}
707- SmallVector<unsigned > DotOperandEncodingAttr::getWarpsPerCTA () const {
708- auto distributedLayout = mlir::cast<DistributedEncodingTrait>(getParent ());
709- auto warps = distributedLayout.getWarpsPerCTA ();
710- auto rank = warps.size ();
711- auto kDim = getOpIdx () == 0 ? rank - 1 : rank - 2 ;
712- warps[kDim ] = 1 ;
713- return warps;
714- }
715685
716686LogicalResult DotOperandEncodingAttr::verify (
717687 ::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
@@ -1339,7 +1309,7 @@ void NvidiaMmaEncodingAttr::print(AsmPrinter &printer) const {
13391309 << " , warpsPerCTA = [" << ArrayRef (getWarpsPerCTA ()) << " ]" ;
13401310
13411311 maybePrintCTALayout (getContext (), printer, getCTALayout (),
1342- /* rank=*/ getWarpsPerCTA (). size ());
1312+ /* rank=*/ getRank ());
13431313
13441314 printer << " , instrShape = [" << getInstrShape () << " ]}>" ;
13451315}
@@ -1419,11 +1389,11 @@ void AMDMfmaEncodingAttr::print(AsmPrinter &printer) const {
14191389 printer << " <{"
14201390 << " versionMajor = " << getVersionMajor () //
14211391 << " , versionMinor = " << getVersionMinor () //
1422- << " , warpsPerCTA = [" << ArrayRef ( getWarpsPerCTA ()) << " ]" //
1392+ << " , warpsPerCTA = [" << getWarpsPerCTA () << " ]" //
14231393 << " , instrShape = [" << ArrayRef{getMDim (), getNDim ()} << " ]" //
14241394 << " , isTransposed = " << getIsTransposed ();
14251395 maybePrintCTALayout (getContext (), printer, getCTALayout (),
1426- /* rank=*/ getWarpsPerCTA (). size ());
1396+ /* rank=*/ getRank ());
14271397 printer << " }>" ;
14281398}
14291399
@@ -1754,9 +1724,6 @@ SmallVector<unsigned> AMDMfmaEncodingAttr::getCTAOrder() const {
17541724SmallVector<unsigned > AMDMfmaEncodingAttr::getCTASplitNum () const {
17551725 return SmallVector<unsigned >(getCTALayout ().getCTASplitNum ());
17561726}
1757- SmallVector<unsigned > AMDMfmaEncodingAttr::getWarpsPerCTA () const {
1758- return SmallVector<unsigned >(getWarpsPerCTA__ ());
1759- }
17601727
17611728SmallVector<int64_t >
17621729AMDMfmaEncodingAttr::getInstrShapeForOperand (int kWidth , int opIdx) const {
@@ -1875,9 +1842,6 @@ SmallVector<unsigned> AMDWmmaEncodingAttr::getCTAOrder() const {
18751842SmallVector<unsigned > AMDWmmaEncodingAttr::getCTASplitNum () const {
18761843 return SmallVector<unsigned >(getCTALayout ().getCTASplitNum ());
18771844}
1878- SmallVector<unsigned > AMDWmmaEncodingAttr::getWarpsPerCTA () const {
1879- return SmallVector<unsigned >(getWarpsPerCTA__ ());
1880- }
18811845
18821846SmallVector<int64_t > AMDWmmaEncodingAttr::getElemsPerInstrForOperands () const {
18831847 return {16 , 16 };
@@ -1949,9 +1913,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getCTAOrder() const {
19491913SmallVector<unsigned > NvidiaMmaEncodingAttr::getCTASplitNum () const {
19501914 return SmallVector<unsigned >(getCTALayout ().getCTASplitNum ());
19511915}
1952- SmallVector<unsigned > NvidiaMmaEncodingAttr::getWarpsPerCTA () const {
1953- return SmallVector<unsigned >(getWarpsPerCTA__ ());
1954- }
19551916
19561917SmallVector<unsigned >
19571918NvidiaMmaEncodingAttr::getRepOrderForOperand (int opIdx) const {
@@ -1966,7 +1927,7 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
19661927 " kWidth must be >= 32 / bitwidth for this function to be well-defined" );
19671928 auto rank = shape.size ();
19681929 // Broadcast long K
1969- auto warpsPerCTA = getWarpsPerCTA ();
1930+ auto warpsPerCTA = to_vector ( getWarpsPerCTA () );
19701931 auto kDim = opIdx == 0 ? rank - 1 : rank - 2 ;
19711932 warpsPerCTA[kDim ] = 1 ;
19721933
0 commit comments