Skip to content

Commit bafdb41

Browse files
Start pulling out into separate APIs
1 parent aae7e04 commit bafdb41

File tree

3 files changed

+233
-29
lines changed

3 files changed

+233
-29
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,16 @@ std::optional<SmallVector<ReassociationIndices>>
111111
getReassociationMapForFoldingUnitDims(ArrayRef<OpFoldResult> mixedSizes);
112112

113113
//===----------------------------------------------------------------------===//
114-
// Fusion / Tiling utilities
114+
// Convolution matcher utilities
115115
//===----------------------------------------------------------------------===//
116116

117+
bool isaConv1DOp(LinalgOp op);
118+
bool isaConv1DNwcWcfOp(LinalgOp op);
119+
bool isaConv1DNcwFcwOp(LinalgOp op);
120+
bool isaDepthwiseConv1DNcwCwOp(LinalgOp op);
121+
bool isaDepthwiseConv1DNwcWcOp(LinalgOp op);
122+
bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op);
123+
117124
//===----------------------------------------------------------------------===//
118125
// Fusion / Tiling utilities
119126
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -338,35 +338,24 @@ bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned a
338338
}
339339

340340
static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
341-
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
342-
if (indexingMaps.size() != 3) return "";
343-
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
344-
if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0))
345-
return "linalg.conv_1d";
341+
if (isaConv1DOp(genericOp)) return "linalg.conv_1d";
346342
return "";
347343
}
348344

349345
static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
350346
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
351347
if (indexingMaps.size() != 3) return "";
352-
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
353348
// depthwise_conv_1d_ncw_cw
354349
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
355350
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
356351
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
357-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
358-
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
359-
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
360-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2))
352+
if (isaDepthwiseConv1DNcwCwOp(genericOp))
361353
return "linalg.depthwise_conv_1d_ncw_cw";
362354
// depthwise_conv_1d_nwc_wc
363355
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
364356
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
365357
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
366-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
367-
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
368-
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
369-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1))
358+
if (isaDepthwiseConv1DNwcWcOp(genericOp))
370359
return "linalg.depthwise_conv_1d_nwc_wc";
371360
// conv_2d
372361
// #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
@@ -414,34 +403,23 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
414403
static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
415404
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
416405
if (indexingMaps.size() != 3) return "";
417-
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
418406
// depthwise_conv_1d_nwc_wcm
419407
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
420408
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
421409
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
422-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
423-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
424-
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
425-
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
426-
matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3))
410+
if (isaDepthwiseConv1DNwcWcmOp(genericOp))
427411
return "linalg.depthwise_conv_1d_nwc_wcm";
428412
// conv_1d_nwc_wcf
429413
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
430414
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
431415
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
432-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
433-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
434-
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
435-
matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2))
416+
if (isaConv1DNwcWcfOp(genericOp))
436417
return "linalg.conv_1d_nwc_wcf";
437418
// conv_1d_ncw_fcw
438419
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
439420
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
440421
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
441-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
442-
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
443-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
444-
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
422+
if (isaConv1DNcwFcwOp(genericOp))
445423
return "linalg.conv_1d_ncw_fcw";
446424
return "";
447425
}

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,225 @@ bool isReductionIterator(utils::IteratorType iteratorType) {
240240
return iteratorType == utils::IteratorType::reduction;
241241
}
242242

243+
// -------------------------------
244+
// ---------- CONV ---------------
245+
// -------------------------------
246+
247+
/// Utility to match block body for linalg.pool* ops.
248+
template <typename... OpTypes>
249+
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
250+
Operation *defOp = yieldVal.getDefiningOp();
251+
// if (!defOp) return false;
252+
if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
253+
254+
BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
255+
BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
256+
if (!lhsArg || !rhsArg) return false;
257+
return true;
258+
}
259+
260+
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
261+
return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
262+
}
263+
264+
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
265+
return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
266+
}
267+
268+
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
269+
return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
270+
}
271+
272+
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
273+
return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
274+
}
275+
276+
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
277+
return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
278+
}
279+
280+
static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
281+
uint32_t mapIndex, uint32_t dimIndex) {
282+
auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
283+
if (dimIndex < affineMap.getNumResults())
284+
return affineMap.getResult(dimIndex);
285+
return nullptr;
286+
}
287+
288+
// Check if `expr` is either:
289+
// - a dimension expr alone (implying *1), or
290+
// - a multiplication of dimension expr by constant.
291+
static bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
292+
if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
293+
dim = dExpr;
294+
constantValue = 1;
295+
return true;
296+
}
297+
298+
auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
299+
if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
300+
return false;
301+
302+
AffineExpr lhs = mulExpr.getLHS();
303+
AffineExpr rhs = mulExpr.getRHS();
304+
305+
if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
306+
if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
307+
dim = dExpr;
308+
constantValue = cst.getValue();
309+
return true;
310+
}
311+
}
312+
if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
313+
if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
314+
dim = dExpr;
315+
constantValue = cst.getValue();
316+
return true;
317+
}
318+
}
319+
return false;
320+
}
321+
322+
static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) {
323+
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
324+
AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
325+
auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
326+
if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
327+
return false;
328+
329+
AffineExpr dim0, dim1;
330+
// TODO(Abhishek-Varma): Use this information in specialize.cpp.
331+
int64_t c0, c1;
332+
333+
if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
334+
isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
335+
// Pattern matched with dims and constants extracted.
336+
AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
337+
AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
338+
return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
339+
}
340+
return false;
341+
}
342+
343+
static bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
344+
return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
345+
}
346+
347+
static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps, ArrayRef<int64_t> expectedSizes) {
348+
if (indexingMaps.size() != expectedSizes.size()) return false;
349+
350+
for (auto [indexingMap, expectedSize] : llvm::zip_equal(indexingMaps, expectedSizes)) {
351+
auto affineMap = cast<AffineMapAttr>(indexingMap).getValue();
352+
if (affineMap.getNumResults() != expectedSize) return false;
353+
}
354+
return true;
355+
}
356+
357+
bool isaConv1DOp(LinalgOp op) {
358+
if (isa<linalg::Conv1DOp>(op)) return true;
359+
360+
if (!isaConvolutionOpInterface(op)) return false;
361+
362+
ArrayAttr indexingMaps = op.getIndexingMaps();
363+
if (!verifyConvIndexingMapSizes(indexingMaps, {1,1,1})) return false;
364+
365+
// #map = affine_map<(d0, d1) -> (d0 + d1)>
366+
// #map1 = affine_map<(d0, d1) -> (d1)>
367+
// #map2 = affine_map<(d0, d1) -> (d0)>
368+
return matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0);
369+
}
370+
371+
bool isaConv1DNwcWcfOp(LinalgOp op) {
372+
if (isa<linalg::Conv1DNwcWcfOp>(op)) return true;
373+
374+
if (!isaConvolutionOpInterface(op)) return false;
375+
376+
ArrayAttr indexingMaps = op.getIndexingMaps();
377+
if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
378+
379+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
380+
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
381+
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
382+
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
383+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
384+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
385+
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
386+
matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 2));
387+
}
388+
389+
bool isaConv1DNcwFcwOp(LinalgOp op) {
390+
if (isa<linalg::Conv1DNcwFcwOp>(op)) return true;
391+
392+
if (!isaConvolutionOpInterface(op)) return false;
393+
394+
ArrayAttr indexingMaps = op.getIndexingMaps();
395+
if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
396+
397+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
398+
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
399+
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
400+
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
401+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
402+
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
403+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
404+
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
405+
}
406+
407+
bool isaDepthwiseConv1DNcwCwOp(LinalgOp op) {
408+
if (isa<linalg::DepthwiseConv1DNcwCwOp>(op)) return true;
409+
410+
if (!isaConvolutionOpInterface(op)) return false;
411+
412+
ArrayAttr indexingMaps = op.getIndexingMaps();
413+
if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
414+
415+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
416+
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
417+
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
418+
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
419+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
420+
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
421+
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
422+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2));
423+
}
424+
425+
bool isaDepthwiseConv1DNwcWcOp(LinalgOp op) {
426+
if (isa<linalg::DepthwiseConv1DNwcWcOp>(op)) return true;
427+
428+
if (!isaConvolutionOpInterface(op)) return false;
429+
430+
ArrayAttr indexingMaps = op.getIndexingMaps();
431+
if (!verifyConvIndexingMapSizes(indexingMaps, {3,2,3})) return false;
432+
433+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
434+
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
435+
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
436+
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
437+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
438+
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
439+
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
440+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1));
441+
}
442+
443+
bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op) {
444+
if (isa<linalg::DepthwiseConv1DNwcWcmOp>(op)) return true;
445+
446+
if (!isaConvolutionOpInterface(op)) return false;
447+
448+
ArrayAttr indexingMaps = op.getIndexingMaps();
449+
if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,4})) return false;
450+
451+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
452+
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
453+
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
454+
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
455+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
456+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
457+
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 1) &&
458+
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2) &&
459+
matchConvDimExprPattern(indexingMaps, fIndex, 2, oIndex, 3));
460+
}
461+
243462
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
244463
Value source, Value pad, bool nofold,
245464
ValueRange typeDynDims) {

0 commit comments

Comments
 (0)