1313#include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
1414#include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
1515#include " iree/compiler/Codegen/Utils/Utils.h"
16+ #include " llvm/ADT/STLExtras.h"
1617#include " llvm/Support/Casting.h"
1718#include " llvm/Support/Debug.h"
1819#include " mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
@@ -124,20 +125,37 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
124125 return failure ();
125126 }
126127
127- // For now we are not being smart and trying to reshape dimensions to allow
128- // for better usage of intrinsics, and instead are tiling all dimensions
129- // except the inner most m, n, and k dimensions to 1.
130- int64_t mDim = contractionDims.m .back ();
131- int64_t nDim = contractionDims.n .back ();
132- int64_t kDim = contractionDims.k .back ();
133-
134- // Dynamic dims are expected to be taken care of earlier in the pipeline.
135- if (ShapedType::isDynamic (bounds[mDim ]) ||
136- ShapedType::isDynamic (bounds[nDim]) ||
137- ShapedType::isDynamic (bounds[kDim ])) {
128+ // TODO(Max191): add dynamic shape support for inner most dims.
129+ if (ShapedType::isDynamic (bounds[contractionDims.m .back ()]) ||
130+ ShapedType::isDynamic (bounds[contractionDims.n .back ()]) ||
131+ ShapedType::isDynamic (bounds[contractionDims.k .back ()])) {
138132 return failure ();
139133 }
140134
135+ // Gather all static M, N, and K dimensions to deduce the MMASchedule. Dynamic
136+ // dimensions will be tiled to 1 in workgroup tiling, so they are ignored when
137+ // computing an MMA schedule.
138+ SmallVector<int64_t > mDims , nDims, kDims ;
139+ for (auto mDim : contractionDims.m ) {
140+ if (!ShapedType::isDynamic (bounds[mDim ])) {
141+ mDims .push_back (mDim );
142+ }
143+ }
144+ for (auto nDim : contractionDims.n ) {
145+ if (!ShapedType::isDynamic (bounds[nDim])) {
146+ nDims.push_back (nDim);
147+ }
148+ }
149+ for (auto kDim : contractionDims.k ) {
150+ if (!ShapedType::isDynamic (bounds[kDim ])) {
151+ kDims .push_back (kDim );
152+ }
153+ }
154+
155+ auto getDimBounds = [&](SmallVector<int64_t > dims) -> SmallVector<int64_t > {
156+ return llvm::map_to_vector (dims, [&](int64_t dim) { return bounds[dim]; });
157+ };
158+
141159 Value lhs = linalgOp.getDpsInputOperand (0 )->get ();
142160 Value rhs = linalgOp.getDpsInputOperand (1 )->get ();
143161 Value init = linalgOp.getDpsInitOperand (0 )->get ();
@@ -146,8 +164,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
146164 Type rhsElemType = getElementTypeOrSelf (rhs);
147165 Type initElemType = getElementTypeOrSelf (init);
148166
149- GPUMatmulShapeType problem{bounds[mDim ], bounds[nDim], bounds[kDim ],
150- lhsElemType, rhsElemType, initElemType};
167+ GPUMatmulShapeType problem{getDimBounds (mDims ), getDimBounds (nDims),
168+ getDimBounds (kDims ), lhsElemType,
169+ rhsElemType, initElemType};
151170
152171 SmallVector<GPUMatmulShapeType> intrinsics;
153172 for (IREE::GPU::MMAAttr mma : target.getWgp ().getMma ()) {
@@ -166,7 +185,9 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
166185 // Note that the following heuristic seeds are just placeholder values.
167186 // We need to clean it up and make it adjusting to different targets.
168187 // See https://github.com/iree-org/iree/issues/16341 for details.
169- if (problem.mSize * problem.nSize <= 512 * 512 ) {
188+ int64_t mSize = ShapedType::getNumElements (problem.mSizes );
189+ int64_t nSize = ShapedType::getNumElements (problem.nSizes );
190+ if (mSize * nSize <= 512 * 512 ) {
170191 // For matmuls with small M*N size, we want to distribute M*N onto more
171192 // workgroups to fill the GPU. Use a smaller bestMNTileCountPerSubgroup
172193 // and a larger bestKTileCountPerSubgroup.
@@ -190,10 +211,10 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
190211 // TODO: Drop this. This is only a consideration for other pipelines.
191212 SmallVector<AffineMap> maps = linalgOp.getIndexingMapsArray ();
192213 bool transposedLhs =
193- kDim !=
214+ kDims . back () !=
194215 llvm::cast<AffineDimExpr>(maps[0 ].getResults ().back ()).getPosition ();
195216 bool transposedRhs =
196- nDim !=
217+ nDims. back () !=
197218 llvm::cast<AffineDimExpr>(maps[1 ].getResults ().back ()).getPosition ();
198219
199220 // First try to find a schedule with an exactly matching intrinsic.
@@ -213,16 +234,13 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
213234 }
214235
215236 LDBG (" Target Subgroup size: " << targetSubgroupSize);
216- LDBG (" Schedule: sizes [" << schedule->mSize << " , " << schedule->nSize << " , "
217- << schedule->kSize << " ]" );
218- LDBG (" Schedule: tile counts [" << schedule->mTileCount << " , "
219- << schedule->nTileCount << " , "
220- << schedule->kTileCount << " ]" );
221- LDBG (" Schedule: warp counts [" << schedule->mWarpCount << " , "
222- << schedule->nWarpCount << " ]" );
237+ LDBG (" Schedule: " << schedule);
223238
224- std::array<int64_t , 3 > workgroupSize{
225- schedule->nWarpCount * targetSubgroupSize, schedule->mWarpCount , 1 };
239+ int64_t flatWorkgroupSize =
240+ targetSubgroupSize *
241+ ShapedType::getNumElements (schedule->nSubgroupCounts ) *
242+ ShapedType::getNumElements (schedule->mSubgroupCounts );
243+ std::array<int64_t , 3 > workgroupSize{flatWorkgroupSize, 1 , 1 };
226244
227245 SmallVector<int64_t > workgroupTileSizes (linalgOp.getNumLoops (), 0 );
228246 SmallVector<int64_t > reductionTileSizes (linalgOp.getNumLoops (), 0 );
@@ -244,16 +262,30 @@ LogicalResult setMatmulLoweringConfig(IREE::GPU::TargetAttr target,
244262 reductionTileSizes[k] = 1 ;
245263 }
246264
247- // Compute the M/N dimension tile size by multiplying subgroup information.
248- workgroupTileSizes[mDim ] = schedule->mWarpCount * schedule->mTileCount ;
249- workgroupTileSizes[nDim] = schedule->nWarpCount * schedule->nTileCount ;
250-
251- // Specify the subgroup tile sizes from the mma schedule. This is applied
252- subgroupTileSizes[mDim ] = schedule->mTileCount ;
253- subgroupTileSizes[nDim] = schedule->nTileCount ;
265+ // Adjust the inner bound size for packing to intrinsic shapes, since tiling
266+ // happens after packing.
267+ assert (bounds[mDims .back ()] % schedule->mSize == 0 &&
268+ bounds[nDims.back ()] % schedule->nSize == 0 &&
269+ " expected inner bound to be evenly divisible by schedule sizes." );
270+ bounds[mDims .back ()] /= schedule->mSize ;
271+ bounds[nDims.back ()] /= schedule->nSize ;
272+
273+ // Compute the M/N dimension tile sizes by multiplying subgroup information.
274+ for (auto [i, mDim ] : llvm::enumerate (mDims )) {
275+ workgroupTileSizes[mDim ] =
276+ schedule->mSubgroupCounts [i] * schedule->mTileSizes [i];
277+ subgroupTileSizes[mDim ] = schedule->mTileSizes [i];
278+ }
279+ for (auto [i, nDim] : llvm::enumerate (nDims)) {
280+ workgroupTileSizes[nDim] =
281+ schedule->nSubgroupCounts [i] * schedule->nTileSizes [i];
282+ subgroupTileSizes[nDim] = schedule->nTileSizes [i];
283+ }
254284
255285 // Similarly the reduction tile size is just the post-packing tile count.
256- reductionTileSizes[kDim ] = schedule->kTileCount ;
286+ for (auto [i, kDim ] : llvm::enumerate (kDims )) {
287+ reductionTileSizes[kDim ] = schedule->kTileSizes [i];
288+ }
257289
258290 IREE::GPU::MmaInterfaceAttr mmaKind =
259291 target.getWgp ().getMma ()[schedule->index ];
0 commit comments