1- #include " triton/Analysis/Utility.h"
21#include " intel/include/Analysis/Utility.h"
2+
3+ #include " llvm/ADT/TypeSwitch.h"
4+
5+ #include " triton/Conversion/TritonGPUToLLVM/Utility.h"
6+
37#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
48
59namespace mlir ::triton::gpu::intel {
10+ namespace {
11+ constexpr inline unsigned minSubGroupTransposeWidth = 8 ;
12+
13+ bool canTypeBeConvertedForSubGroupTranspose (Type type) {
14+ return TypeSwitch<Type, bool >(type)
15+ .Case ([](FloatType floatTy) {
16+ // Support via bitcasting to integer type.
17+ return isValidElementTypeForSubGroupTranspose (
18+ IntegerType::get (floatTy.getContext (), floatTy.getWidth ()));
19+ })
20+ .Case ([](IntegerType intTy) {
21+ // Support via extending to supported type.
22+ return isValidElementTypeForSubGroupTranspose (intTy) ||
23+ intTy.getWidth () < minSubGroupTransposeWidth;
24+ })
25+ .Case ([](PointerType) {
26+ // Support via ptrtoint
27+ return true ;
28+ })
29+ .Default (false );
30+ }
31+
32+ // Return a vector such as:
33+ // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [laneSize, 0], ...,
34+ // [registerSize / 2, 0]],
35+ // i.e., mapping registers to lanes till laneSize and performing an ID
36+ // conversion afterwards.
37+ std::vector<std::vector<int32_t >>
38+ buildSubGroupTransposeRegisterBases (int32_t registerSize, int32_t laneSize) {
39+ std::vector<std::vector<int32_t >> bases;
40+ std::vector<int32_t > curr (2 );
41+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
42+ curr[1 ] = i;
43+ bases.push_back (curr);
44+ }
45+ curr[1 ] = 0 ;
46+ for (int32_t i = laneSize; i < registerSize; i *= 2 ) {
47+ curr[0 ] = i;
48+ bases.push_back (curr);
49+ }
50+ return bases;
51+ }
52+
53+ // Return a vector such as:
54+ // [[0, 1], [0, 2], [0, 4], ..., [0, laneSize / 2], [1, 0], ...,
55+ // [registerSize / (2 * laneSize), 0]]
56+ // i.e., mapping registers to lanes till laneSize and repeating the pattern
57+ // afterwards.
58+ std::vector<std::vector<int32_t >>
59+ buildSubGroupShuffleRegisterBases (int32_t registerSize, int32_t laneSize) {
60+ std::vector<std::vector<int32_t >> bases;
61+ std::vector<int32_t > curr (2 );
62+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
63+ curr[1 ] = i;
64+ bases.push_back (curr);
65+ }
66+ curr[1 ] = 0 ;
67+ for (int32_t i = laneSize, val = 1 ; i < registerSize; i *= 2 , val *= 2 ) {
68+ curr[0 ] = val;
69+ bases.push_back (curr);
70+ }
71+ return bases;
72+ }
73+
74+ // Return a vector such as:
75+ // [[1, 0], [2, 0], [4, 0], ..., [laneSize / 2, 0]],
76+ // i.e., mapping lanes to registers.
77+ std::vector<std::vector<int32_t >>
78+ buildSubGroupTransposeLaneBases (int32_t laneSize) {
79+ std::vector<std::vector<int32_t >> bases;
80+ std::vector<int32_t > curr (2 );
81+ for (int32_t i = 1 ; i < laneSize; i *= 2 ) {
82+ curr[0 ] = i;
83+ bases.push_back (curr);
84+ }
85+ return bases;
86+ }
87+
88+ } // namespace
689
790bool isDpasToDotShortcut (RankedTensorType dpasTy, RankedTensorType dotTy) {
891 auto dpasLayout = dyn_cast<DpasEncodingAttr>(dpasTy.getEncoding ());
@@ -24,4 +107,120 @@ bool isDpasToDotShortcut(RankedTensorType dpasTy, RankedTensorType dotTy) {
24107 return false ;
25108}
26109
110+ bool cvtIsSubGroupShuffle (RankedTensorType srcTy, RankedTensorType dstTy) {
111+ MLIRContext *ctx = srcTy.getContext ();
112+ StringAttr kRegister = str_attr (" register" );
113+ StringAttr kLane = str_attr (" lane" );
114+ StringAttr kWarp = str_attr (" warp" );
115+ StringAttr kBlock = str_attr (" block" );
116+
117+ std::optional<LinearLayout> srcLayout =
118+ toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
119+ if (!srcLayout)
120+ return false ;
121+
122+ std::optional<LinearLayout> dstLayout =
123+ toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
124+ if (!dstLayout)
125+ return false ;
126+
127+ LinearLayout comp = dstLayout->invertAndCompose (*srcLayout);
128+ std::optional<LinearLayout> conversion = comp.quotient (kBlock );
129+ if (!conversion)
130+ return false ;
131+ conversion = conversion->quotient (kWarp );
132+ if (!conversion)
133+ return false ;
134+
135+ // TODO: Support more kind of shuffles.
136+ // Expected conversion is:
137+ // - register=1 -> (0, 1)
138+ // ...
139+ // - register=2**i -> (0, 2**i)
140+ // ...
141+ // - register=M -> (0, 2**M)
142+ // ...
143+ // - register=2**k -> (2**(k-M), 0)
144+ // ...
145+ // - register=2**N -> (2**(N-M), 0)
146+ // - lane=1 -> (0, 0)
147+ // ...
148+ // - lane=2**j -> (0, 0)
149+ // ...
150+ // lane=2**M -> (0, 0)
151+ // where out dims are: [register (size 2**(N - M)), lane (size 2**(M + 1))]
152+ //
153+ // With N >= M.
154+ int32_t registerInDimSize = conversion->getInDimSize (kRegister );
155+ int32_t laneOutDimSize = conversion->getOutDimSize (kLane );
156+ return conversion->sublayoutIsZero ({kLane }, {kRegister , kLane }) &&
157+ conversion->getBases ().lookup (kRegister ) ==
158+ buildSubGroupShuffleRegisterBases (registerInDimSize,
159+ laneOutDimSize);
160+ }
161+
162+ bool isValidElementTypeForSubGroupTranspose (Type type) {
163+ return TypeSwitch<Type, bool >(type)
164+ .Case ([](IntegerType intTy) {
165+ unsigned width = intTy.getWidth ();
166+ return width == 8 || width == 16 || width == 32 || width == 64 ;
167+ })
168+ .Default (false );
169+ }
170+
171+ bool cvtIsSubGroupTranspose (RankedTensorType srcTy, RankedTensorType dstTy) {
172+ if (!canTypeBeConvertedForSubGroupTranspose (srcTy.getElementType ()))
173+ return false ;
174+
175+ MLIRContext *ctx = srcTy.getContext ();
176+ StringAttr kRegister = str_attr (" register" );
177+ StringAttr kLane = str_attr (" lane" );
178+ StringAttr kWarp = str_attr (" warp" );
179+ StringAttr kBlock = str_attr (" block" );
180+
181+ std::optional<LinearLayout> srcLayout =
182+ toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
183+ if (!srcLayout)
184+ return false ;
185+
186+ std::optional<LinearLayout> dstLayout =
187+ toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
188+ if (!dstLayout)
189+ return false ;
190+
191+ LinearLayout comp = dstLayout->invertAndCompose (*srcLayout);
192+ std::optional<LinearLayout> conversion = comp.quotient (kBlock );
193+ if (!conversion)
194+ return false ;
195+ conversion = conversion->quotient (kWarp );
196+ if (!conversion)
197+ return false ;
198+
199+ // Expected conversion is:
200+ // - register=1 -> (0, 1)
201+ // ...
202+ // - register=2**i -> (0, 2**i)
203+ // ...
204+ // - register=M -> (0, 2**M)
205+ // ...
206+ // - register=2**k -> (2**k, 0)
207+ // ...
208+ // - register=N -> (2**N, 0)
209+ // - lane=1 -> (0, 1)
210+ // ...
211+ // - lane=2**j -> (2**j, 0)
212+ // ...
213+ // lane=2**M -> (2**M, 0)
214+ // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
215+ //
216+ // With N >= M.
217+ int32_t registerInDimSize = conversion->getInDimSize (kRegister );
218+ int32_t laneInDimSize = conversion->getInDimSize (kLane );
219+ return conversion->getBases ().lookup (kRegister ) ==
220+ buildSubGroupTransposeRegisterBases (registerInDimSize,
221+ laneInDimSize) &&
222+ conversion->getBases ().lookup (kLane ) ==
223+ buildSubGroupTransposeLaneBases (laneInDimSize);
224+ }
225+
27226} // namespace mlir::triton::gpu::intel
0 commit comments