@@ -237,6 +237,39 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
237237 return replaceWithMatmulVariant<MatmulOp>(rewriter, genericOp);
238238}
239239
240+ // / Utility to match block body for linalg.pool* ops.
241+ template <typename ... OpTypes>
242+ static bool bodyMatcherForPoolOps (Value yieldVal, Block *body) {
243+ Operation *defOp = yieldVal.getDefiningOp ();
244+ // if (!defOp) return false;
245+ if (!(isa_and_present<OpTypes>(defOp) || ...)) return false ;
246+
247+ BlockArgument lhsArg = dyn_cast<BlockArgument>(defOp->getOperand (0 ));
248+ BlockArgument rhsArg = dyn_cast<BlockArgument>(defOp->getOperand (1 ));
249+ if (!lhsArg || !rhsArg) return false ;
250+ return true ;
251+ }
252+
253+ static bool bodyMatcherForMaxSignedPoolOps (Value yieldVal, Block *body) {
254+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxSIOp>(yieldVal, body);
255+ }
256+
257+ static bool bodyMatcherForMaxUnsignedPoolOps (Value yieldVal, Block *body) {
258+ return bodyMatcherForPoolOps<arith::MaximumFOp, arith::MaxUIOp>(yieldVal, body);
259+ }
260+
261+ static bool bodyMatcherForMinSignedPoolOps (Value yieldVal, Block *body) {
262+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinSIOp>(yieldVal, body);
263+ }
264+
265+ static bool bodyMatcherForMinUnsignedPoolOps (Value yieldVal, Block *body) {
266+ return bodyMatcherForPoolOps<arith::MinimumFOp, arith::MinUIOp>(yieldVal, body);
267+ }
268+
269+ static bool bodyMatcherForSumPoolOps (Value yieldVal, Block *body) {
270+ return bodyMatcherForPoolOps<arith::AddIOp, arith::AddFOp>(yieldVal, body);
271+ }
272+
240273static mlir::AffineExpr getAffineMapDim (ArrayAttr indexingMaps,
241274 uint32_t mapIndex, uint32_t dimIndex) {
242275 auto affineMap = cast<AffineMapAttr>(indexingMaps[mapIndex]).getValue ();
@@ -279,6 +312,39 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
279312 if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 0 ))) &&
280313 (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))))
281314 return " linalg.conv_2d" ;
315+
316+ Block *body = genericOp.getBlock ();
317+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
318+ Value yieldVal = yieldOp.getOperand (0 );
319+ // pooling_ncw_max
320+ // pooling_ncw_sum
321+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2 + d3)>
322+ // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
323+ // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
324+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
325+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == getAffineMapDim (indexingMaps, oIndex, 1 )) &&
326+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 2 )))) {
327+ if (bodyMatcherForMaxSignedPoolOps (yieldVal, body))
328+ return " linalg.pooling_ncw_max" ;
329+ if (bodyMatcherForSumPoolOps (yieldVal, body))
330+ return " linalg.pooling_ncw_sum" ;
331+ }
332+ // pooling_nwc_max
333+ // pooling_nwc_min
334+ // pooling_nwc_sum
335+ // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
336+ // #map3 = affine_map<(d0, d1, d2, d3) -> (d3)>
337+ // #map4 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
338+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
339+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 1 )) &&
340+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == getAffineMapDim (indexingMaps, oIndex, 2 ))) {
341+ if (bodyMatcherForMaxSignedPoolOps (yieldVal, body))
342+ return " linalg.pooling_nwc_max" ;
343+ if (bodyMatcherForMinSignedPoolOps (yieldVal, body))
344+ return " linalg.pooling_nwc_min" ;
345+ if (bodyMatcherForSumPoolOps (yieldVal, body))
346+ return " linalg.pooling_nwc_sum" ;
347+ }
282348 return " " ;
283349}
284350
@@ -346,6 +412,55 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
346412 (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
347413 (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 2 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))))
348414 return " linalg.conv_3d" ;
415+
416+ Block *body = genericOp.getBlock ();
417+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
418+ Value yieldVal = yieldOp.getOperand (0 );
419+ // pooling_nchw_max
420+ // pooling_nchw_sum
421+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 + d4, d3 + d5)>
422+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
423+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
424+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
425+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == getAffineMapDim (indexingMaps, oIndex, 1 )) &&
426+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
427+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 3 )))) {
428+ if (bodyMatcherForMaxSignedPoolOps (yieldVal, body))
429+ return " linalg.pooling_nchw_max" ;
430+ if (bodyMatcherForSumPoolOps (yieldVal, body))
431+ return " linalg.pooling_nchw_sum" ;
432+ }
433+ // pooling_nhwc_max
434+ // pooling_nhwc_min
435+ // pooling_nhwc_sum
436+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
437+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
438+ // #map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
439+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
440+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
441+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
442+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == getAffineMapDim (indexingMaps, oIndex, 3 ))) {
443+ if (bodyMatcherForMaxSignedPoolOps (yieldVal, body))
444+ return " linalg.pooling_nhwc_max" ;
445+ if (bodyMatcherForMinSignedPoolOps (yieldVal, body))
446+ return " linalg.pooling_nhwc_min" ;
447+ if (bodyMatcherForSumPoolOps (yieldVal, body))
448+ return " linalg.pooling_nhwc_sum" ;
449+ }
450+ // pooling_nhwc_max_unsigned
451+ // pooling_nhwc_min_unsigned
452+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
453+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5)>
454+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
455+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
456+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
457+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
458+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == getAffineMapDim (indexingMaps, oIndex, 3 ))) {
459+ if (bodyMatcherForMaxUnsignedPoolOps (yieldVal, body))
460+ return " linalg.pooling_nhwc_max_unsigned" ;
461+ if (bodyMatcherForMinUnsignedPoolOps (yieldVal, body))
462+ return " linalg.pooling_nhwc_max_unsigned" ;
463+ }
349464 return " " ;
350465}
351466
@@ -510,6 +625,28 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
510625 (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 2 ) + getAffineMapDim (indexingMaps, oIndex, 3 ))) &&
511626 (getAffineMapDim (indexingMaps, iIndex, 4 ) == getAffineMapDim (indexingMaps, fIndex , 3 ) && getAffineMapDim (indexingMaps, iIndex, 4 ) == getAffineMapDim (indexingMaps, oIndex, 4 )))
512627 return " linalg.depthwise_conv_3d_ndhwc_dhwc" ;
628+
629+ Block *body = genericOp.getBlock ();
630+ auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
631+ Value yieldVal = yieldOp.getOperand (0 );
632+ // pooling_ndhwc_max
633+ // pooling_ndhwc_min
634+ // pooling_ndhwc_sum
635+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3 + d7, d4)>
636+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)>
637+ // #map4 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
638+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
639+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
640+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
641+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 2 ) + getAffineMapDim (indexingMaps, oIndex, 3 ))) &&
642+ (getAffineMapDim (indexingMaps, iIndex, 4 ) == getAffineMapDim (indexingMaps, oIndex, 4 ))) {
643+ if (bodyMatcherForMaxSignedPoolOps (yieldVal, body))
644+ return " linalg.pooling_ndhwc_max" ;
645+ if (bodyMatcherForMinSignedPoolOps (yieldVal, body))
646+ return " linalg.pooling_ndhwc_min" ;
647+ if (bodyMatcherForSumPoolOps (yieldVal, body))
648+ return " linalg.pooling_ndhwc_sum" ;
649+ }
513650 return " " ;
514651}
515652
@@ -639,6 +776,36 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
639776 namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv3DNcdhwCdhwOp>(genericOp, resultTypes, inputs, outputs);
640777 } else if (convKind == " linalg.depthwise_conv_3d_ndhwc_dhwc" ) {
641778 namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv3DNdhwcDhwcOp>(genericOp, resultTypes, inputs, outputs);
779+ } else if (convKind == " linalg.pooling_nchw_max" ) {
780+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNchwMaxOp>(genericOp, resultTypes, inputs, outputs);
781+ } else if (convKind == " linalg.pooling_nchw_sum" ) {
782+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNchwSumOp>(genericOp, resultTypes, inputs, outputs);
783+ } else if (convKind == " linalg.pooling_nhwc_max" ) {
784+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
785+ } else if (convKind == " linalg.pooling_nhwc_min" ) {
786+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNhwcMinOp>(genericOp, resultTypes, inputs, outputs);
787+ } else if (convKind == " linalg.pooling_nhwc_sum" ) {
788+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNhwcSumOp>(genericOp, resultTypes, inputs, outputs);
789+ } else if (convKind == " linalg.pooling_nhwc_max_unsigned" ) {
790+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNhwcMaxUnsignedOp>(genericOp, resultTypes, inputs, outputs);
791+ } else if (convKind == " linalg.pooling_nhwc_min_unsigned" ) {
792+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNhwcMinUnsignedOp>(genericOp, resultTypes, inputs, outputs);
793+ } else if (convKind == " linalg.pooling_ncw_max" ) {
794+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNcwMaxOp>(genericOp, resultTypes, inputs, outputs);
795+ } else if (convKind == " linalg.pooling_ncw_sum" ) {
796+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNcwSumOp>(genericOp, resultTypes, inputs, outputs);
797+ } else if (convKind == " linalg.pooling_nwc_max" ) {
798+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNwcMaxOp>(genericOp, resultTypes, inputs, outputs);
799+ } else if (convKind == " linalg.pooling_nwc_min" ) {
800+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNwcMinOp>(genericOp, resultTypes, inputs, outputs);
801+ } else if (convKind == " linalg.pooling_nwc_sum" ) {
802+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNwcSumOp>(genericOp, resultTypes, inputs, outputs);
803+ } else if (convKind == " linalg.pooling_ndhwc_max" ) {
804+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNdhwcMaxOp>(genericOp, resultTypes, inputs, outputs);
805+ } else if (convKind == " linalg.pooling_ndhwc_min" ) {
806+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNdhwcMinOp>(genericOp, resultTypes, inputs, outputs);
807+ } else if (convKind == " linalg.pooling_ndhwc_sum" ) {
808+ namedOp = rewriter.replaceOpWithNewOp <linalg::PoolingNdhwcSumOp>(genericOp, resultTypes, inputs, outputs);
642809 }
643810 return namedOp;
644811
0 commit comments