@@ -597,6 +597,15 @@ We call each individual tile "rep".
597597 /*defaultImplementation=*/[{
598598 return toLinearEncoding($_self, shape).getElemsPerThread(shape);
599599 }]>,
600+ // Interface for the meta information about the multiple thread hierarchy.
601+ InterfaceMethod<"Get the shape of the warps per CTA.",
602+ "SmallVector<unsigned>",
603+ "getWarpsPerCTA">,
604+
605+
606+ InterfaceMethod<"Get the shape of the threads per warp",
607+ "SmallVector<unsigned>",
608+ "getThreadsPerWarp">,
600609 InterfaceMethod<"Convert to LinearLayout.",
601610 "LinearLayout",
602611 "toLinearLayout",
@@ -662,6 +671,8 @@ L(T) = [ {0,8} , {1,9} , {2,10}, {3,11}, {0,8} , {1, 9} , {2, 10}, {3, 11},
662671 SmallVector<unsigned> getCTAsPerCGA() const;
663672 SmallVector<unsigned> getCTAOrder() const;
664673 SmallVector<unsigned> getCTASplitNum() const;
674+ SmallVector<unsigned> getWarpsPerCTA() const;
675+ SmallVector<unsigned> getThreadsPerWarp() const;
665676
666677 LinearLayout toLinearLayout(ArrayRef<int64_t> shape) const;
667678
@@ -714,8 +725,6 @@ def LinearEncodingAttr : DistributedEncoding<"LinearEncoding", "linear_encoding"
714725 // If skipBroadcast is false, we count a base zero
715726 SmallVector<unsigned> basesPerDim(StringAttr dimName,
716727 bool skipBroadcast = true) const;
717- SmallVector<unsigned> getThreadsPerWarp() const;
718- SmallVector<unsigned> getWarpsPerCTA() const;
719728
720729 // [FIXME LL] Supports legacy behaviour. We should remove these functions
721730 SmallVector<unsigned> getShapePerCTATile() const;
825834 let parameters = (
826835 ins
827836 ArrayRefParameter<"unsigned">:$sizePerThread,
828- ArrayRefParameter<"unsigned">:$threadsPerWarp ,
829- ArrayRefParameter<"unsigned">:$warpsPerCTA ,
837+ ArrayRefParameter<"unsigned">:$threadsPerWarp__ ,
838+ ArrayRefParameter<"unsigned">:$warpsPerCTA__ ,
830839 ArrayRefParameter<"unsigned">:$order, // the fastest-changing axis first
831840
832841 // CTALayout is optional in the textual IR. If omitted, we infer it to be a
@@ -1030,7 +1039,7 @@ V [ 0,4,8...60 1,5...61 2,6...62 3,7...63 ] [ 128,132...188 129,
10301039 ins
10311040 "unsigned": $versionMajor,
10321041 "unsigned": $versionMinor,
1033- ArrayRefParameter<"unsigned">:$warpsPerCTA ,
1042+ ArrayRefParameter<"unsigned">:$warpsPerCTA__ ,
10341043 "unsigned":$MDim,
10351044 "unsigned":$NDim,
10361045 "bool":$isTransposed,
@@ -1151,7 +1160,7 @@ Row |
11511160 ins
11521161 "unsigned": $version,
11531162 "bool":$isTransposed,
1154- ArrayRefParameter<"unsigned">:$warpsPerCTA ,
1163+ ArrayRefParameter<"unsigned">:$warpsPerCTA__ ,
11551164 "CTALayoutAttr":$CTALayout
11561165 );
11571166
@@ -1257,7 +1266,7 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
12571266 ins
12581267 "unsigned":$versionMajor,
12591268 "unsigned":$versionMinor,
1260- ArrayRefParameter<"unsigned">:$warpsPerCTA ,
1269+ ArrayRefParameter<"unsigned">:$warpsPerCTA__ ,
12611270 "CTALayoutAttr":$CTALayout,
12621271 ArrayRefParameter<"unsigned">:$instrShape
12631272 );
0 commit comments