Skip to content

Commit 62ece64

Browse files
Abhishek-Varmakcloudy0717
authored andcommitted
[NFC][Linalg] Follow-up on ConvMatchBuilder (llvm#170080)
-- This commit addresses [follow-up review comments on 169704](llvm#169704 (review)). -- Contains NFC nit/minor changes. Signed-off-by: Abhishek Varma <[email protected]>
1 parent 0ee4765 commit 62ece64

File tree

1 file changed

+85
-71
lines changed

1 file changed

+85
-71
lines changed

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 85 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -430,19 +430,33 @@ static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
430430
})));
431431
}
432432

433-
/// Enum of all kinds of Pooling Op's type.
434-
enum PoolingType {
435-
NONE,
436-
MAX_SIGNED,
437-
MAX_UNSIGNED,
438-
MIN_SIGNED,
439-
MIN_UNSIGNED,
440-
SUM
433+
/// Enum representing pooling operation types used by ConvMatcherBuilder.
434+
enum class PoolingType {
435+
None,
436+
MaxSigned,
437+
MaxUnsigned,
438+
MinSigned,
439+
MinUnsigned,
440+
Sum
441441
};
442442

443443
/// Helper class for building convolution op matchers with minimal boilerplate.
444444
/// Reduces repetitive code across Conv1D/2D/3D and Depthwise variants as well
445445
/// as Pooling ops.
446+
///
447+
/// Usage: Create an instance with the op, spatial rank, and output pointers for
448+
/// extracted dilations/strides. Then chain matchStride() calls for each spatial
449+
/// dimension, followed by matchMaps() to verify indexing maps, and finally
450+
/// matchBody() to verify the operation body pattern.
451+
///
452+
/// The `matched` flag starts as `true` and is set to `false` if any match step
453+
/// fails. This allows chaining multiple match calls; once any match fails, all
454+
/// subsequent calls become no-ops and the final result is `false`.
455+
///
456+
/// The `dilations` and `strides` pointers are output parameters that get
457+
/// populated with the extracted dilation and stride values from the operation's
458+
/// indexing maps during matchStride() calls. These values are initially set to
459+
/// 1 for each spatial dimension and updated as patterns are matched.
446460
class ConvMatcherBuilder {
447461
LinalgOp op;
448462
MLIRContext *ctx;
@@ -454,7 +468,7 @@ class ConvMatcherBuilder {
454468
public:
455469
ConvMatcherBuilder(LinalgOp op, unsigned spatialRank, SmallVector<int64_t> *d,
456470
SmallVector<int64_t> *s,
457-
PoolingType poolingType = PoolingType::NONE)
471+
PoolingType poolingType = PoolingType::None)
458472
: op(op), ctx(op->getContext()), dilations(d), strides(s),
459473
indexingMaps(op.getIndexingMaps()), poolingType(poolingType) {
460474
*dilations = SmallVector<int64_t>(spatialRank, 1);
@@ -474,16 +488,16 @@ class ConvMatcherBuilder {
474488
ConvMatcherBuilder &matchStride(unsigned iDim, unsigned fDim, unsigned oDim,
475489
unsigned idx) {
476490
if (matched) {
477-
matched = matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
478-
(*dilations)[idx], (*strides)[idx]);
491+
matched &= matchConvDimAddExprPattern(indexingMaps, iDim, fDim, oDim,
492+
(*dilations)[idx], (*strides)[idx]);
479493
}
480494
return *this;
481495
}
482496

483497
/// Match expected indexing maps layout. Returns *this for method chaining.
484-
ConvMatcherBuilder &expectMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
498+
ConvMatcherBuilder &matchMaps(ArrayRef<ArrayRef<AffineExpr>> maps) {
485499
if (matched)
486-
matched = convLayoutMatches(maps, indexingMaps, ctx);
500+
matched &= convLayoutMatches(maps, indexingMaps, ctx);
487501
return *this;
488502
}
489503

@@ -494,17 +508,17 @@ class ConvMatcherBuilder {
494508
Block *body = op.getBlock();
495509
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
496510
switch (poolingType) {
497-
case PoolingType::NONE:
511+
case PoolingType::None:
498512
return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body);
499-
case PoolingType::MAX_SIGNED:
513+
case PoolingType::MaxSigned:
500514
return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body);
501-
case PoolingType::MAX_UNSIGNED:
515+
case PoolingType::MaxUnsigned:
502516
return bodyMatcherForMaxUnsignedPoolOps(yieldOp.getOperand(0), body);
503-
case PoolingType::MIN_SIGNED:
517+
case PoolingType::MinSigned:
504518
return bodyMatcherForMinSignedPoolOps(yieldOp.getOperand(0), body);
505-
case PoolingType::MIN_UNSIGNED:
519+
case PoolingType::MinUnsigned:
506520
return bodyMatcherForMinUnsignedPoolOps(yieldOp.getOperand(0), body);
507-
case PoolingType::SUM:
521+
case PoolingType::Sum:
508522
return bodyMatcherForSumPoolOps(yieldOp.getOperand(0), body);
509523
}
510524
return false;
@@ -533,9 +547,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DOp>(LinalgOp op,
533547
AffineExpr w = m.dim(1);
534548

535549
return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
536-
.expectMaps({/*inputMap=*/{m.strided(W, w, 0)},
537-
/*filterMap=*/{w},
538-
/*outputMap=*/{W}})
550+
.matchMaps({/*inputMap=*/{m.strided(W, w, 0)},
551+
/*filterMap=*/{w},
552+
/*outputMap=*/{W}})
539553
.matchBody();
540554
}
541555

@@ -560,9 +574,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNwcWcfOp>(
560574
AffineExpr c = m.dim(4);
561575

562576
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
563-
.expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
564-
/*filterMap=*/{w, c, F},
565-
/*outputMap=*/{N, W, F}})
577+
.matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), c},
578+
/*filterMap=*/{w, c, F},
579+
/*outputMap=*/{N, W, F}})
566580
.matchBody();
567581
}
568582

@@ -587,9 +601,9 @@ bool isaConvolutionOpOfType<linalg::Conv1DNcwFcwOp>(
587601
AffineExpr w = m.dim(4);
588602

589603
return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0)
590-
.expectMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
591-
/*filterMap=*/{F, c, w},
592-
/*outputMap=*/{N, F, W}})
604+
.matchMaps({/*inputMap=*/{N, c, m.strided(W, w, 0)},
605+
/*filterMap=*/{F, c, w},
606+
/*outputMap=*/{N, F, W}})
593607
.matchBody();
594608
}
595609

@@ -614,9 +628,9 @@ bool isaConvolutionOpOfType<linalg::Conv2DOp>(LinalgOp op,
614628

615629
return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
616630
.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
617-
.expectMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
618-
/*filterMap=*/{h, w},
619-
/*outputMap=*/{H, W}})
631+
.matchMaps({/*inputMap=*/{m.strided(H, h, 0), m.strided(W, w, 1)},
632+
/*filterMap=*/{h, w},
633+
/*outputMap=*/{H, W}})
620634
.matchBody();
621635
}
622636

@@ -644,10 +658,10 @@ bool isaConvolutionOpOfType<linalg::Conv3DOp>(LinalgOp op,
644658
return m.matchStride(/*iDim=*/0, /*fDim=*/0, /*oDim=*/0, /*idx=*/0)
645659
.matchStride(/*iDim=*/1, /*fDim=*/1, /*oDim=*/1, /*idx=*/1)
646660
.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/2)
647-
.expectMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
648-
m.strided(W, w, 2)},
649-
/*filterMap=*/{d, h, w},
650-
/*outputMap=*/{D, H, W}})
661+
.matchMaps({/*inputMap=*/{m.strided(D, d, 0), m.strided(H, h, 1),
662+
m.strided(W, w, 2)},
663+
/*filterMap=*/{d, h, w},
664+
/*outputMap=*/{D, H, W}})
651665
.matchBody();
652666
}
653667

@@ -671,9 +685,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNcwCwOp>(
671685
AffineExpr w = m.dim(3);
672686

673687
return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
674-
.expectMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
675-
/*filterMap=*/{C, w},
676-
/*outputMap=*/{N, C, W}})
688+
.matchMaps({/*inputMap=*/{N, C, m.strided(W, w, 0)},
689+
/*filterMap=*/{C, w},
690+
/*outputMap=*/{N, C, W}})
677691
.matchBody();
678692
}
679693

@@ -697,9 +711,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
697711
AffineExpr w = m.dim(3);
698712

699713
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
700-
.expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
701-
/*filterMap=*/{w, C},
702-
/*outputMap=*/{N, W, C}})
714+
.matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
715+
/*filterMap=*/{w, C},
716+
/*outputMap=*/{N, W, C}})
703717
.matchBody();
704718
}
705719

@@ -724,9 +738,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcmOp>(
724738
AffineExpr w = m.dim(4);
725739

726740
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
727-
.expectMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
728-
/*filterMap=*/{w, C, CM},
729-
/*outputMap=*/{N, W, C, CM}})
741+
.matchMaps({/*inputMap=*/{N, m.strided(W, w, 0), C},
742+
/*filterMap=*/{w, C, CM},
743+
/*outputMap=*/{N, W, C, CM}})
730744
.matchBody();
731745
}
732746

@@ -753,9 +767,9 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
753767

754768
return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0)
755769
.matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1)
756-
.expectMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
757-
/*filterMap=*/{C, h, w},
758-
/*outputMap=*/{N, C, H, W}})
770+
.matchMaps({/*inputMap=*/{N, C, m.strided(H, h, 0), m.strided(W, w, 1)},
771+
/*filterMap=*/{C, h, w},
772+
/*outputMap=*/{N, C, H, W}})
759773
.matchBody();
760774
}
761775

@@ -789,10 +803,10 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
789803
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
790804
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
791805
.matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2)
792-
.expectMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
793-
m.strided(W, w, 2), C},
794-
/*filterMap=*/{d, h, w, C, CM},
795-
/*outputMap=*/{N, D, H, W, C, CM}})
806+
.matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1),
807+
m.strided(W, w, 2), C},
808+
/*filterMap=*/{d, h, w, C, CM},
809+
/*outputMap=*/{N, D, H, W, C, CM}})
796810
.matchBody();
797811
}
798812

@@ -810,7 +824,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
810824
"expected op to implement ConvolutionOpInterface");
811825

812826
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
813-
PoolingType::MAX_SIGNED);
827+
PoolingType::MaxSigned);
814828
AffineExpr N = m.dim(0);
815829
AffineExpr H = m.dim(1);
816830
AffineExpr W = m.dim(2);
@@ -820,9 +834,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
820834

821835
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
822836
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
823-
.expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
824-
/*filterMap=*/{h, w},
825-
/*outputMap=*/{N, H, W, C}})
837+
.matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
838+
/*filterMap=*/{h, w},
839+
/*outputMap=*/{N, H, W, C}})
826840
.matchBody();
827841
}
828842

@@ -840,7 +854,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
840854
"expected op to implement ConvolutionOpInterface");
841855

842856
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
843-
PoolingType::MIN_SIGNED);
857+
PoolingType::MinSigned);
844858
AffineExpr N = m.dim(0);
845859
AffineExpr H = m.dim(1);
846860
AffineExpr W = m.dim(2);
@@ -850,9 +864,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
850864

851865
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
852866
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
853-
.expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
854-
/*filterMap=*/{h, w},
855-
/*outputMap=*/{N, H, W, C}})
867+
.matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
868+
/*filterMap=*/{h, w},
869+
/*outputMap=*/{N, H, W, C}})
856870
.matchBody();
857871
}
858872

@@ -870,7 +884,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
870884
"expected op to implement ConvolutionOpInterface");
871885

872886
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
873-
PoolingType::SUM);
887+
PoolingType::Sum);
874888
AffineExpr N = m.dim(0);
875889
AffineExpr H = m.dim(1);
876890
AffineExpr W = m.dim(2);
@@ -880,9 +894,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
880894

881895
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
882896
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
883-
.expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
884-
/*filterMap=*/{h, w},
885-
/*outputMap=*/{N, H, W, C}})
897+
.matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
898+
/*filterMap=*/{h, w},
899+
/*outputMap=*/{N, H, W, C}})
886900
.matchBody();
887901
}
888902

@@ -900,7 +914,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
900914
"expected op to implement ConvolutionOpInterface");
901915

902916
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
903-
PoolingType::MAX_UNSIGNED);
917+
PoolingType::MaxUnsigned);
904918
AffineExpr N = m.dim(0);
905919
AffineExpr H = m.dim(1);
906920
AffineExpr W = m.dim(2);
@@ -910,9 +924,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
910924

911925
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
912926
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
913-
.expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
914-
/*filterMap=*/{h, w},
915-
/*outputMap=*/{N, H, W, C}})
927+
.matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
928+
/*filterMap=*/{h, w},
929+
/*outputMap=*/{N, H, W, C}})
916930
.matchBody();
917931
}
918932

@@ -930,7 +944,7 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
930944
"expected op to implement ConvolutionOpInterface");
931945

932946
ConvMatcherBuilder m(op, /*spatialRank=*/2, dilations, strides,
933-
PoolingType::MIN_UNSIGNED);
947+
PoolingType::MinUnsigned);
934948
AffineExpr N = m.dim(0);
935949
AffineExpr H = m.dim(1);
936950
AffineExpr W = m.dim(2);
@@ -940,9 +954,9 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
940954

941955
return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0)
942956
.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1)
943-
.expectMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
944-
/*filterMap=*/{h, w},
945-
/*outputMap=*/{N, H, W, C}})
957+
.matchMaps({/*inputMap=*/{N, m.strided(H, h, 0), m.strided(W, w, 1), C},
958+
/*filterMap=*/{h, w},
959+
/*outputMap=*/{N, H, W, C}})
946960
.matchBody();
947961
}
948962

0 commit comments

Comments
 (0)