@@ -430,19 +430,33 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
430430 })));
431431}
432432
433- // / Enum of all kinds of Pooling Op's type .
434- enum PoolingType {
435- NONE ,
436- MAX_SIGNED ,
437- MAX_UNSIGNED ,
438- MIN_SIGNED ,
439- MIN_UNSIGNED ,
440- SUM
433+ // / Enum representing pooling operation types used by ConvMatcherBuilder .
434+ enum class PoolingType {
435+ None ,
436+ MaxSigned ,
437+ MaxUnsigned ,
438+ MinSigned ,
439+ MinUnsigned ,
440+ Sum
441441};
442442
443443// / Helper class for building convolution op matchers with minimal boilerplate.
444444// / Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
445445// / as Pooling ops.
446+ // /
447+ // / Usage: Create an instance with the op, spatial rank, and output pointers for
448+ // / extracted dilations/strides. Then chain matchStride() calls for each spatial
449+ // / dimension, followed by matchMaps() to verify indexing maps, and finally
450+ // / matchBody() to verify the operation body pattern.
451+ // /
452+ // / The `matched` flag starts as `true` and is set to `false` if any match step
453+ // / fails. This allows chaining multiple match calls; once any match fails, all
454+ // / subsequent calls become no-ops and the final result is `false`.
455+ // /
456+ // / The `dilations` and `strides` pointers are output parameters that get
457+ // / populated with the extracted dilation and stride values from the operation's
458+ // / indexing maps during matchStride() calls. These values are initially set to
459+ // / 1 for each spatial dimension and updated as patterns are matched.
446460class ConvMatcherBuilder {
447461 LinalgOp op;
448462 MLIRContext *ctx;
@@ -454,7 +468,7 @@ class ConvMatcherBuilder {
454468public:
455469 ConvMatcherBuilder (LinalgOp op, unsigned spatialRank, SmallVector<int64_t > *d,
456470 SmallVector<int64_t > *s,
457- PoolingType poolingType = PoolingType::NONE )
471+ PoolingType poolingType = PoolingType::None )
458472 : op(op), ctx(op->getContext ()), dilations(d), strides(s),
459473 indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
460474 *dilations = SmallVector<int64_t >(spatialRank, 1 );
@@ -474,16 +488,16 @@ class ConvMatcherBuilder {
474488 ConvMatcherBuilder &matchStride (unsigned iDim, unsigned fDim , unsigned oDim,
475489 unsigned idx) {
476490 if (matched) {
477- matched = matchConvDimAddExprPattern (indexingMaps, iDim, fDim , oDim,
478- (*dilations)[idx], (*strides)[idx]);
491+ matched & = matchConvDimAddExprPattern (indexingMaps, iDim, fDim , oDim,
492+ (*dilations)[idx], (*strides)[idx]);
479493 }
480494 return *this ;
481495 }
482496
483497 // / Match expected indexing maps layout. Returns *this for method chaining.
484- ConvMatcherBuilder &expectMaps (ArrayRef<ArrayRef<AffineExpr>> maps) {
498+ ConvMatcherBuilder &matchMaps (ArrayRef<ArrayRef<AffineExpr>> maps) {
485499 if (matched)
486- matched = convLayoutMatches (maps, indexingMaps, ctx);
500+ matched & = convLayoutMatches (maps, indexingMaps, ctx);
487501 return *this ;
488502 }
489503
@@ -494,17 +508,17 @@ class ConvMatcherBuilder {
494508 Block *body = op.getBlock ();
495509 auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
496510 switch (poolingType) {
497- case PoolingType::NONE :
511+ case PoolingType::None :
498512 return bodyMatcherForConvolutionOps (yieldOp.getOperand (0 ), body);
499- case PoolingType::MAX_SIGNED :
513+ case PoolingType::MaxSigned :
500514 return bodyMatcherForMaxSignedPoolOps (yieldOp.getOperand (0 ), body);
501- case PoolingType::MAX_UNSIGNED :
515+ case PoolingType::MaxUnsigned :
502516 return bodyMatcherForMaxUnsignedPoolOps (yieldOp.getOperand (0 ), body);
503- case PoolingType::MIN_SIGNED :
517+ case PoolingType::MinSigned :
504518 return bodyMatcherForMinSignedPoolOps (yieldOp.getOperand (0 ), body);
505- case PoolingType::MIN_UNSIGNED :
519+ case PoolingType::MinUnsigned :
506520 return bodyMatcherForMinUnsignedPoolOps (yieldOp.getOperand (0 ), body);
507- case PoolingType::SUM :
521+ case PoolingType::Sum :
508522 return bodyMatcherForSumPoolOps (yieldOp.getOperand (0 ), body);
509523 }
510524 return false ;
@@ -533,9 +547,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
533547 AffineExpr w = m.dim (1 );
534548
535549 return m.matchStride (/* iDim=*/ 0 , /* fDim=*/ 0 , /* oDim=*/ 0 , /* idx=*/ 0 )
536- .expectMaps ({/* inputMap=*/ {m.strided (W, w, 0 )},
537- /* filterMap=*/ {w},
538- /* outputMap=*/ {W}})
550+ .matchMaps ({/* inputMap=*/ {m.strided (W, w, 0 )},
551+ /* filterMap=*/ {w},
552+ /* outputMap=*/ {W}})
539553 .matchBody ();
540554}
541555
@@ -560,9 +574,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
560574 AffineExpr c = m.dim (4 );
561575
562576 return m.matchStride (/* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 , /* idx=*/ 0 )
563- .expectMaps ({/* inputMap=*/ {N, m.strided (W, w, 0 ), c},
564- /* filterMap=*/ {w, c, F},
565- /* outputMap=*/ {N, W, F}})
577+ .matchMaps ({/* inputMap=*/ {N, m.strided (W, w, 0 ), c},
578+ /* filterMap=*/ {w, c, F},
579+ /* outputMap=*/ {N, W, F}})
566580 .matchBody ();
567581}
568582
@@ -587,9 +601,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
587601 AffineExpr w = m.dim (4 );
588602
589603 return m.matchStride (/* iDim=*/ 2 , /* fDim=*/ 2 , /* oDim=*/ 2 , /* idx=*/ 0 )
590- .expectMaps ({/* inputMap=*/ {N, c, m.strided (W, w, 0 )},
591- /* filterMap=*/ {F, c, w},
592- /* outputMap=*/ {N, F, W}})
604+ .matchMaps ({/* inputMap=*/ {N, c, m.strided (W, w, 0 )},
605+ /* filterMap=*/ {F, c, w},
606+ /* outputMap=*/ {N, F, W}})
593607 .matchBody ();
594608}
595609
@@ -614,9 +628,9 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
614628
615629 return m.matchStride (/* iDim=*/ 0 , /* fDim=*/ 0 , /* oDim=*/ 0 , /* idx=*/ 0 )
616630 .matchStride (/* iDim=*/ 1 , /* fDim=*/ 1 , /* oDim=*/ 1 , /* idx=*/ 1 )
617- .expectMaps ({/* inputMap=*/ {m.strided (H, h, 0 ), m.strided (W, w, 1 )},
618- /* filterMap=*/ {h, w},
619- /* outputMap=*/ {H, W}})
631+ .matchMaps ({/* inputMap=*/ {m.strided (H, h, 0 ), m.strided (W, w, 1 )},
632+ /* filterMap=*/ {h, w},
633+ /* outputMap=*/ {H, W}})
620634 .matchBody ();
621635}
622636
@@ -644,10 +658,10 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
644658 return m.matchStride (/* iDim=*/ 0 , /* fDim=*/ 0 , /* oDim=*/ 0 , /* idx=*/ 0 )
645659 .matchStride (/* iDim=*/ 1 , /* fDim=*/ 1 , /* oDim=*/ 1 , /* idx=*/ 1 )
646660 .matchStride (/* iDim=*/ 2 , /* fDim=*/ 2 , /* oDim=*/ 2 , /* idx=*/ 2 )
647- .expectMaps ({/* inputMap=*/ {m.strided (D, d, 0 ), m.strided (H, h, 1 ),
648- m.strided (W, w, 2 )},
649- /* filterMap=*/ {d, h, w},
650- /* outputMap=*/ {D, H, W}})
661+ .matchMaps ({/* inputMap=*/ {m.strided (D, d, 0 ), m.strided (H, h, 1 ),
662+ m.strided (W, w, 2 )},
663+ /* filterMap=*/ {d, h, w},
664+ /* outputMap=*/ {D, H, W}})
651665 .matchBody ();
652666}
653667
@@ -671,9 +685,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
671685 AffineExpr w = m.dim (3 );
672686
673687 return m.matchStride (/* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 , /* idx=*/ 0 )
674- .expectMaps ({/* inputMap=*/ {N, C, m.strided (W, w, 0 )},
675- /* filterMap=*/ {C, w},
676- /* outputMap=*/ {N, C, W}})
688+ .matchMaps ({/* inputMap=*/ {N, C, m.strided (W, w, 0 )},
689+ /* filterMap=*/ {C, w},
690+ /* outputMap=*/ {N, C, W}})
677691 .matchBody ();
678692}
679693
@@ -697,9 +711,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
697711 AffineExpr w = m.dim (3 );
698712
699713 return m.matchStride (/* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 , /* idx=*/ 0 )
700- .expectMaps ({/* inputMap=*/ {N, m.strided (W, w, 0 ), C},
701- /* filterMap=*/ {w, C},
702- /* outputMap=*/ {N, W, C}})
714+ .matchMaps ({/* inputMap=*/ {N, m.strided (W, w, 0 ), C},
715+ /* filterMap=*/ {w, C},
716+ /* outputMap=*/ {N, W, C}})
703717 .matchBody ();
704718}
705719
@@ -724,9 +738,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
724738 AffineExpr w = m.dim (4 );
725739
726740 return m.matchStride (/* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 , /* idx=*/ 0 )
727- .expectMaps ({/* inputMap=*/ {N, m.strided (W, w, 0 ), C},
728- /* filterMap=*/ {w, C, CM},
729- /* outputMap=*/ {N, W, C, CM}})
741+ .matchMaps ({/* inputMap=*/ {N, m.strided (W, w, 0 ), C},
742+ /* filterMap=*/ {w, C, CM},
743+ /* outputMap=*/ {N, W, C, CM}})
730744 .matchBody ();
731745}
732746
@@ -753,9 +767,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
753767
754768 return m.matchStride (/* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 , /* idx=*/ 0 )
755769 .matchStride (/* iDim=*/ 3 , /* fDim=*/ 2 , /* oDim=*/ 3 , /* idx=*/ 1 )
756- .expectMaps ({/* inputMap=*/ {N, C, m.strided (H, h, 0 ), m.strided (W, w, 1 )},
757- /* filterMap=*/ {C, h, w},
758- /* outputMap=*/ {N, C, H, W}})
770+ .matchMaps ({/* inputMap=*/ {N, C, m.strided (H, h, 0 ), m.strided (W, w, 1 )},
771+ /* filterMap=*/ {C, h, w},
772+ /* outputMap=*/ {N, C, H, W}})
759773 .matchBody ();
760774}
761775
@@ -789,10 +803,10 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
789803 return m.matchStride (/* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 , /* idx=*/ 0 )
790804 .matchStride (/* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 , /* idx=*/ 1 )
791805 .matchStride (/* iDim=*/ 3 , /* fDim=*/ 2 , /* oDim=*/ 3 , /* idx=*/ 2 )
792- .expectMaps ({/* inputMap=*/ {N, m.strided (D, d, 0 ), m.strided (H, h, 1 ),
793- m.strided (W, w, 2 ), C},
794- /* filterMap=*/ {d, h, w, C, CM},
795- /* outputMap=*/ {N, D, H, W, C, CM}})
806+ .matchMaps ({/* inputMap=*/ {N, m.strided (D, d, 0 ), m.strided (H, h, 1 ),
807+ m.strided (W, w, 2 ), C},
808+ /* filterMap=*/ {d, h, w, C, CM},
809+ /* outputMap=*/ {N, D, H, W, C, CM}})
796810 .matchBody ();
797811}
798812
@@ -810,7 +824,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
810824 " expected op to implement ConvolutionOpInterface" );
811825
812826 ConvMatcherBuilder m (op, /* spatialRank=*/ 2 , dilations, strides,
813- PoolingType::MAX_SIGNED );
827+ PoolingType::MaxSigned );
814828 AffineExpr N = m.dim (0 );
815829 AffineExpr H = m.dim (1 );
816830 AffineExpr W = m.dim (2 );
@@ -820,9 +834,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
820834
821835 return m.matchStride (/* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 , /* idx=*/ 0 )
822836 .matchStride (/* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 , /* idx=*/ 1 )
823- .expectMaps ({/* inputMap=*/ {N, m.strided (H, h, 0 ), m.strided (W, w, 1 ), C},
824- /* filterMap=*/ {h, w},
825- /* outputMap=*/ {N, H, W, C}})
837+ .matchMaps ({/* inputMap=*/ {N, m.strided (H, h, 0 ), m.strided (W, w, 1 ), C},
838+ /* filterMap=*/ {h, w},
839+ /* outputMap=*/ {N, H, W, C}})
826840 .matchBody ();
827841}
828842
@@ -840,7 +854,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
840854 " expected op to implement ConvolutionOpInterface" );
841855
842856 ConvMatcherBuilder m (op, /* spatialRank=*/ 2 , dilations, strides,
843- PoolingType::MIN_SIGNED );
857+ PoolingType::MinSigned );
844858 AffineExpr N = m.dim (0 );
845859 AffineExpr H = m.dim (1 );
846860 AffineExpr W = m.dim (2 );
@@ -850,9 +864,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
850864
851865 return m.matchStride (/* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 , /* idx=*/ 0 )
852866 .matchStride (/* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 , /* idx=*/ 1 )
853- .expectMaps ({/* inputMap=*/ {N, m.strided (H, h, 0 ), m.strided (W, w, 1 ), C},
854- /* filterMap=*/ {h, w},
855- /* outputMap=*/ {N, H, W, C}})
867+ .matchMaps ({/* inputMap=*/ {N, m.strided (H, h, 0 ), m.strided (W, w, 1 ), C},
868+ /* filterMap=*/ {h, w},
869+ /* outputMap=*/ {N, H, W, C}})
856870 .matchBody ();
857871}
858872
@@ -870,7 +884,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
870884 " expected op to implement ConvolutionOpInterface" );
871885
872886 ConvMatcherBuilder m (op, /* spatialRank=*/ 2 , dilations, strides,
873- PoolingType::SUM );
887+ PoolingType::Sum );
874888 AffineExpr N = m.dim (0 );
875889 AffineExpr H = m.dim (1 );
876890 AffineExpr W = m.dim (2 );
@@ -880,9 +894,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
880894
881895 return m.matchStride (/* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 , /* idx=*/ 0 )
882896 .matchStride (/* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 , /* idx=*/ 1 )
883- .expectMaps ({/* inputMap=*/ {N, m.strided (H, h, 0 ), m.strided (W, w, 1 ), C},
884- /* filterMap=*/ {h, w},
885- /* outputMap=*/ {N, H, W, C}})
897+ .matchMaps ({/* inputMap=*/ {N, m.strided (H, h, 0 ), m.strided (W, w, 1 ), C},
898+ /* filterMap=*/ {h, w},
899+ /* outputMap=*/ {N, H, W, C}})
886900 .matchBody ();
887901}
888902
@@ -900,7 +914,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
900914 " expected op to implement ConvolutionOpInterface" );
901915
902916 ConvMatcherBuilder m (op, /* spatialRank=*/ 2 , dilations, strides,
903- PoolingType::MAX_UNSIGNED );
917+ PoolingType::MaxUnsigned );
904918 AffineExpr N = m.dim (0 );
905919 AffineExpr H = m.dim (1 );
906920 AffineExpr W = m.dim (2 );
@@ -910,9 +924,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
910924
911925 return m.matchStride (/* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 , /* idx=*/ 0 )
912926 .matchStride (/* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 , /* idx=*/ 1 )
913- .expectMaps ({/* inputMap=*/ {N, m.strided (H, h, 0 ), m.strided (W, w, 1 ), C},
914- /* filterMap=*/ {h, w},
915- /* outputMap=*/ {N, H, W, C}})
927+ .matchMaps ({/* inputMap=*/ {N, m.strided (H, h, 0 ), m.strided (W, w, 1 ), C},
928+ /* filterMap=*/ {h, w},
929+ /* outputMap=*/ {N, H, W, C}})
916930 .matchBody ();
917931}
918932
@@ -930,7 +944,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
930944 " expected op to implement ConvolutionOpInterface" );
931945
932946 ConvMatcherBuilder m (op, /* spatialRank=*/ 2 , dilations, strides,
933- PoolingType::MIN_UNSIGNED );
947+ PoolingType::MinUnsigned );
934948 AffineExpr N = m.dim (0 );
935949 AffineExpr H = m.dim (1 );
936950 AffineExpr W = m.dim (2 );
@@ -940,9 +954,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
940954
941955 return m.matchStride (/* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 , /* idx=*/ 0 )
942956 .matchStride (/* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 , /* idx=*/ 1 )
943- .expectMaps ({/* inputMap=*/ {N, m.strided (H, h, 0 ), m.strided (W, w, 1 ), C},
944- /* filterMap=*/ {h, w},
945- /* outputMap=*/ {N, H, W, C}})
957+ .matchMaps ({/* inputMap=*/ {N, m.strided (H, h, 0 ), m.strided (W, w, 1 ), C},
958+ /* filterMap=*/ {h, w},
959+ /* outputMap=*/ {N, H, W, C}})
946960 .matchBody ();
947961}
948962
0 commit comments