33// Licensed under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
76#include " iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
8-
97namespace mlir ::iree_compiler {
108
11- // Given an `expandShape` vector-of-vectors describing the mapping from source
12- // dimensions to expanded dimensions, returns the index of the first expanded
13- // dimension corresponding to the given source dimension index.
14- static int64_t
15- getExpandedDimFirstIdx (const TileSwizzle::ExpandShapeType &expandShape,
16- int64_t srcIndex ) {
17- int dstIndexFirst = 0 ;
18- for (int i = 0 ; i < srcIndex ; ++i) {
19- dstIndexFirst += expandShape[i].size ();
9+ using Kind = TileSwizzle::Dim::Kind;
10+
11+ // Returns the index of the first destination dimension corresponding to the
12+ // given source dimension `srcIdx`.
13+ static int64_t expandedDimIdx (const TileSwizzle::ExpandShapeType &expandShape,
14+ int srcIdx ) {
15+ int dstIdx = 0 ;
16+ for (int i = 0 ; i < srcIdx ; ++i) {
17+ dstIdx += expandShape[i].size ();
2018 }
21- return dstIndexFirst ;
19+ return dstIdx ;
2220}
2321
24- void unroll (TileSwizzle &swizzle, int srcIndex, int unrollFactor,
25- TileSwizzle::Dim::Kind kind) {
26- assert (unrollFactor > 1 );
27- int dstIndexFirst = getExpandedDimFirstIdx (swizzle.expandShape , srcIndex);
28- TileSwizzle::Dim unrollDim;
29- unrollDim.size = unrollFactor;
30- unrollDim.kind = kind;
22+ // Pushes `dim` to the front of `swizzle.expandShape[srcIdx]`, and updates
23+ // `swizzle.permutation` to make the new dimension outer-most among the dims in
24+ // `swizzle.expandShape[srcIdx]`.
25+ //
26+ // This can be used to unroll a kernel with kind = CrossIntrinsic,
27+ // or to expand a kernel to multiple subgroups with kind = CrossThread.
28+ //
29+ // Example:
30+ // Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
31+ // Input srcIdx = 1
32+ // Input dim.size = 4
33+ // -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
34+ //
35+ static void expand (TileSwizzle &swizzle, int srcIdx, TileSwizzle::Dim dim) {
36+ int dstIdx = expandedDimIdx (swizzle.expandShape , srcIdx);
3137 // The new unrolling dimension is inserted at the start of the expandShape
32- // dimensions group corresponding to srcIndex.
33- swizzle.expandShape [srcIndex].insert (swizzle.expandShape [srcIndex].begin (),
34- unrollDim);
38+ // dimensions group corresponding to srcIdx.
39+ swizzle.expandShape [srcIdx].insert (swizzle.expandShape [srcIdx].begin (), dim);
3540 // Since we are not interleaving here, generating side-by-side copies of the
3641 // original layout, the new unrolling dimension is the new outermost
3742 // dimension. Existing entries get shifted to make room for it.
3843 for (auto &p : swizzle.permutation ) {
39- p += (p >= dstIndexFirst );
44+ p += (p >= dstIdx );
4045 }
41- swizzle.permutation .insert (swizzle.permutation .begin (), dstIndexFirst );
46+ swizzle.permutation .insert (swizzle.permutation .begin (), dstIdx );
4247}
4348
44- void interleave (TileSwizzle &swizzle, int srcIndex,
45- int expandedDimIndexToInterleaveAt) {
46- // Compute which inner dimension to permute the current outer dimension into.
47- int dstIndexFirst = getExpandedDimFirstIdx (swizzle.expandShape , srcIndex);
48- int dstIndexToInterleaveAt = dstIndexFirst + expandedDimIndexToInterleaveAt;
49-
49+ // Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
50+ // move permutation[0], the outer-most dimension (which the unroll() function
51+ // created to be the unrolling dimension), to the inner dimension given by
52+ // `expandedIdx`.
53+ //
54+ // Example:
55+ // Input swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
56+ // Input srcIdx = 1
57+ // Input expandedIdx = 1
58+ // -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [2, 0, 1] }
59+ //
60+ static void interleave (TileSwizzle &swizzle, int srcIdx, int expandedIdx) {
61+ int dstIdx = expandedDimIdx (swizzle.expandShape , srcIdx) + expandedIdx;
5062 SmallVector<int64_t > outPermutation (swizzle.permutation .size ());
5163 // The leading dimension, permutation[0], gets moved inwards to the
52- // position that we just computed, dstIndexToInterleaveAt .
53- outPermutation[dstIndexToInterleaveAt ] = swizzle.permutation [0 ];
64+ // position that we just computed, dstIdx .
65+ outPermutation[dstIdx ] = swizzle.permutation [0 ];
5466 // Outer dimensions get shifted outwards to fill the gap.
55- for (int i = 0 ; i < dstIndexToInterleaveAt ; ++i) {
67+ for (int i = 0 ; i < dstIdx ; ++i) {
5668 outPermutation[i] = swizzle.permutation [i + 1 ];
5769 }
58- // Inner dimensions don't change. That is to say that we only interleave
59- // at `targetInterleavedElements` granularity, we don't swizzle further
60- // internally to that.
61- for (int i = dstIndexToInterleaveAt + 1 ; i < outPermutation.size (); ++i) {
70+ // Inner dimensions don't change.
71+ for (int i = dstIdx + 1 ; i < outPermutation.size (); ++i) {
6272 outPermutation[i] = swizzle.permutation [i];
6373 }
6474 swizzle.permutation = outPermutation;
6575}
6676
67- // Returns the permutation of indices that sorts `v` with the given comparator.
68- template <template <typename U> class Comparator , typename T>
69- static SmallVector<int64_t > getSortingPermutation (ArrayRef<T> v) {
70- using P = std::pair<int64_t , T>;
71- SmallVector<P> pairs;
72- pairs.reserve (v.size ());
73- for (auto [i, x] : llvm::enumerate (v)) {
74- pairs.push_back ({i, x});
75- }
76- std::sort (pairs.begin (), pairs.end (),
77- [](P p1, P p2) { return Comparator<T>{}(p1.second , p2.second ); });
78- SmallVector<int64_t > indices;
79- for (auto p : pairs) {
80- indices.push_back (p.first );
81- }
82- return indices;
83- }
84-
8577TileSwizzle getIntrinsicSwizzle (IREE::GPU::MMAIntrinsic intrinsic,
8678 IREE::GPU::MMAFragment fragment) {
8779 auto layout = IREE::GPU::getSingleSubgroupLayout (intrinsic, fragment);
@@ -95,57 +87,48 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
9587 std::swap (layout.element [0 ], layout.element [1 ]);
9688 }
9789
98- // Initially populate swizzle.expandShape with just the thread sizes, no
99- // shape expansion for now.
10090 TileSwizzle swizzle;
101- for (auto t : layout.thread ) {
102- TileSwizzle::Dim dim;
103- dim.size = t;
104- dim.kind = TileSwizzle::Dim::Kind::CrossThread; // Because `layout.thread`.
105- swizzle.expandShape .push_back ({dim});
106- }
107- // The layout strides decide the initial swizzle.permutation.
108- // Some WMMA intrinsics have tstrides=0 value. That always indicates an outer
109- // dimension, so overwrite 0 with a large value to get the right order.
110- SmallVector<int64_t , 2 > order = layout.tstrides ;
111- for (auto &val : order) {
112- val = (val == 0 ) ? INT64_MAX : val;
113- }
114- swizzle.permutation = getSortingPermutation<std::greater, int64_t >(order);
115- // Deal with any element size greater than 1 by inserting it innermost.
116- // Notice that this is similar to the unroll() function, just creating an
117- // inner dimension instead of an outer dimension.
91+ // There are two source dimensions, corresponding to the arrays in `layout`
92+ // all having size 2. Let's just guard that assumption with one assert here.
93+ assert (layout.thread .size () == 2 );
94+ swizzle.expandShape .resize (2 );
95+ // Expand the shape from inner-most to outer-most dimension, so that we can
96+ // simply use the `expand` helper function, which creates new outer dims.
97+ // `layout.element` dims are inner-most, so we add them first.
11898 for (auto [i, e] : llvm::enumerate (layout.element )) {
11999 if (e != 1 ) {
120- TileSwizzle::Dim dim;
121- dim.size = e;
122- dim.kind = TileSwizzle::Dim::Kind::Internal; // Because `layout.element`.
123- swizzle.expandShape [i].push_back (dim);
124- int newIndex = getExpandedDimFirstIdx (swizzle.expandShape , i + 1 ) - 1 ;
125- for (auto &p : swizzle.permutation ) {
126- p += (p >= newIndex);
127- }
128- swizzle.permutation .push_back (newIndex);
100+ expand (swizzle, i, {Kind::Internal, e});
129101 }
130102 }
131- // Deal with any outer size greater than 1 as just a call to unroll.
132- // Iterate over dims in reverse order because we are creating a new outermost
133- // dimension each time.
103+ // Next come `layout.thread` dims.
104+ for (auto [i, t] : llvm::enumerate (layout.thread )) {
105+ if (t != 1 ) {
106+ expand (swizzle, i, {Kind::CrossThread, t});
107+ }
108+ }
109+ // `layout.thread` dims are special in that they come with `layout.tstrides`
110+ // which may call for a swap in `swizzle.permutation`. We only need to worry
111+ // about that when both `layout.thread` sizes are greater than 1, so we didn't
112+ // skip them above. Note that this condition also implies that we don't need
113+ // to worry about `layout.tstrides == 0` which only happens with
114+ // `layout.thread == 1`.
115+ if (layout.thread [0 ] != 1 && layout.thread [1 ] != 1 &&
116+ layout.tstrides [0 ] > layout.tstrides [1 ]) {
117+ std::swap (swizzle.permutation [0 ], swizzle.permutation [1 ]);
118+ }
119+ // Finally come `layout.outer` dims, added last so they are outer-most.
134120 for (auto [i, o] : llvm::enumerate (layout.outer )) {
135121 if (o != 1 ) {
136- // `layout.outer` means additional Internal dimensions, just like
137- // `layout.element`, just swizzled outermost.
138- unroll (swizzle, i, o, TileSwizzle::Dim::Kind::Internal);
122+ expand (swizzle, i, {Kind::Internal, o});
139123 }
140124 }
141-
142125 return swizzle;
143126}
144127
145128static int getInnermostNonInternalDimIdx (
146129 const TileSwizzle::ExpandShapeDimVectorType &shape) {
147130 for (int idx = shape.size () - 1 ; idx >= 0 ; --idx) {
148- if (shape[idx].kind != TileSwizzle::Dim:: Kind::Internal) {
131+ if (shape[idx].kind != Kind::Internal) {
149132 return idx;
150133 }
151134 }
@@ -156,55 +139,54 @@ static int getInnermostNonInternalDimIdx(
156139TileSwizzle getSwizzle (IREE::GPU::DataTiledMMAAttr mma,
157140 IREE::GPU::MMAFragment fragment) {
158141 auto swizzle = getIntrinsicSwizzle (mma.getIntrinsic ().getValue (), fragment);
159- using Kind = TileSwizzle::Dim::Kind;
160142 switch (fragment) {
161143 case IREE::GPU::MMAFragment::Lhs:
162144 // A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
163145 // Unroll on K with interleaving, then on M.
164146 if (mma.getUnrollK () > 1 ) {
165- unroll (swizzle, 1 , mma.getUnrollK (), Kind::CrossIntrinsic );
147+ expand (swizzle, 1 , {Kind::CrossIntrinsic, mma.getUnrollK ()} );
166148 int interleavingIdx =
167149 getInnermostNonInternalDimIdx (swizzle.expandShape [1 ]);
168150 interleave (swizzle, 1 , interleavingIdx);
169151 }
170152 if (mma.getUnrollM () > 1 ) {
171- unroll (swizzle, 0 , mma.getUnrollM (), Kind::CrossIntrinsic );
153+ expand (swizzle, 0 , {Kind::CrossIntrinsic, mma.getUnrollM ()} );
172154 }
173155 if (mma.getUnrollMToSubgroups () > 1 ) {
174- unroll (swizzle, 0 , mma.getUnrollMToSubgroups (), Kind::CrossThread );
156+ expand (swizzle, 0 , {Kind::CrossThread, mma.getUnrollMToSubgroups ()} );
175157 }
176158 break ;
177159 case IREE::GPU::MMAFragment::Rhs:
178160 // B-matrix (RHS). Since the pack ops already took care of transposing B,
179161 // source dimensions are N (index 0) and K (index 1).
180162 // Unroll on K with interleaving, then on N.
181163 if (mma.getUnrollK () > 1 ) {
182- unroll (swizzle, 1 , mma.getUnrollK (), Kind::CrossIntrinsic );
164+ expand (swizzle, 1 , {Kind::CrossIntrinsic, mma.getUnrollK ()} );
183165 int interleavingIdx =
184166 getInnermostNonInternalDimIdx (swizzle.expandShape [1 ]);
185167 interleave (swizzle, 1 , interleavingIdx);
186168 }
187169 if (mma.getUnrollN () > 1 ) {
188- unroll (swizzle, 0 , mma.getUnrollN (), Kind::CrossIntrinsic );
170+ expand (swizzle, 0 , {Kind::CrossIntrinsic, mma.getUnrollN ()} );
189171 }
190172 if (mma.getUnrollNToSubgroups () > 1 ) {
191- unroll (swizzle, 0 , mma.getUnrollNToSubgroups (), Kind::CrossThread );
173+ expand (swizzle, 0 , {Kind::CrossThread, mma.getUnrollNToSubgroups ()} );
192174 }
193175 break ;
194176 case IREE::GPU::MMAFragment::Acc:
195177 // C-matrix (accumulator). Source dimensions are M (index 0) and N (index
196178 // 1). Unroll on N, then on M.
197179 if (mma.getUnrollN () > 1 ) {
198- unroll (swizzle, 1 , mma.getUnrollN (), Kind::CrossIntrinsic );
180+ expand (swizzle, 1 , {Kind::CrossIntrinsic, mma.getUnrollN ()} );
199181 }
200182 if (mma.getUnrollNToSubgroups () > 1 ) {
201- unroll (swizzle, 1 , mma.getUnrollNToSubgroups (), Kind::CrossThread );
183+ expand (swizzle, 1 , {Kind::CrossThread, mma.getUnrollNToSubgroups ()} );
202184 }
203185 if (mma.getUnrollM () > 1 ) {
204- unroll (swizzle, 0 , mma.getUnrollM (), Kind::CrossIntrinsic );
186+ expand (swizzle, 0 , {Kind::CrossIntrinsic, mma.getUnrollM ()} );
205187 }
206188 if (mma.getUnrollMToSubgroups () > 1 ) {
207- unroll (swizzle, 0 , mma.getUnrollMToSubgroups (), Kind::CrossThread );
189+ expand (swizzle, 0 , {Kind::CrossThread, mma.getUnrollMToSubgroups ()} );
208190 }
209191 break ;
210192 }
0 commit comments