Skip to content

Commit 03c744e

Browse files
authored
[GPU] Support multiple contraction dims in MmaSchedules (#18720)
This adds support for multiple M, N, and K dims in problems when deducing a GPUMMASchedule. The new heuristic is similar to the old one, but works on pairs of M and N dims. For example: ``` tensor<M1xM0xK1xK0> * tensor<N1xN0xK1xK0> -> tensor<M1xN1xM0xN0> ``` This will try to distribute the seeded tile counts to `M0` and `N0` (first attempting to distribute evenly, and then distributing to N followed by N), and then distribute the residual counts to `M1` and `N1`. The K tile counts will be partitioned to `K0` first, and then the residual tile counts will be partitioned to `K1`. This PR also updates the config selection logic for the TileAndFuse pipeline to make use of the multiple contraction dimensions in mma schedules. --------- Signed-off-by: Max Dawkins <[email protected]>
1 parent 0c2c627 commit 03c744e

File tree

8 files changed

+492
-232
lines changed

8 files changed

+492
-232
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.cpp

Lines changed: 239 additions & 122 deletions
Large diffs are not rendered by default.

compiler/src/iree/compiler/Codegen/Common/GPU/GPUHeuristics.h

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,18 @@ namespace mlir::iree_compiler {
1010

1111
/// Struct containing information about a matmul's shape and type.
1212
struct GPUMatmulShapeType {
13-
int64_t mSize;
14-
int64_t nSize;
15-
int64_t kSize;
13+
SmallVector<int64_t> mSizes;
14+
SmallVector<int64_t> nSizes;
15+
SmallVector<int64_t> kSizes;
1616
Type aType;
1717
Type bType;
1818
Type cType;
1919

2020
GPUMatmulShapeType(int64_t m, int64_t n, int64_t k, Type a, Type b, Type c)
21-
: mSize(m), nSize(n), kSize(k), aType(a), bType(b), cType(c) {}
21+
: mSizes({m}), nSizes({n}), kSizes({k}), aType(a), bType(b), cType(c) {}
22+
GPUMatmulShapeType(SmallVector<int64_t> m, SmallVector<int64_t> n,
23+
SmallVector<int64_t> k, Type a, Type b, Type c)
24+
: mSizes(m), nSizes(n), kSizes(k), aType(a), bType(b), cType(c) {}
2225
};
2326

2427
/// Struct containing seed tile sizes for GPU MMA heuristics deduction logic.
@@ -38,14 +41,42 @@ struct GPUMMAHeuristicSeeds {
3841
struct GPUMMASchedule {
3942
// Index of the chosen intrinsic into the list of given MMA intrinsics
4043
uint64_t index;
41-
int64_t mSize; // Native MMA size along M dimension
42-
int64_t nSize; // Native MMA size along N dimension
43-
int64_t kSize; // Native MMA size along K dimension
44-
int64_t mWarpCount; // Number of subgroups along M dimension
45-
int64_t nWarpCount; // Number of subgroups along N dimension
46-
int64_t mTileCount; // Number of tiles per subgroup along M dimension
47-
int64_t nTileCount; // Number of tiles per subgroup along N dimension
48-
int64_t kTileCount; // Number of tiles along K dimension
44+
int64_t mSize; // Native MMA intrinsic size along M dimension for a subgroup.
45+
int64_t nSize; // Native MMA intrinsic size along N dimension for a subgroup.
46+
int64_t kSize; // Native MMA intrinsic size along K dimension for a subgroup.
47+
48+
// Number of subgroups along each M and N dimension.
49+
SmallVector<int64_t> mSubgroupCounts;
50+
SmallVector<int64_t> nSubgroupCounts;
51+
52+
// Tile sizes for each M, N, and K dimension. When there are multiple M, N,
53+
// or K dimensions, the intrinsic sizes are targeted to the innermost
54+
// dimension, and the outer dimensions can be thought of as unrolling factors
55+
// along M, N, or K.
56+
SmallVector<int64_t> mTileSizes; // M tile sizes per subgroup.
57+
SmallVector<int64_t> nTileSizes; // N tile sizes per subgroup.
58+
SmallVector<int64_t> kTileSizes; // K tile sizes.
59+
60+
// Constructor for multi M, N, K dim schedules.
61+
GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize,
62+
int64_t kIntrinsicSize, SmallVector<int64_t> mSubgroupCounts,
63+
SmallVector<int64_t> nSubgroupCounts,
64+
SmallVector<int64_t> mTileSizes,
65+
SmallVector<int64_t> nTileSizes,
66+
SmallVector<int64_t> kTileSizes)
67+
: index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize),
68+
kSize(kIntrinsicSize), mSubgroupCounts(mSubgroupCounts),
69+
nSubgroupCounts(nSubgroupCounts), mTileSizes(mTileSizes),
70+
nTileSizes(nTileSizes), kTileSizes(kTileSizes) {}
71+
72+
// Constructor for single M, N, K dim schedules.
73+
GPUMMASchedule(uint64_t i, int64_t mIntrinsicSize, int64_t nIntrinsicSize,
74+
int64_t kIntrinsicSize, int64_t mSubgroup, int64_t nSubgroup,
75+
int64_t mTileSize, int64_t nTileSize, int64_t kTileSize)
76+
: index(i), mSize(mIntrinsicSize), nSize(nIntrinsicSize),
77+
kSize(kIntrinsicSize), mSubgroupCounts({mSubgroup}),
78+
nSubgroupCounts({nSubgroup}), mTileSizes({mTileSize}),
79+
nTileSizes({nTileSize}), kTileSizes({kTileSize}) {}
4980
};
5081

5182
/// Returns a schedule for using one of the given MMA |intrinsics| to target the
@@ -69,4 +100,7 @@ FailureOr<GPUMMASchedule> deduceAttentionSchedule(
69100
bool transposedV = false, bool canUpcastAcc = false,
70101
bool mustBeAligned = true);
71102

103+
llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
104+
const GPUMMASchedule &schedule);
105+
72106
} // namespace mlir::iree_compiler

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
1414
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
1515
#include "iree/compiler/Codegen/Utils/Utils.h"
16+
#include "llvm/ADT/STLExtras.h"
1617
#include "llvm/Support/Casting.h"
1718
#include "llvm/Support/Debug.h"
1819
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -124,20 +125,37 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
124125
return failure();
125126
}
126127

127-
// For now we are not being smart and trying to reshape dimensions to allow
128-
// for better usage of intrinsics, and instead are tiling all dimensions
129-
// except the inner most m, n, and k dimensions to 1.
130-
int64_t mDim = contractionDims.m.back();
131-
int64_t nDim = contractionDims.n.back();
132-
int64_t kDim = contractionDims.k.back();
133-
134-
// Dynamic dims are expected to be taken care of earlier in the pipeline.
135-
if (ShapedType::isDynamic(bounds[mDim]) ||
136-
ShapedType::isDynamic(bounds[nDim]) ||
137-
ShapedType::isDynamic(bounds[kDim])) {
128+
// TODO(Max191): add dynamic shape support for inner most dims.
129+
if (ShapedType::isDynamic(bounds[contractionDims.m.back()]) ||
130+
ShapedType::isDynamic(bounds[contractionDims.n.back()]) ||
131+
ShapedType::isDynamic(bounds[contractionDims.k.back()])) {
138132
return failure();
139133
}
140134

135+
// Gather all static M, N, and K dimensions to deduce the MMASchedule. Dynamic
136+
// dimensions will be tiled to 1 in workgroup tiling, so they are ignored when
137+
// computing an MMA schedule.
138+
SmallVector<int64_t> mDims, nDims, kDims;
139+
for (auto mDim : contractionDims.m) {
140+
if (!ShapedType::isDynamic(bounds[mDim])) {
141+
mDims.push_back(mDim);
142+
}
143+
}
144+
for (auto nDim : contractionDims.n) {
145+
if (!ShapedType::isDynamic(bounds[nDim])) {
146+
nDims.push_back(nDim);
147+
}
148+
}
149+
for (auto kDim : contractionDims.k) {
150+
if (!ShapedType::isDynamic(bounds[kDim])) {
151+
kDims.push_back(kDim);
152+
}
153+
}
154+
155+
auto getDimBounds = [&](SmallVector<int64_t> dims) -> SmallVector<int64_t> {
156+
return llvm::map_to_vector(dims, [&](int64_t dim) { return bounds[dim]; });
157+
};
158+
141159
Value lhs = linalgOp.getDpsInputOperand(0)->get();
142160
Value rhs = linalgOp.getDpsInputOperand(1)->get();
143161
Value init = linalgOp.getDpsInitOperand(0)->get();
@@ -146,8 +164,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
146164
Type rhsElemType = getElementTypeOrSelf(rhs);
147165
Type initElemType = getElementTypeOrSelf(init);
148166

149-
GPUMatmulShapeType problem{bounds[mDim], bounds[nDim], bounds[kDim],
150-
lhsElemType, rhsElemType, initElemType};
167+
GPUMatmulShapeType problem{getDimBounds(mDims), getDimBounds(nDims),
168+
getDimBounds(kDims), lhsElemType,
169+
rhsElemType, initElemType};
151170

152171
SmallVector<GPUMatmulShapeType> intrinsics;
153172
for (IREE::GPU::MMAAttr mma : target.getWgp().getMma()) {
@@ -166,7 +185,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
166185
// Note that the following heuristic seeds are just placeholder values.
167186
// We need to clean it up and make it adjusting to different targets.
168187
// See https://github.com/iree-org/iree/issues/16341 for details.
169-
if (problem.mSize * problem.nSize <= 512 * 512) {
188+
int64_t mSize = ShapedType::getNumElements(problem.mSizes);
189+
int64_t nSize = ShapedType::getNumElements(problem.nSizes);
190+
if (mSize * nSize <= 512 * 512) {
170191
// For matmuls with small M*N size, we want to distribute M*N onto more
171192
// workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
172193
// and a larger bestKTileCountPerSubgroup.
@@ -190,10 +211,10 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
190211
// TODO: Drop this. This is only a consideration for other pipelines.
191212
SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray();
192213
bool transposedLhs =
193-
kDim !=
214+
kDims.back() !=
194215
llvm::cast<AffineDimExpr>(maps[0].getResults().back()).getPosition();
195216
bool transposedRhs =
196-
nDim !=
217+
nDims.back() !=
197218
llvm::cast<AffineDimExpr>(maps[1].getResults().back()).getPosition();
198219

199220
// First try to find a schedule with an exactly matching intrinsic.
@@ -213,16 +234,13 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
213234
}
214235

215236
LDBG("Target Subgroup size: " << targetSubgroupSize);
216-
LDBG("Schedule: sizes [" << schedule->mSize << ", " << schedule->nSize << ", "
217-
<< schedule->kSize << "]");
218-
LDBG("Schedule: tile counts [" << schedule->mTileCount << ", "
219-
<< schedule->nTileCount << ", "
220-
<< schedule->kTileCount << "]");
221-
LDBG("Schedule: warp counts [" << schedule->mWarpCount << ", "
222-
<< schedule->nWarpCount << "]");
237+
LDBG("Schedule: " << schedule);
223238

224-
std::array<int64_t, 3> workgroupSize{
225-
schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount, 1};
239+
int64_t flatWorkgroupSize =
240+
targetSubgroupSize *
241+
ShapedType::getNumElements(schedule->nSubgroupCounts) *
242+
ShapedType::getNumElements(schedule->mSubgroupCounts);
243+
std::array<int64_t, 3> workgroupSize{flatWorkgroupSize, 1, 1};
226244

227245
SmallVector<int64_t> workgroupTileSizes(linalgOp.getNumLoops(), 0);
228246
SmallVector<int64_t> reductionTileSizes(linalgOp.getNumLoops(), 0);
@@ -244,16 +262,30 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
244262
reductionTileSizes[k] = 1;
245263
}
246264

247-
// Compute the M/N dimension tile size by multiplying subgroup information.
248-
workgroupTileSizes[mDim] = schedule->mWarpCount * schedule->mTileCount;
249-
workgroupTileSizes[nDim] = schedule->nWarpCount * schedule->nTileCount;
250-
251-
// Specify the subgroup tile sizes from the mma schedule. This is applied
252-
subgroupTileSizes[mDim] = schedule->mTileCount;
253-
subgroupTileSizes[nDim] = schedule->nTileCount;
265+
// Adjust the inner bound size for packing to intrinsic shapes, since tiling
266+
// happens after packing.
267+
assert(bounds[mDims.back()] % schedule->mSize == 0 &&
268+
bounds[nDims.back()] % schedule->nSize == 0 &&
269+
"expected inner bound to be evenly divisible by schedule sizes.");
270+
bounds[mDims.back()] /= schedule->mSize;
271+
bounds[nDims.back()] /= schedule->nSize;
272+
273+
// Compute the M/N dimension tile sizes by multiplying subgroup information.
274+
for (auto [i, mDim] : llvm::enumerate(mDims)) {
275+
workgroupTileSizes[mDim] =
276+
schedule->mSubgroupCounts[i] * schedule->mTileSizes[i];
277+
subgroupTileSizes[mDim] = schedule->mTileSizes[i];
278+
}
279+
for (auto [i, nDim] : llvm::enumerate(nDims)) {
280+
workgroupTileSizes[nDim] =
281+
schedule->nSubgroupCounts[i] * schedule->nTileSizes[i];
282+
subgroupTileSizes[nDim] = schedule->nTileSizes[i];
283+
}
254284

255285
// Similarly the reduction tile size is just the post-packing tile count.
256-
reductionTileSizes[kDim] = schedule->kTileCount;
286+
for (auto [i, kDim] : llvm::enumerate(kDims)) {
287+
reductionTileSizes[kDim] = schedule->kTileSizes[i];
288+
}
257289

258290
IREE::GPU::MmaInterfaceAttr mmaKind =
259291
target.getWgp().getMma()[schedule->index];

0 commit comments

Comments
 (0)