@@ -90,6 +90,7 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
9090 // we split the MMA into 4 sub-MMAs, each with a stride 4 x 32-bit along the
9191 // K dimension.
9292 llvm::SmallVector<unsigned > si;
93+ auto kIters = kWidth / (32 / bitwidth);
9394
9495 if (dot.getOpIdx () == 0 ) {
9596 // Original register layout:
@@ -106,11 +107,63 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
106107 // 2nd MMA: [[2, 3], [10, 11], [18, 19], [26, 27]]
107108 // 3rd MMA: [[4, 5], [12, 13], [20, 21], [28, 29]]
108109 // 4th MMA: [[6, 7], [14, 15], [22, 23], [30, 31]]
109- for (size_t kRep = 0 ; kRep < kWidth / numElemsPerVec; ++kRep )
110- for (size_t tile = 0 ; tile < 4 ; ++tile)
111- for (size_t e = 0 ; e < numElemsPerVec; ++e) {
112- si.push_back (kRep * numElemsPerVec + tile * kWidth + e);
113- }
110+ if (kIters <= repK) {
111+ for (size_t kRep = 0 ; kRep < kWidth / numElemsPerVec; ++kRep )
112+ for (size_t tile = 0 ; tile < 4 ; ++tile)
113+ for (size_t e = 0 ; e < numElemsPerVec; ++e) {
114+ si.push_back (kRep * numElemsPerVec + tile * kWidth + e);
115+ }
116+ } else {
117+ // Suppose kWidth=4 and type=fp32, so numElemsPerVec=1.
118+ // Each tile of the dot operand layout has a size of 16x32.
119+ // However, if the triton tensor size is 16x16, elements along the k
120+ // dimension are duplicated. Within each tile, each register
121+ // contains 2x8 elements arranged as follows:
122+ //
123+ // tile0/0 tile0/1
124+ // |<--kWidth=4-->| |<--kWidth-->|
125+ // |<-mmaWidth=2->|
126+ // [0, 1, 2, 3] [0, 1, 2, 3]
127+ // [4, 5, 6, 7] [4, 5, 6, 7]
128+ //
129+ // tile0/1 replicates the elements in tile0/0 along the k dimension.
130+ // For a tensor size of 32x32, the next tile on the m dimension is as
131+ // follows:
132+ //
133+ // tile1/0 tile1/1
134+ // |<--kWidth-->| |<--kWidth-->|
135+ // [8, 9, 10, 11], [8, 9, 10, 11]
136+ // [12, 13, 14, 15], [12, 13, 14, 15]
137+ //
138+ // Within a single tile, we can perform two MMAs, and the
139+ // resulting register layout for each MMA is as follows:
140+ //
141+ // 1st MMA: [0, 4, 1, 5]
142+ // 2nd MMA: [2, 6, 3, 7]
143+ // 3rd MMA: [8, 12, 9, 13]
144+ // 4th MMA: [10, 14, 11, 15]
145+ //
146+ // Additionally, we should reorder the elements by moving the duplicated
147+ // elements to the end. In the example above, we convert the order from
148+ // tile0/0, tile0/1, tile1/0, tile1/1 to tile0/0, tile1/0, tile0/1,
149+ // tile1/1, so that only the first two tiles will be used in the
150+ // computation.
151+ size_t elemsPerTile = 2 * 2 * kWidth ;
152+ size_t elemsPerMma = 2 * 2 * numElemsPerVec;
153+ size_t mmaWidth = kWidth / numElemsPerVec / 2 ;
154+ size_t repMma = elemsPerTile / (mmaWidth * elemsPerMma);
155+ for (size_t rep = 0 ; rep < repMma; ++rep)
156+ for (size_t tile = 0 ; tile < elems.size () / elemsPerTile; ++tile)
157+ for (size_t mmaKWidth = 0 ; mmaKWidth < mmaWidth; ++mmaKWidth)
158+ for (size_t kTile = 0 ; kTile < 2 ; ++kTile )
159+ for (size_t mTile = 0 ; mTile < 2 ; ++mTile )
160+ for (size_t e = 0 ; e < numElemsPerVec; ++e) {
161+ si.push_back (rep * mmaWidth * elemsPerMma +
162+ mmaKWidth * 2 * numElemsPerVec +
163+ tile * elemsPerTile + mTile * kWidth +
164+ kTile * numElemsPerVec + e);
165+ }
166+ }
114167 } else {
115168 // Original register layout:
116169 //
@@ -122,11 +175,36 @@ ValueTableV2 getValuesFromDotOperandLayoutStruct(
122175 // 2nd MMA: [[2, 3], [10, 11]]
123176 // 3rd MMA: [[4, 5], [12, 13]]
124177 // 4th MMA: [[6, 7], [14, 15]]
125- for (size_t kRep = 0 ; kRep < kWidth / numElemsPerVec; ++kRep )
126- for (size_t tile = 0 ; tile < 2 ; ++tile)
127- for (size_t e = 0 ; e < numElemsPerVec; ++e) {
128- si.push_back (kRep * numElemsPerVec + tile * kWidth + e);
129- }
178+ if (kIters <= repK) {
179+ for (size_t kRep = 0 ; kRep < kWidth / numElemsPerVec; ++kRep )
180+ for (size_t tile = 0 ; tile < 2 ; ++tile)
181+ for (size_t e = 0 ; e < numElemsPerVec; ++e) {
182+ si.push_back (kRep * numElemsPerVec + tile * kWidth + e);
183+ }
184+ } else {
185+ // Suppose kWidth=4 and type=fp32.
186+ // Original register layout:
187+ //
188+ // tile0/0 tile0/1
189+ // [0, 1, 2, 3]^T, [0, 1, 2, 3]^T
190+ //
191+ // Similar to the opIdx=0 situation, we should reorder the elements by
192+ // moving the duplicated elements to the end.
193+ size_t elemsPerTile = 2 * kWidth ;
194+ size_t elemsPerMma = 2 * numElemsPerVec;
195+ size_t mmaWidth = kWidth / numElemsPerVec / 2 ;
196+ size_t repMma = elemsPerTile / (mmaWidth * elemsPerMma);
197+ for (size_t rep = 0 ; rep < repMma; ++rep)
198+ for (size_t tile = 0 ; tile < elems.size () / elemsPerTile; ++tile)
199+ for (size_t mmaKWidth = 0 ; mmaKWidth < mmaWidth; ++mmaKWidth)
200+ for (size_t kTile = 0 ; kTile < 2 ; ++kTile )
201+ for (size_t e = 0 ; e < numElemsPerVec; ++e) {
202+ si.push_back (rep * mmaWidth * elemsPerMma +
203+ mmaKWidth * 2 * numElemsPerVec +
204+ tile * elemsPerTile + kTile * numElemsPerVec +
205+ e);
206+ }
207+ }
130208 }
131209
132210 auto step = si.size ();
0 commit comments