@@ -206,75 +206,60 @@ static std::optional<GPUMMASchedule> getMmaScheduleFromProblemAndTarget(
206206
207207// / Helper function to get convolution padding sizes if possible.
208208static std::optional<ArrayAttr> getPaddingConvSizes (
209- Builder &b, int64_t kSize , int64_t kPaddingSize ,
209+ Builder &b, const SmallVector<int64_t > &bounds,
210+ const SmallVector<int64_t > &paddingSizes,
210211 const SmallVector<int64_t > &workgroupTileSizes,
211- const SmallVector<int64_t > &mDims , const SmallVector< int64_t > &nDims ,
212- const SmallVector< int64_t > &batchDims ,
213- std::optional<mlir::linalg::ConvolutionDimensions> &padConvDims ) {
214- if (!padConvDims .has_value ())
212+ const SmallVector<int64_t > &reductionTileSizes ,
213+ std::optional<DenseMap< int64_t , AffineExpr>> &convToIgemmDimMap ,
214+ std::optional<mlir::linalg::ConvolutionDimensions> &convDims ) {
215+ if (!convToIgemmDimMap. has_value () || !convDims .has_value ())
215216 return std::nullopt ;
216217
217- SmallVector<unsigned > batchAndImageDims;
218- mlir::linalg::ConvolutionDimensions convDims = padConvDims.value ();
219- bool isBatchLast = !convDims.batch .empty () &&
220- convDims.outputImage .back () < convDims.batch .front ();
221- if (isBatchLast) {
222- batchAndImageDims.append (convDims.outputImage .begin (),
223- convDims.outputImage .end ());
224- batchAndImageDims.append (convDims.batch .begin (), convDims.batch .end ());
225- } else {
226- batchAndImageDims.append (convDims.batch .begin (), convDims.batch .end ());
227- batchAndImageDims.append (convDims.outputImage .begin (),
228- convDims.outputImage .end ());
229- }
230-
231- SmallVector<unsigned > concatMDims, concatNDims;
232- bool isOutputChannelFirst =
233- convDims.outputChannel .back () < convDims.outputImage .front ();
234- if (isOutputChannelFirst) {
235- concatMDims.append (convDims.outputChannel .begin (),
236- convDims.outputChannel .end ());
237- concatNDims = batchAndImageDims;
238- } else {
239- concatMDims = batchAndImageDims;
240- concatNDims.append (convDims.outputChannel .begin (),
241- convDims.outputChannel .end ());
242- }
243-
244- // Verify that the number of M, N dimensions from IGEMM match the
245- // corresponding number of convolution dimensions.
246- if (concatMDims.size () != mDims .size () ||
247- concatNDims.size () != nDims.size () ||
248- convDims.depth .size () != batchDims.size ()) {
249- return std::nullopt ;
250- }
251-
218+ DenseMap<int64_t , AffineExpr> convToIgemmMap = convToIgemmDimMap.value ();
252219 // Padding sizes for parallel dimensions are the same as workgroup tile
253220 // sizes.
254- int64_t totalNumDims = convDims.batch .size () + convDims.outputImage .size () +
255- convDims.outputChannel .size () +
256- convDims.filterLoop .size () +
257- convDims.inputChannel .size () + convDims.depth .size ();
258- SmallVector<int64_t > paddingConvSizes (totalNumDims, 0 );
259- if (batchDims.size () != 0 ) {
260- for (auto [dim, bDim] : llvm::zip (convDims.depth , batchDims)) {
261- paddingConvSizes[dim] = workgroupTileSizes[bDim];
221+ DenseSet<int64_t > paddedIGEMMDims;
222+ DenseMap<int64_t , SmallVector<int64_t >> paddedReductionConvDims;
223+ SetVector<int64_t > inputChannelDims (convDims->inputChannel .begin (),
224+ convDims->inputChannel .end ());
225+ SmallVector<int64_t > paddingConvSizes (convToIgemmMap.size (), 0 );
226+ for (auto [convDim, IGEMMExpr] : convToIgemmMap) {
227+ auto IGEMMDimExpr = cast<AffineDimExpr>(IGEMMExpr);
228+ unsigned IGEMMPos = IGEMMDimExpr.getPosition ();
229+ if (reductionTileSizes[IGEMMPos] != 0 ) {
230+ // For reduction dimensions, avoid setting padding on the convolution
231+ // if the product of the corresponding conv sizes are already divisible
232+ // by the padding size.
233+ if (paddingSizes[IGEMMPos] &&
234+ bounds[IGEMMPos] % paddingSizes[IGEMMPos] == 0 ) {
235+ paddedIGEMMDims.insert (IGEMMPos);
236+ continue ;
237+ }
238+ // Only pad input channel dims. If we need to pad filter dims, then we
239+ // would rather just do padding on the GEMM instead.
240+ if (inputChannelDims.contains (convDim)) {
241+ // Multiple input channel dims for a single IGEMMPos is not supported.
242+ if (paddedIGEMMDims.contains (IGEMMPos)) {
243+ return std::nullopt ;
244+ }
245+ paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
246+ paddedIGEMMDims.insert (IGEMMPos);
247+ }
248+ continue ;
262249 }
250+ // Multiple padded parallel dims mapping to the same IGEMM dim is not
251+ // supported.
252+ if (workgroupTileSizes[IGEMMPos] != 0 &&
253+ paddedIGEMMDims.contains (IGEMMPos)) {
254+ return std::nullopt ;
255+ }
256+ paddingConvSizes[convDim] = paddingSizes[IGEMMPos];
257+ paddedIGEMMDims.insert (IGEMMPos);
263258 }
264- for (auto [dim, mDim ] : llvm::zip (concatMDims, mDims ))
265- paddingConvSizes[dim] = workgroupTileSizes[mDim ];
266- for (auto [dim, nDim] : llvm::zip (concatNDims, nDims))
267- paddingConvSizes[dim] = workgroupTileSizes[nDim];
268-
269- // To avoid over-padding, no padding for channel dimensions is needed if
270- // the product of reduction sizes is already multiples of k padding
271- // size. Otherwise, pad the innermost channel dimension.
272- // TODO (vivian): Padding the innermost channel dimension to a multiple
273- // of vector size may still be needed even if the K-dim is aligned, and
274- // this should be validated based on performance.
275- if (kSize % kPaddingSize != 0 ) {
276- int64_t innerChannelDim = convDims.inputChannel .back ();
277- paddingConvSizes[innerChannelDim] = kPaddingSize ;
259+
260+ // Ensure that all dimensions have been padded.
261+ if (paddedIGEMMDims.size () != paddingSizes.size ()) {
262+ return std::nullopt ;
278263 }
279264 return b.getI64ArrayAttr (paddingConvSizes);
280265}
@@ -291,7 +276,9 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
291276 SmallVector<int64_t > bounds, ArrayRef<AffineMap> maps,
292277 ArrayRef<Value> operands, IREE::GPU::TargetAttr target, bool useDirectLoad,
293278 bool scaled,
294- std::optional<mlir::linalg::ConvolutionDimensions> padConvDims = {}) {
279+ std::optional<DenseMap<int64_t , AffineExpr>> convToIgemmDimMap =
280+ std::nullopt ,
281+ std::optional<linalg::ConvolutionDimensions> convDims = std::nullopt ) {
295282 if (target.getWgp ().getMma ().empty ()) {
296283 return failure ();
297284 }
@@ -537,9 +524,9 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
537524
538525 // Create `padding_conv` attribute when padding convolutions before IGEMM is
539526 // possible, otherwise fallback to pad IGEMM.
540- if (auto attr = getPaddingConvSizes (
541- b, bounds[innerKDim], paddingTileSizes[innerKDim] ,
542- workgroupTileSizes, mDims , nDims, batchDims, padConvDims )) {
527+ if (auto attr = getPaddingConvSizes (b, bounds, paddingTileSizes,
528+ workgroupTileSizes, reductionTileSizes ,
529+ convToIgemmDimMap, convDims )) {
543530 attrs.emplace_back (StringAttr::get (context, " padding_conv" ), *attr);
544531 } else {
545532 attrs.emplace_back (StringAttr::get (context, " padding" ),
@@ -580,15 +567,18 @@ LogicalResult setIGEMMConvolutionLoweringConfig(
580567 igemmGenericConvDetails->igemmLoopBounds ;
581568 SmallVector<Value> igemmOperands = igemmGenericConvDetails->igemmOperands ;
582569
583- std::optional<mlir::linalg::ConvolutionDimensions> padConvDims;
584- if (padConv)
585- padConvDims = igemmGenericConvDetails->convDims ;
570+ std::optional<DenseMap<int64_t , AffineExpr>> convToIgemmDimMap;
571+ std::optional<linalg::ConvolutionDimensions> convDims;
572+ if (padConv) {
573+ convDims = igemmGenericConvDetails->convDims ;
574+ convToIgemmDimMap = igemmGenericConvDetails->convToIgemmDimMap ;
575+ }
586576
587577 SmallVector<int64_t > bounds = igemmLoopBounds;
588578 FailureOr<std::pair<LoweringConfigAttr, int64_t >> configAndWgSize =
589579 getMatmulOrIGEMMLoweringConfigAndWorkgroupSize (
590580 bounds, igemmContractionMaps, igemmOperands, target, useDirectLoad,
591- /* scaled*/ false , padConvDims );
581+ /* scaled*/ false , convToIgemmDimMap, convDims );
592582 if (failed (configAndWgSize)) {
593583 return failure ();
594584 }
0 commit comments