Skip to content

Commit ab35e1b

Browse files
authored
Simplify GPUTileSwizzleUtils and avoid creating unit dims. (iree-org#19105)
In `getIntrinsicSwizzle`, we had a slightly roundabout way of constructing the swizzle from the `SingleSubgroupLayout`. We started from the `thread` dims, which we used unconditionally even if they had the value 1, leading to unit dims; and then we inserted the `element` dims *on the inside*, which required custom manipulation of the `swizzle` field. Now we just start from the `element` dims and work our way outwards from there, which means we can reuse the same helper that used to be named `unroll` and that we rename here to `expand` in preparation for iree-org#19102, and which we also move to be a `static` helper since it's no longer used outside of this file. --------- Signed-off-by: Benoit Jacob <[email protected]>
1 parent b08ea12 commit ab35e1b

File tree

5 files changed

+93
-137
lines changed

5 files changed

+93
-137
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_materialize_encoding_gfx1100.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ func.func @matmul_lowering_WMMA_F32_16x16x16_F16() {
5050
// CHECK-DAG: %[[LHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(0)
5151
// CHECK-DAG: %[[RHS_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(1)
5252
// CHECK-DAG: %[[ACC_BINDING:.+]] = hal.interface.binding.subspan {{.+}} binding(2)
53-
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x1x16x16xf16>
54-
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x1x16x16xf16>
53+
// CHECK-DAG: %[[LHS:.+]] = flow.dispatch.tensor.load %[[LHS_BINDING]]{{.+}} -> tensor<?x?x4x16x16xf16>
54+
// CHECK-DAG: %[[RHS:.+]] = flow.dispatch.tensor.load %[[RHS_BINDING]]{{.+}} -> tensor<?x?x4x16x16xf16>
5555
// CHECK-DAG: %[[ACC:.+]] = flow.dispatch.tensor.load %[[ACC_BINDING]]{{.+}} -> tensor<?x?x4x4x8x2x16xf32>
5656
// CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS]], %[[RHS]], %[[ACC]]
5757
// CHECK-SAME: indexing_maps = [#[[MAP0]], #[[MAP1]], #[[MAP2]]],

compiler/src/iree/compiler/Codegen/Common/TileSwizzle.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ struct TileSwizzle {
4949

5050
// The size of the dimension.
5151
int16_t size = 0;
52+
53+
// Support constructing from any size type.
54+
template <typename T>
55+
Dim(Kind kind, T size) : kind(kind), size(size) {}
5256
};
5357

5458
using ExpandShapeDimVectorType = llvm::SmallVector<Dim, 4>;

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp

Lines changed: 84 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -3,85 +3,77 @@
33
// Licensed under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6-
76
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
8-
97
namespace mlir::iree_compiler {
108

11-
// Given an `expandShape` vector-of-vectors describing the mapping from source
12-
// dimensions to expanded dimensions, returns the index of the first expanded
13-
// dimension corresponding to the given source dimension index.
14-
static int64_t
15-
getExpandedDimFirstIdx(const TileSwizzle::ExpandShapeType &expandShape,
16-
int64_t srcIndex) {
17-
int dstIndexFirst = 0;
18-
for (int i = 0; i < srcIndex; ++i) {
19-
dstIndexFirst += expandShape[i].size();
9+
using Kind = TileSwizzle::Dim::Kind;
10+
11+
// Returns the index of the first destination dimension corresponding to the
12+
// given source dimension `srcIdx`.
13+
static int64_t expandedDimIdx(const TileSwizzle::ExpandShapeType &expandShape,
14+
int srcIdx) {
15+
int dstIdx = 0;
16+
for (int i = 0; i < srcIdx; ++i) {
17+
dstIdx += expandShape[i].size();
2018
}
21-
return dstIndexFirst;
19+
return dstIdx;
2220
}
2321

24-
void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
25-
TileSwizzle::Dim::Kind kind) {
26-
assert(unrollFactor > 1);
27-
int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex);
28-
TileSwizzle::Dim unrollDim;
29-
unrollDim.size = unrollFactor;
30-
unrollDim.kind = kind;
22+
// Pushes `dim` to the front of `swizzle.expandShape[srcIdx]`, and updates
23+
// `swizzle.permutation` to make the new dimension outer-most among the dims in
24+
// `swizzle.expandShape[srcIdx]`.
25+
//
26+
// This can be used to unroll a kernel with kind = CrossIntrinsic,
27+
// or to expand a kernel to multiple subgroups with kind = CrossThread.
28+
//
29+
// Example:
30+
// Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
31+
// Input srcIdx = 1
32+
// Input dim.size = 4
33+
// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
34+
//
35+
static void expand(TileSwizzle &swizzle, int srcIdx, TileSwizzle::Dim dim) {
36+
int dstIdx = expandedDimIdx(swizzle.expandShape, srcIdx);
3137
// The new unrolling dimension is inserted at the start of the expandShape
32-
// dimensions group corresponding to srcIndex.
33-
swizzle.expandShape[srcIndex].insert(swizzle.expandShape[srcIndex].begin(),
34-
unrollDim);
38+
// dimensions group corresponding to srcIdx.
39+
swizzle.expandShape[srcIdx].insert(swizzle.expandShape[srcIdx].begin(), dim);
3540
// Since we are not interleaving here, generating side-by-side copies of the
3641
// original layout, the new unrolling dimension is the new outermost
3742
// dimension. Existing entries get shifted to make room for it.
3843
for (auto &p : swizzle.permutation) {
39-
p += (p >= dstIndexFirst);
44+
p += (p >= dstIdx);
4045
}
41-
swizzle.permutation.insert(swizzle.permutation.begin(), dstIndexFirst);
46+
swizzle.permutation.insert(swizzle.permutation.begin(), dstIdx);
4247
}
4348

44-
void interleave(TileSwizzle &swizzle, int srcIndex,
45-
int expandedDimIndexToInterleaveAt) {
46-
// Compute which inner dimension to permute the current outer dimension into.
47-
int dstIndexFirst = getExpandedDimFirstIdx(swizzle.expandShape, srcIndex);
48-
int dstIndexToInterleaveAt = dstIndexFirst + expandedDimIndexToInterleaveAt;
49-
49+
// Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
50+
// move permutation[0], the outer-most dimension (which the unroll() function
51+
// created to be the unrolling dimension), to the inner dimension given by
52+
// `expandedIdx`.
53+
//
54+
// Example:
55+
// Input swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
56+
// Input srcIdx = 1
57+
// Input expandedIdx = 1
58+
// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [2, 0, 1] }
59+
//
60+
static void interleave(TileSwizzle &swizzle, int srcIdx, int expandedIdx) {
61+
int dstIdx = expandedDimIdx(swizzle.expandShape, srcIdx) + expandedIdx;
5062
SmallVector<int64_t> outPermutation(swizzle.permutation.size());
5163
// The leading dimension, permutation[0], gets moved inwards to the
52-
// position that we just computed, dstIndexToInterleaveAt.
53-
outPermutation[dstIndexToInterleaveAt] = swizzle.permutation[0];
64+
// position that we just computed, dstIdx.
65+
outPermutation[dstIdx] = swizzle.permutation[0];
5466
// Outer dimensions get shifted outwards to fill the gap.
55-
for (int i = 0; i < dstIndexToInterleaveAt; ++i) {
67+
for (int i = 0; i < dstIdx; ++i) {
5668
outPermutation[i] = swizzle.permutation[i + 1];
5769
}
58-
// Inner dimensions don't change. That is to say that we only interleave
59-
// at `targetInterleavedElements` granularity, we don't swizzle further
60-
// internally to that.
61-
for (int i = dstIndexToInterleaveAt + 1; i < outPermutation.size(); ++i) {
70+
// Inner dimensions don't change.
71+
for (int i = dstIdx + 1; i < outPermutation.size(); ++i) {
6272
outPermutation[i] = swizzle.permutation[i];
6373
}
6474
swizzle.permutation = outPermutation;
6575
}
6676

67-
// Returns the permutation of indices that sorts `v` with the given comparator.
68-
template <template <typename U> class Comparator, typename T>
69-
static SmallVector<int64_t> getSortingPermutation(ArrayRef<T> v) {
70-
using P = std::pair<int64_t, T>;
71-
SmallVector<P> pairs;
72-
pairs.reserve(v.size());
73-
for (auto [i, x] : llvm::enumerate(v)) {
74-
pairs.push_back({i, x});
75-
}
76-
std::sort(pairs.begin(), pairs.end(),
77-
[](P p1, P p2) { return Comparator<T>{}(p1.second, p2.second); });
78-
SmallVector<int64_t> indices;
79-
for (auto p : pairs) {
80-
indices.push_back(p.first);
81-
}
82-
return indices;
83-
}
84-
8577
TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
8678
IREE::GPU::MMAFragment fragment) {
8779
auto layout = IREE::GPU::getSingleSubgroupLayout(intrinsic, fragment);
@@ -95,57 +87,48 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
9587
std::swap(layout.element[0], layout.element[1]);
9688
}
9789

98-
// Initially populate swizzle.expandShape with just the thread sizes, no
99-
// shape expansion for now.
10090
TileSwizzle swizzle;
101-
for (auto t : layout.thread) {
102-
TileSwizzle::Dim dim;
103-
dim.size = t;
104-
dim.kind = TileSwizzle::Dim::Kind::CrossThread; // Because `layout.thread`.
105-
swizzle.expandShape.push_back({dim});
106-
}
107-
// The layout strides decide the initial swizzle.permutation.
108-
// Some WMMA intrinsics have tstrides=0 value. That always indicates an outer
109-
// dimension, so overwrite 0 with a large value to get the right order.
110-
SmallVector<int64_t, 2> order = layout.tstrides;
111-
for (auto &val : order) {
112-
val = (val == 0) ? INT64_MAX : val;
113-
}
114-
swizzle.permutation = getSortingPermutation<std::greater, int64_t>(order);
115-
// Deal with any element size greater than 1 by inserting it innermost.
116-
// Notice that this is similar to the unroll() function, just creating an
117-
// inner dimension instead of an outer dimension.
91+
// There are two source dimensions, corresponding to the arrays in `layout`
92+
// all having size 2. Let's just guard that assumption with one assert here.
93+
assert(layout.thread.size() == 2);
94+
swizzle.expandShape.resize(2);
95+
// Expand the shape from inner-most to outer-most dimension, so that we can
96+
// simply use the `expand` helper function, which creates new outer dims.
97+
// `layout.element` dims are inner-most, so we add them first.
11898
for (auto [i, e] : llvm::enumerate(layout.element)) {
11999
if (e != 1) {
120-
TileSwizzle::Dim dim;
121-
dim.size = e;
122-
dim.kind = TileSwizzle::Dim::Kind::Internal; // Because `layout.element`.
123-
swizzle.expandShape[i].push_back(dim);
124-
int newIndex = getExpandedDimFirstIdx(swizzle.expandShape, i + 1) - 1;
125-
for (auto &p : swizzle.permutation) {
126-
p += (p >= newIndex);
127-
}
128-
swizzle.permutation.push_back(newIndex);
100+
expand(swizzle, i, {Kind::Internal, e});
129101
}
130102
}
131-
// Deal with any outer size greater than 1 as just a call to unroll.
132-
// Iterate over dims in reverse order because we are creating a new outermost
133-
// dimension each time.
103+
// Next come `layout.thread` dims.
104+
for (auto [i, t] : llvm::enumerate(layout.thread)) {
105+
if (t != 1) {
106+
expand(swizzle, i, {Kind::CrossThread, t});
107+
}
108+
}
109+
// `layout.thread` dims are special in that they come with `layout.tstrides`
110+
// which may call for a swap in `swizzle.permutation`. We only need to worry
111+
// about that when both `layout.thread` sizes are greater than 1, so we didn't
112+
// skip them above. Note that this condition also implies that we don't need
113+
// to worry about `layout.tstrides == 0` which only happens with
114+
// `layout.thread == 1`.
115+
if (layout.thread[0] != 1 && layout.thread[1] != 1 &&
116+
layout.tstrides[0] > layout.tstrides[1]) {
117+
std::swap(swizzle.permutation[0], swizzle.permutation[1]);
118+
}
119+
// Finally come `layout.outer` dims, added last so they are outer-most.
134120
for (auto [i, o] : llvm::enumerate(layout.outer)) {
135121
if (o != 1) {
136-
// `layout.outer` means additional Internal dimensions, just like
137-
// `layout.element`, just swizzled outermost.
138-
unroll(swizzle, i, o, TileSwizzle::Dim::Kind::Internal);
122+
expand(swizzle, i, {Kind::Internal, o});
139123
}
140124
}
141-
142125
return swizzle;
143126
}
144127

145128
static int getInnermostNonInternalDimIdx(
146129
const TileSwizzle::ExpandShapeDimVectorType &shape) {
147130
for (int idx = shape.size() - 1; idx >= 0; --idx) {
148-
if (shape[idx].kind != TileSwizzle::Dim::Kind::Internal) {
131+
if (shape[idx].kind != Kind::Internal) {
149132
return idx;
150133
}
151134
}
@@ -156,55 +139,54 @@ static int getInnermostNonInternalDimIdx(
156139
TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
157140
IREE::GPU::MMAFragment fragment) {
158141
auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment);
159-
using Kind = TileSwizzle::Dim::Kind;
160142
switch (fragment) {
161143
case IREE::GPU::MMAFragment::Lhs:
162144
// A-matrix (LHS). Source dimensions are M (index 0) and K (index 1).
163145
// Unroll on K with interleaving, then on M.
164146
if (mma.getUnrollK() > 1) {
165-
unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
147+
expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollK()});
166148
int interleavingIdx =
167149
getInnermostNonInternalDimIdx(swizzle.expandShape[1]);
168150
interleave(swizzle, 1, interleavingIdx);
169151
}
170152
if (mma.getUnrollM() > 1) {
171-
unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
153+
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
172154
}
173155
if (mma.getUnrollMToSubgroups() > 1) {
174-
unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
156+
expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollMToSubgroups()});
175157
}
176158
break;
177159
case IREE::GPU::MMAFragment::Rhs:
178160
// B-matrix (RHS). Since the pack ops already took care of transposing B,
179161
// source dimensions are N (index 0) and K (index 1).
180162
// Unroll on K with interleaving, then on N.
181163
if (mma.getUnrollK() > 1) {
182-
unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
164+
expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollK()});
183165
int interleavingIdx =
184166
getInnermostNonInternalDimIdx(swizzle.expandShape[1]);
185167
interleave(swizzle, 1, interleavingIdx);
186168
}
187169
if (mma.getUnrollN() > 1) {
188-
unroll(swizzle, 0, mma.getUnrollN(), Kind::CrossIntrinsic);
170+
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollN()});
189171
}
190172
if (mma.getUnrollNToSubgroups() > 1) {
191-
unroll(swizzle, 0, mma.getUnrollNToSubgroups(), Kind::CrossThread);
173+
expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollNToSubgroups()});
192174
}
193175
break;
194176
case IREE::GPU::MMAFragment::Acc:
195177
// C-matrix (accumulator). Source dimensions are M (index 0) and N (index
196178
// 1). Unroll on N, then on M.
197179
if (mma.getUnrollN() > 1) {
198-
unroll(swizzle, 1, mma.getUnrollN(), Kind::CrossIntrinsic);
180+
expand(swizzle, 1, {Kind::CrossIntrinsic, mma.getUnrollN()});
199181
}
200182
if (mma.getUnrollNToSubgroups() > 1) {
201-
unroll(swizzle, 1, mma.getUnrollNToSubgroups(), Kind::CrossThread);
183+
expand(swizzle, 1, {Kind::CrossThread, mma.getUnrollNToSubgroups()});
202184
}
203185
if (mma.getUnrollM() > 1) {
204-
unroll(swizzle, 0, mma.getUnrollM(), Kind::CrossIntrinsic);
186+
expand(swizzle, 0, {Kind::CrossIntrinsic, mma.getUnrollM()});
205187
}
206188
if (mma.getUnrollMToSubgroups() > 1) {
207-
unroll(swizzle, 0, mma.getUnrollMToSubgroups(), Kind::CrossThread);
189+
expand(swizzle, 0, {Kind::CrossThread, mma.getUnrollMToSubgroups()});
208190
}
209191
break;
210192
}

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h

Lines changed: 1 addition & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -19,40 +19,10 @@ TileSwizzle getIntrinsicSwizzle(IREE::GPU::MMAIntrinsic intrinsic,
1919
IREE::GPU::MMAFragment fragment);
2020

2121
// Returns the swizzle for the full data-tiled-mma tile, including all the
22-
// relevant unrolling factors.
22+
// relevant unrolling and expansion factors.
2323
TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
2424
IREE::GPU::MMAFragment fragment);
2525

26-
// Unrolls the dimension given by `srcIndex` by the given `unrollFactor`.
27-
// This is not interleaving layouts. The layout will consist of multiple copies
28-
// of the input tile, side by side.
29-
//
30-
// The enum parameter `kind` initializes the corresponding member on the newly
31-
// created TileSwizzle::Dim.
32-
//
33-
// Example:
34-
// Input swizzle = { expandShape = [[16], [4]], permutation = [1, 0] }
35-
// Input srcIndex = 1
36-
// Input unrollFactor = 4
37-
// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
38-
//
39-
void unroll(TileSwizzle &swizzle, int srcIndex, int unrollFactor,
40-
TileSwizzle::Dim::Kind kind);
41-
42-
// Interleaves the layout in `swizzle` by mutating `swizzle.permutation` to
43-
// move permutation[0], the outer-most dimension (which the unroll() function
44-
// created to be the unrolling dimension), to the inner dimension given by
45-
// `expandedDimIndexToInterleaveAt`.
46-
//
47-
// Example:
48-
// Input swizzle = { expandShape = [[16], [4, 4]], permutation = [1, 2, 0] }
49-
// Input srcIndex = 1
50-
// Input expandedDimIndexToInterleaveAt = 1
51-
// -> Output swizzle = { expandShape = [[16], [4, 4]], permutation = [2, 0, 1] }
52-
//
53-
void interleave(TileSwizzle &swizzle, int srcIndex,
54-
int expandedDimIndexToInterleaveAt);
55-
5626
} // namespace mlir::iree_compiler
5727

5828
#endif // IREE_COMPILER_CODEGEN_DIALECT_GPU_IR_GPUTILESWIZZLEUTILS_H_

compiler/src/iree/compiler/Codegen/LLVMGPU/test/amdgpu_set_anchor_layouts.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ builtin.module attributes { transform.with_named_sequence } {
7676
%c0 = arith.constant 0 : index
7777
%cst_0 = arith.constant 0.0 : f16
7878
%lhs = vector.transfer_read %a[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
79-
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}}
79+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, VECTORX], [1, 16]>>}}
8080
%rhs = vector.transfer_read %b[%c0, %c0], %cst_0 {in_bounds = [true, true]} : memref<16x16xf16>, vector<16x16xf16>
81-
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, LANEY, VECTORX], [1, 1, 16]>>}}
81+
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, LANEX], [1, 16]>, <[ BATCHY, VECTORX], [1, 16]>>}}
8282
%output = vector.contract {indexing_maps = [#map1, #map2, #map3], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %lhs, %rhs, %init : vector<16x16xf16>, vector<16x16xf16> into vector<16x16xf32>
8383
// expected-remark @above {{layout of result #0 is #iree_vector_ext.layout<<[ BATCHX, VECTORX, LANEY], [1, 8, 2]>, <[ BATCHY, LANEX], [1, 16]>>}}
8484
return %output : vector<16x16xf32>

0 commit comments

Comments
 (0)