Skip to content

Commit 535f7e9

Browse files
Pooling ops'
1 parent e5acca4 commit 535f7e9

File tree

3 files changed

+373
-213
lines changed

3 files changed

+373
-213
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,21 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
140140
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
141141
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
142142
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
143+
bool isaPoolingNchwMaxOp(LinalgOp op);
144+
bool isaPoolingNchwSumOp(LinalgOp op);
145+
bool isaPoolingNhwcMaxOp(LinalgOp op);
146+
bool isaPoolingNhwcMinOp(LinalgOp op);
147+
bool isaPoolingNhwcSumOp(LinalgOp op);
148+
bool isaPoolingNhwcMaxUnsignedOp(LinalgOp op);
149+
bool isaPoolingNhwcMinUnsignedOp(LinalgOp op);
150+
bool isaPoolingNcwMaxOp(LinalgOp op);
151+
bool isaPoolingNcwSumOp(LinalgOp op);
152+
bool isaPoolingNwcMaxOp(LinalgOp op);
153+
bool isaPoolingNwcMinOp(LinalgOp op);
154+
bool isaPoolingNwcSumOp(LinalgOp op);
155+
bool isaPoolingNdhwcMaxOp(LinalgOp op);
156+
bool isaPoolingNdhwcMinOp(LinalgOp op);
157+
bool isaPoolingNdhwcSumOp(LinalgOp op);
143158

144159
//===----------------------------------------------------------------------===//
145160
// Fusion / Tiling utilities

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 30 additions & 213 deletions
Original file line numberDiff line numberDiff line change
@@ -237,106 +237,6 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
237237
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238238
}
239239

240-
/// Utility to match block body for linalg.pool* ops.
241-
template <typename... OpTypes>
242-
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
243-
Operation *defOp = yieldVal.getDefiningOp();
244-
// if (!defOp) return false;
245-
if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
246-
247-
BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
248-
BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
249-
if (!lhsArg || !rhsArg) return false;
250-
return true;
251-
}
252-
253-
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
254-
return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
255-
}
256-
257-
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
258-
return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
259-
}
260-
261-
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
262-
return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
263-
}
264-
265-
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
266-
return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
267-
}
268-
269-
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
270-
return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
271-
}
272-
273-
static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
274-
uint32_t mapIndex, uint32_t dimIndex) {
275-
auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
276-
if (dimIndex < affineMap.getNumResults())
277-
return affineMap.getResult(dimIndex);
278-
return nullptr;
279-
}
280-
281-
// Check if `expr` is either:
282-
// - a dimension expr alone (implying *1), or
283-
// - a multiplication of dimension expr by constant.
284-
bool isDimTimesConstantOrDimOnly(AffineExpr expr, AffineExpr &dim, int64_t &constantValue) {
285-
if (auto dExpr = dyn_cast<AffineDimExpr>(expr)) {
286-
dim = dExpr;
287-
constantValue = 1;
288-
return true;
289-
}
290-
291-
auto mulExpr = dyn_cast<AffineBinaryOpExpr>(expr);
292-
if (!mulExpr || mulExpr.getKind() != AffineExprKind::Mul)
293-
return false;
294-
295-
AffineExpr lhs = mulExpr.getLHS();
296-
AffineExpr rhs = mulExpr.getRHS();
297-
298-
if (auto dExpr = dyn_cast<AffineDimExpr>(lhs)) {
299-
if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
300-
dim = dExpr;
301-
constantValue = cst.getValue();
302-
return true;
303-
}
304-
}
305-
if (auto cst = dyn_cast<AffineConstantExpr>(lhs)) {
306-
if (auto dExpr = dyn_cast<AffineDimExpr>(rhs)) {
307-
dim = dExpr;
308-
constantValue = cst.getValue();
309-
return true;
310-
}
311-
}
312-
return false;
313-
}
314-
315-
bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned fDim, unsigned oDim) {
316-
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
317-
AffineExpr inpExpr = getAffineMapDim(indexingMaps, iIndex, iDim);
318-
auto addExpr = dyn_cast<AffineBinaryOpExpr>(inpExpr);
319-
if (!addExpr || addExpr.getKind() != AffineExprKind::Add)
320-
return false;
321-
322-
AffineExpr dim0, dim1;
323-
// TODO(Abhishek-Varma): Use this information in specialize.cpp.
324-
int64_t c0, c1;
325-
326-
if (isDimTimesConstantOrDimOnly(addExpr.getLHS(), dim0, c0) &&
327-
isDimTimesConstantOrDimOnly(addExpr.getRHS(), dim1, c1)) {
328-
// Pattern matched with dims and constants extracted.
329-
AffineExpr fExpr = getAffineMapDim(indexingMaps, fIndex, fDim);
330-
AffineExpr oExpr = getAffineMapDim(indexingMaps, oIndex, oDim);
331-
return ((dim0 == fExpr && dim1 == oExpr) || (dim1 == fExpr && dim0 == oExpr));
332-
}
333-
return false;
334-
}
335-
336-
bool matchConvDimExprPattern(ArrayAttr indexingMaps, unsigned aIndex, unsigned aDim, unsigned bIndex, unsigned bDim) {
337-
return getAffineMapDim(indexingMaps, aIndex, aDim) == getAffineMapDim(indexingMaps, bIndex, bDim);
338-
}
339-
340240
static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
341241
if (isaConv1DOp(genericOp)) return "linalg.conv_1d";
342242
return "";
@@ -349,41 +249,16 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
349249
return "linalg.depthwise_conv_1d_nwc_wc";
350250
if (isaConv2DOp(genericOp))
351251
return "linalg.conv_2d";
352-
353-
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
354-
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
355-
Block *body = genericOp.getBlock();
356-
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
357-
Value yieldVal = yieldOp.getOperand(0);
358-
// pooling_ncw_max
359-
// pooling_ncw_sum
360-
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
361-
// #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
362-
// #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
363-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
364-
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
365-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2)) {
366-
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
367-
return "linalg.pooling_ncw_max";
368-
if (bodyMatcherForSumPoolOps(yieldVal, body))
369-
return "linalg.pooling_ncw_sum";
370-
}
371-
// pooling_nwc_max
372-
// pooling_nwc_min
373-
// pooling_nwc_sum
374-
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
375-
// #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
376-
// #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
377-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
378-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
379-
matchConvDimExprPattern(indexingMaps, iIndex, 2, oIndex, 2)) {
380-
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
381-
return "linalg.pooling_nwc_max";
382-
if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
383-
return "linalg.pooling_nwc_min";
384-
if (bodyMatcherForSumPoolOps(yieldVal, body))
385-
return "linalg.pooling_nwc_sum";
386-
}
252+
if (isaPoolingNcwMaxOp(genericOp))
253+
return "linalg.pooling_ncw_max";
254+
if (isaPoolingNcwSumOp(genericOp))
255+
return "linalg.pooling_ncw_sum";
256+
if (isaPoolingNwcMaxOp(genericOp))
257+
return "linalg.pooling_nwc_max";
258+
if (isaPoolingNwcMinOp(genericOp))
259+
return "linalg.pooling_nwc_min";
260+
if (isaPoolingNwcSumOp(genericOp))
261+
return "linalg.pooling_nwc_sum";
387262
return "";
388263
}
389264

@@ -402,61 +277,22 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
402277
return "linalg.depthwise_conv_2d_nchw_chw";
403278
if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
404279
return "linalg.depthwise_conv_2d_nhwc_hwc";
405-
406-
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
407-
if (indexingMaps.size() < 3) return "";
408-
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
409280
if (isaConv3DOp(genericOp))
410281
return "linalg.conv_3d";
411-
412-
Block *body = genericOp.getBlock();
413-
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
414-
Value yieldVal = yieldOp.getOperand(0);
415-
// pooling_nchw_max
416-
// pooling_nchw_sum
417-
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
418-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
419-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
420-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
421-
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
422-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/0, /*oDim=*/2) &&
423-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/1, /*oDim=*/3)) {
424-
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
425-
return "linalg.pooling_nchw_max";
426-
if (bodyMatcherForSumPoolOps(yieldVal, body))
427-
return "linalg.pooling_nchw_sum";
428-
}
429-
// pooling_nhwc_max
430-
// pooling_nhwc_min
431-
// pooling_nhwc_sum
432-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
433-
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
434-
// #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
435-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
436-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
437-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
438-
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) {
439-
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
440-
return "linalg.pooling_nhwc_max";
441-
if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
442-
return "linalg.pooling_nhwc_min";
443-
if (bodyMatcherForSumPoolOps(yieldVal, body))
444-
return "linalg.pooling_nhwc_sum";
445-
}
446-
// pooling_nhwc_max_unsigned
447-
// pooling_nhwc_min_unsigned
448-
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
449-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
450-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
451-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
452-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
453-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
454-
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3)) {
455-
if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
456-
return "linalg.pooling_nhwc_max_unsigned";
457-
if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
458-
return "linalg.pooling_nhwc_min_unsigned";
459-
}
282+
if (isaPoolingNchwMaxOp(genericOp))
283+
return "linalg.pooling_nchw_max";
284+
if (isaPoolingNchwSumOp(genericOp))
285+
return "linalg.pooling_nchw_sum";
286+
if (isaPoolingNhwcMaxOp(genericOp))
287+
return "linalg.pooling_nhwc_max";
288+
if (isaPoolingNhwcMinOp(genericOp))
289+
return "linalg.pooling_nhwc_min";
290+
if (isaPoolingNhwcSumOp(genericOp))
291+
return "linalg.pooling_nhwc_sum";
292+
if (isaPoolingNhwcMaxUnsignedOp(genericOp))
293+
return "linalg.pooling_nhwc_max_unsigned";
294+
if (isaPoolingNhwcMinUnsignedOp(genericOp))
295+
return "linalg.pooling_nhwc_min_unsigned";
460296
return "";
461297
}
462298

@@ -491,31 +327,12 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
491327
return "linalg.depthwise_conv_3d_ncdhw_cdhw";
492328
if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
493329
return "linalg.depthwise_conv_3d_ndhwc_dhwc";
494-
495-
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
496-
if (indexingMaps.size() < 3) return "";
497-
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
498-
Block *body = genericOp.getBlock();
499-
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
500-
Value yieldVal = yieldOp.getOperand(0);
501-
// pooling_ndhwc_max
502-
// pooling_ndhwc_min
503-
// pooling_ndhwc_sum
504-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
505-
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
506-
// #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
507-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
508-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
509-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
510-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
511-
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4)) {
512-
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
513-
return "linalg.pooling_ndhwc_max";
514-
if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
515-
return "linalg.pooling_ndhwc_min";
516-
if (bodyMatcherForSumPoolOps(yieldVal, body))
517-
return "linalg.pooling_ndhwc_sum";
518-
}
330+
if (isaPoolingNdhwcMaxOp(genericOp))
331+
return "linalg.pooling_ndhwc_max";
332+
if (isaPoolingNdhwcMinOp(genericOp))
333+
return "linalg.pooling_ndhwc_min";
334+
if (isaPoolingNdhwcSumOp(genericOp))
335+
return "linalg.pooling_ndhwc_sum";
519336
return "";
520337
}
521338

0 commit comments

Comments
 (0)