Skip to content

Commit 789fb85

Browse files
Add pooling ops to the mix - has few issues but we can shift to considering dilations/strides now
1 parent dac92f1 commit 789fb85

File tree

1 file changed

+167
-0
lines changed

1 file changed

+167
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,39 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
237237
return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238238
}
239239

240+
/// Utility to match block body for linalg.pool* ops.
241+
template <typename... OpTypes>
242+
static bool bodyMatcherForPoolOps(Value yieldVal, Block *body) {
243+
Operation *defOp = yieldVal.getDefiningOp();
244+
// if (!defOp) return false;
245+
if (!(isa_and_present<OpTypes>(defOp) || ...)) return false;
246+
247+
BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand(0));
248+
BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand(1));
249+
if (!lhsArg || !rhsArg) return false;
250+
return true;
251+
}
252+
253+
static bool bodyMatcherForMaxSignedPoolOps(Value yieldVal, Block *body) {
254+
return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
255+
}
256+
257+
static bool bodyMatcherForMaxUnsignedPoolOps(Value yieldVal, Block *body) {
258+
return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
259+
}
260+
261+
static bool bodyMatcherForMinSignedPoolOps(Value yieldVal, Block *body) {
262+
return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
263+
}
264+
265+
static bool bodyMatcherForMinUnsignedPoolOps(Value yieldVal, Block *body) {
266+
return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
267+
}
268+
269+
static bool bodyMatcherForSumPoolOps(Value yieldVal, Block *body) {
270+
return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
271+
}
272+
240273
static mlir::AffineExpr getAffineMapDim(ArrayAttr indexingMaps,
241274
uint32_t mapIndex, uint32_t dimIndex) {
242275
auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue();
@@ -279,6 +312,39 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
279312
if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
280313
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))))
281314
return "linalg.conv_2d";
315+
316+
Block *body = genericOp.getBlock();
317+
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
318+
Value yieldVal = yieldOp.getOperand(0);
319+
// pooling_ncw_max
320+
// pooling_ncw_sum
321+
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
322+
// #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
323+
// #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
324+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
325+
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
326+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2)))) {
327+
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
328+
return "linalg.pooling_ncw_max";
329+
if (bodyMatcherForSumPoolOps(yieldVal, body))
330+
return "linalg.pooling_ncw_sum";
331+
}
332+
// pooling_nwc_max
333+
// pooling_nwc_min
334+
// pooling_nwc_sum
335+
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
336+
// #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
337+
// #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
338+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
339+
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1)) &&
340+
(getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, oIndex, 2))) {
341+
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
342+
return "linalg.pooling_nwc_max";
343+
if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
344+
return "linalg.pooling_nwc_min";
345+
if (bodyMatcherForSumPoolOps(yieldVal, body))
346+
return "linalg.pooling_nwc_sum";
347+
}
282348
return "";
283349
}
284350

@@ -346,6 +412,55 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
346412
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
347413
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))))
348414
return "linalg.conv_3d";
415+
416+
Block *body = genericOp.getBlock();
417+
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
418+
Value yieldVal = yieldOp.getOperand(0);
419+
// pooling_nchw_max
420+
// pooling_nchw_sum
421+
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
422+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
423+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
424+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
425+
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
426+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
427+
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 3)))) {
428+
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
429+
return "linalg.pooling_nchw_max";
430+
if (bodyMatcherForSumPoolOps(yieldVal, body))
431+
return "linalg.pooling_nchw_sum";
432+
}
433+
// pooling_nhwc_max
434+
// pooling_nhwc_min
435+
// pooling_nhwc_sum
436+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
437+
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
438+
// #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
439+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
440+
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
441+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
442+
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
443+
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
444+
return "linalg.pooling_nhwc_max";
445+
if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
446+
return "linalg.pooling_nhwc_min";
447+
if (bodyMatcherForSumPoolOps(yieldVal, body))
448+
return "linalg.pooling_nhwc_sum";
449+
}
450+
// pooling_nhwc_max_unsigned
451+
// pooling_nhwc_min_unsigned
452+
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
453+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
454+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
455+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
456+
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
457+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
458+
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3))) {
459+
if (bodyMatcherForMaxUnsignedPoolOps(yieldVal, body))
460+
return "linalg.pooling_nhwc_max_unsigned";
461+
if (bodyMatcherForMinUnsignedPoolOps(yieldVal, body))
462+
return "linalg.pooling_nhwc_max_unsigned";
463+
}
349464
return "";
350465
}
351466

@@ -510,6 +625,28 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
510625
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
511626
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
512627
return "linalg.depthwise_conv_3d_ndhwc_dhwc";
628+
629+
Block *body = genericOp.getBlock();
630+
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
631+
Value yieldVal = yieldOp.getOperand(0);
632+
// pooling_ndhwc_max
633+
// pooling_ndhwc_min
634+
// pooling_ndhwc_sum
635+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
636+
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
637+
// #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
638+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
639+
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
640+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
641+
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
642+
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4))) {
643+
if (bodyMatcherForMaxSignedPoolOps(yieldVal, body))
644+
return "linalg.pooling_ndhwc_max";
645+
if (bodyMatcherForMinSignedPoolOps(yieldVal, body))
646+
return "linalg.pooling_ndhwc_min";
647+
if (bodyMatcherForSumPoolOps(yieldVal, body))
648+
return "linalg.pooling_ndhwc_sum";
649+
}
513650
return "";
514651
}
515652

@@ -639,6 +776,36 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
639776
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(genericOp, resultTypes, inputs, outputs);
640777
} else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") {
641778
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(genericOp, resultTypes, inputs, outputs);
779+
} else if (convKind == "linalg.pooling_nchw_max") {
780+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNchwMaxOp>(genericOp, resultTypes, inputs, outputs);
781+
} else if (convKind == "linalg.pooling_nchw_sum") {
782+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNchwSumOp>(genericOp, resultTypes, inputs, outputs);
783+
} else if (convKind == "linalg.pooling_nhwc_max") {
784+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
785+
} else if (convKind == "linalg.pooling_nhwc_min") {
786+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMinOp>(genericOp, resultTypes, inputs, outputs);
787+
} else if (convKind == "linalg.pooling_nhwc_sum") {
788+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcSumOp>(genericOp, resultTypes, inputs, outputs);
789+
} else if (convKind == "linalg.pooling_nhwc_max_unsigned") {
790+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMaxUnsignedOp>(genericOp, resultTypes, inputs, outputs);
791+
} else if (convKind == "linalg.pooling_nhwc_min_unsigned") {
792+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNhwcMinUnsignedOp>(genericOp, resultTypes, inputs, outputs);
793+
} else if (convKind == "linalg.pooling_ncw_max") {
794+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNcwMaxOp>(genericOp, resultTypes, inputs, outputs);
795+
} else if (convKind == "linalg.pooling_ncw_sum") {
796+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNcwSumOp>(genericOp, resultTypes, inputs, outputs);
797+
} else if (convKind == "linalg.pooling_nwc_max") {
798+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcMaxOp>(genericOp, resultTypes, inputs, outputs);
799+
} else if (convKind == "linalg.pooling_nwc_min") {
800+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcMinOp>(genericOp, resultTypes, inputs, outputs);
801+
} else if (convKind == "linalg.pooling_nwc_sum") {
802+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNwcSumOp>(genericOp, resultTypes, inputs, outputs);
803+
} else if (convKind == "linalg.pooling_ndhwc_max") {
804+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
805+
} else if (convKind == "linalg.pooling_ndhwc_min") {
806+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcMinOp>(genericOp, resultTypes, inputs, outputs);
807+
} else if (convKind == "linalg.pooling_ndhwc_sum") {
808+
namedOp = rewriter.replaceOpWithNewOp<linalg::PoolingNdhwcSumOp>(genericOp, resultTypes, inputs, outputs);
642809
}
643810
return namedOp;
644811

0 commit comments

Comments
 (0)