Skip to content

Commit 1a9417d

Browse files
Add lit test and clean up
1 parent f1b8e80 commit 1a9417d

File tree

3 files changed

+627
-2
lines changed

3 files changed

+627
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
237237
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238238
}
239239

240+
/// Utility to create a `genericOp` with a convolution op of type `ConvOpTy` with `dilations` and `strides`.
240241
template <typename ConvOpTy>
241242
static FailureOr<LinalgOp> specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef<int64_t> dilations, ArrayRef<int64_t> strides) {
242243
SmallVector<Value> inputs = genericOp.getDpsInputs();
@@ -380,7 +381,7 @@ static FailureOr<LinalgOp> inferAndSpecializeBasedOnRank9ConvIteratorTypes(Rewri
380381
return failure();
381382
}
382383

383-
// Converts linalg.generic to named linalg.*conv* where possible.
384+
// Converts linalg.generic to named linalg.*conv/pooling* where possible. To improve the search speed, the convolution ops have been segregated based on the rank of iterator types array.
384385
static FailureOr<LinalgOp> inferAndSpecializeToConvolutionOp(RewriterBase &rewriter, GenericOp genericOp) {
385386
SmallVector<utils::IteratorType> iteratorTypes = genericOp.getIteratorTypesArray();
386387
unsigned totalIterators = iteratorTypes.size();
@@ -483,7 +484,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
483484
return specializeLinalgContractions(rewriter, genericOp);
484485
}
485486

486-
// Convolution - e.g. *conv*
487+
// Convolution - e.g. *conv/pooling*
487488
if (isaConvolutionOpInterface(genericOp)) {
488489
return inferAndSpecializeToConvolutionOp(rewriter, genericOp);
489490
}

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_
319319
return false;
320320
}
321321

322+
/// Given an array of AffineMaps `indexingMaps` verify the following :-
323+
/// indexingMaps[0].getResult(iDim) ==
324+
/// indexingMaps[1].getResult(fDim) * <CST_1> +
325+
/// indexingMaps[n-1].getResult(oDim) * <CST_2>
326+
/// where, CST_1 and CST_2 can be any constant.
322327
static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim,
323328
int64_t& dilation, int64_t& stride) {
324329
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
@@ -348,10 +353,13 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, un
348353
return false;
349354
}
350355

356+
/// Given an array of AffineMaps `indexingMaps` verify the following :-
357+
/// indexingMaps[aIndex].getResult(aDim) == indexingMaps[bIndex].getResult(bDim)
351358
static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
352359
return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
353360
}
354361

362+
/// Give an array of AffineMaps, verify each map to be of the corresponding `expectedSize`.
355363
static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t> expectedSizes) {
356364
if (indexingMaps.size() != expectedSizes.size()) return false;
357365

@@ -362,6 +370,7 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t>
362370
return true;
363371
}
364372

373+
/// Utility to update `dilations` and `strides` by copy the corresponding data from `tempDilations` and `tempStrides`.
365374
static bool updateConvDilationsAndStrides(SmallVector<int64_t>* dilations, SmallVector<int64_t>* strides, ArrayRef<int64_t> tempDilations, ArrayRef<int64_t> tempStrides) {
366375
if (!(dilations && strides))
367376
return true;

0 commit comments

Comments
 (0)