@@ -71,6 +71,28 @@ buildSubGroupShuffleRegisterBases(int32_t registerSize, int32_t laneSize) {
7171 return bases;
7272}
7373
74+ // Return a vector such as:
75+ // [[1, 0], ..., [registerSize / (laneSize * 2)], [0, 1], [0, 2], [0, 4], ...,
76+ // [0, laneSize / 2]] i.e., mapping registers to lanes till laneSize and
77+ // repeating the pattern afterwards.
78+ std::vector<std::vector<int32_t >>
79+ buildContiguousSubGroupShuffleRegisterBases (int32_t registerSize,
80+ int32_t laneSize) {
81+ std::vector<std::vector<int32_t >> bases;
82+ std::vector<int32_t > curr (2 );
83+ int i = 1 ;
84+ for (; i < registerSize / laneSize; i *= 2 ) {
85+ curr[0 ] = i;
86+ bases.push_back (curr);
87+ }
88+ curr[0 ] = 0 ;
89+ for (int32_t val = 1 ; i < registerSize; i *= 2 , val *= 2 ) {
90+ curr[1 ] = val;
91+ bases.push_back (curr);
92+ }
93+ return bases;
94+ }
95+
7496// Return a vector such as:
7597// [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
7698// i.e., mapping lanes to registers.
@@ -85,6 +107,46 @@ buildSubGroupTransposeLaneBases(int32_t laneSize) {
85107 return bases;
86108}
87109
110+ // Return a vector such as:
111+ // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
112+ // [registerSize / (laneSize * 2), 0]],
113+ // i.e., mapping registers to lanes till laneSize and performing an ID
114+ // conversion afterwards.
115+ std::vector<std::vector<int32_t >>
116+ buildContiguousSubGroupTransposeRegisterBases (int32_t registerSize,
117+ int32_t laneSize) {
118+ std::vector<std::vector<int32_t >> bases;
119+ std::vector<int32_t > curr (2 );
120+ int i = 1 ;
121+ for (; i < laneSize; i *= 2 ) {
122+ curr[1 ] = i;
123+ bases.push_back (curr);
124+ }
125+ curr[1 ] = 0 ;
126+ for (int32_t j = 1 ; i < registerSize; i *= 2 , j *= 2 ) {
127+ curr[0 ] = j;
128+ bases.push_back (curr);
129+ }
130+ return bases;
131+ }
132+
133+ // Return a vector such as:
134+ // [[registerSize / laneSize, 0], [registerSize / laneSize * 2, 0], ...,
135+ // [registerSize / 2, 0]]
136+ // i.e., mapping registers to lanes till laneSize and performing an ID
137+ // conversion afterwards.
138+ std::vector<std::vector<int32_t >>
139+ buildContiguousSubGroupTransposeLaneBases (int32_t registerSize,
140+ int32_t laneSize) {
141+ std::vector<std::vector<int32_t >> bases;
142+ std::vector<int32_t > curr (2 );
143+ for (int32_t i = registerSize / laneSize; i < registerSize; i *= 2 ) {
144+ curr[0 ] = i;
145+ bases.push_back (curr);
146+ }
147+ return bases;
148+ }
149+
88150} // namespace
89151
90152bool isDpasToDotShortcut (RankedTensorType dpasTy, RankedTensorType dotTy) {
@@ -159,6 +221,59 @@ bool cvtIsSubGroupShuffle(RankedTensorType srcTy, RankedTensorType dstTy) {
159221 laneOutDimSize);
160222}
161223
224+ bool cvtIsContiguousSubGroupShuffle (RankedTensorType srcTy,
225+ RankedTensorType dstTy) {
226+ MLIRContext *ctx = srcTy.getContext ();
227+ StringAttr kRegister = str_attr (" register" );
228+ StringAttr kLane = str_attr (" lane" );
229+ StringAttr kWarp = str_attr (" warp" );
230+ StringAttr kBlock = str_attr (" block" );
231+
232+ std::optional<LinearLayout> srcLayout =
233+ toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
234+ if (!srcLayout)
235+ return false ;
236+
237+ std::optional<LinearLayout> dstLayout =
238+ toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
239+ if (!dstLayout)
240+ return false ;
241+
242+ LinearLayout comp = dstLayout->invertAndCompose (*srcLayout);
243+ std::optional<LinearLayout> conversion = comp.quotient (kBlock );
244+ if (!conversion)
245+ return false ;
246+ conversion = conversion->quotient (kWarp );
247+ if (!conversion)
248+ return false ;
249+
250+ // TODO: Support more kind of shuffles.
251+ // Expected conversion is:
252+ // - register=1 -> (0, 1)
253+ // ...
254+ // - register=2**i -> (0, 2**i)
255+ // ...
256+ // - register=M -> (0, 2**M)
257+ // ...
258+ // - register=2**k -> (2**(k-M), 0)
259+ // ...
260+ // - register=2**N -> (2**(N-M), 0)
261+ // - lane=1 -> (0, 0)
262+ // ...
263+ // - lane=2**j -> (0, 0)
264+ // ...
265+ // lane=2**M -> (0, 0)
266+ // where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
267+ //
268+ // With N >= M.
269+ int32_t registerInDimSize = conversion->getInDimSize (kRegister );
270+ int32_t laneOutDimSize = conversion->getOutDimSize (kLane );
271+ return conversion->sublayoutIsZero ({kLane }, {kRegister , kLane }) &&
272+ conversion->getBases ().lookup (kRegister ) ==
273+ buildContiguousSubGroupShuffleRegisterBases (registerInDimSize,
274+ laneOutDimSize);
275+ }
276+
162277bool isValidElementTypeForSubGroupTranspose (Type type) {
163278 return TypeSwitch<Type, bool >(type)
164279 .Case ([](IntegerType intTy) {
@@ -196,16 +311,14 @@ bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy) {
196311 if (!conversion)
197312 return false ;
198313
199- llvm::errs () << conversion << " \n " ;
200-
201314 // Expected conversion is:
202315 // - register=1 -> (0, 1)
203316 // ...
204317 // - register=2**i -> (0, 2**i)
205318 // ...
206319 // - register=M -> (0, 2**M)
207320 // ...
208- // - register=2**k -> (2**k , 0)
321+ // - register=2**k -> (, 0)
209322 // ...
210323 // - register=N -> (2**N, 0)
211324 // - lane=1 -> (0, 1)
@@ -225,6 +338,63 @@ bool cvtIsSubGroupTranspose(RankedTensorType srcTy, RankedTensorType dstTy) {
225338 buildSubGroupTransposeLaneBases (laneInDimSize);
226339}
227340
341+ bool cvtIsContiguousSubGroupTranspose (RankedTensorType srcTy,
342+ RankedTensorType dstTy) {
343+ if (!canTypeBeConvertedForSubGroupTranspose (srcTy.getElementType ()))
344+ return false ;
345+
346+ MLIRContext *ctx = srcTy.getContext ();
347+ StringAttr kRegister = str_attr (" register" );
348+ StringAttr kLane = str_attr (" lane" );
349+ StringAttr kWarp = str_attr (" warp" );
350+ StringAttr kBlock = str_attr (" block" );
351+
352+ std::optional<LinearLayout> srcLayout =
353+ toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
354+ if (!srcLayout)
355+ return false ;
356+
357+ std::optional<LinearLayout> dstLayout =
358+ toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
359+ if (!dstLayout)
360+ return false ;
361+
362+ LinearLayout comp = dstLayout->invertAndCompose (*srcLayout);
363+ std::optional<LinearLayout> conversion = comp.quotient (kBlock );
364+ if (!conversion)
365+ return false ;
366+ conversion = conversion->quotient (kWarp );
367+ if (!conversion)
368+ return false ;
369+
370+ // Expected conversion is:
371+ // - register=1 -> (0, 1)
372+ // ...
373+ // - register=2**i -> (0, 2**i)
374+ // ...
375+ // - register=M -> (0, 2**M)
376+ // ...
377+ // - register=2**k -> (1, 0)
378+ // ...
379+ // - register=N -> (2**(N-k), 0)
380+ // - lane=1 -> (2**(N-k+1), 0)
381+ // ...
382+ // - lane=2**j -> (2**(N-k+j), 0)
383+ // ...
384+ // lane=2**M -> (2**(N-k+M), 0)
385+ // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
386+ //
387+ // With N >= M.
388+ int32_t registerInDimSize = conversion->getInDimSize (kRegister );
389+ int32_t laneInDimSize = conversion->getInDimSize (kLane );
390+ return conversion->getBases ().lookup (kRegister ) ==
391+ buildContiguousSubGroupTransposeRegisterBases (registerInDimSize,
392+ laneInDimSize) &&
393+ conversion->getBases ().lookup (kLane ) ==
394+ buildContiguousSubGroupTransposeLaneBases (registerInDimSize,
395+ laneInDimSize);
396+ }
397+
228398bool cvtIsUnbroadcast (RankedTensorType srcTy, RankedTensorType dstTy) {
229399 MLIRContext *ctx = srcTy.getContext ();
230400 StringAttr kRegister = str_attr (" register" );
0 commit comments