Skip to content

Commit 8c828fd

Browse files
Change the way you compare maps
1 parent e811d48 commit 8c828fd

File tree

1 file changed

+153
-103
lines changed

1 file changed

+153
-103
lines changed

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 153 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,20 @@ static bool verifyConvIndexingMapSizes(ArrayAttr indexingMaps,
442442
// Matchers for specific convolution operation.
443443
// ---------------------------------------------
444444

445+
/// Returns true if the given indexing maps matches with the expected indexing
446+
/// maps.
447+
static bool convLayoutMatches(ArrayRef<ArrayRef<AffineExpr>> mapListExpected,
448+
ArrayAttr indexingMaps, MLIRContext *context) {
449+
SmallVector<AffineMap, 4> expectedIndexingMaps =
450+
AffineMap::inferFromExprList(mapListExpected, context);
451+
return indexingMaps ==
452+
ArrayAttr::get(
453+
context, llvm::to_vector<4>(llvm::map_range(
454+
expectedIndexingMaps, [&](AffineMap m) -> Attribute {
455+
return AffineMapAttr::get(m);
456+
})));
457+
}
458+
445459
// #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)>
446460
// #filterMap = affine_map<(N, W, C, w) -> (w, C)>
447461
// #outputMap = affine_map<(N, W, C, w) -> (N, W, C)>
@@ -459,25 +473,25 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv1DNwcWcOp>(
459473
if (!verifyConvIndexingMapSizes(indexingMaps, {3, 2, 3}))
460474
return false;
461475

462-
unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2;
463-
464476
*dilations = SmallVector<int64_t>(1, 1);
465477
*strides = SmallVector<int64_t>(1, 1);
466-
// Match: N
467-
if (getAffineMapDim(indexingMaps, inputMapIdx, 0) !=
468-
getAffineMapDim(indexingMaps, outputMapIdx, 0))
469-
return false;
470-
// Match: C
471-
if (getAffineMapDim(indexingMaps, inputMapIdx, 2) !=
472-
getAffineMapDim(indexingMaps, filterMapIdx, 1))
473-
return false;
474-
if (getAffineMapDim(indexingMaps, inputMapIdx, 2) !=
475-
getAffineMapDim(indexingMaps, outputMapIdx, 2))
476-
return false;
477-
// Match: W + w
478+
MLIRContext *context = op->getContext();
479+
AffineExpr N = getAffineDimExpr(0, context);
480+
AffineExpr W = getAffineDimExpr(1, context);
481+
AffineExpr C = getAffineDimExpr(2, context);
482+
AffineExpr w = getAffineDimExpr(3, context);
483+
// First fetch dilations/strides :-
484+
// Match: W * stride + w * dilation
478485
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
479486
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
480487
return false;
488+
// Match expected indexing maps
489+
if (!convLayoutMatches(
490+
{/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C},
491+
/*filterMap=*/{w, C},
492+
/*outputMap=*/{N, W, C}},
493+
indexingMaps, context))
494+
return false;
481495
// Match body
482496
Block *body = op.getBlock();
483497
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
@@ -504,29 +518,32 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv2DNchwChwOp>(
504518
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 3, 4}))
505519
return false;
506520

507-
unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2;
508-
509521
*dilations = SmallVector<int64_t>(2, 1);
510522
*strides = SmallVector<int64_t>(2, 1);
511-
// Match: N
512-
if (getAffineMapDim(indexingMaps, inputMapIdx, 0) !=
513-
getAffineMapDim(indexingMaps, outputMapIdx, 0))
514-
return false;
515-
// Match: C
516-
if (getAffineMapDim(indexingMaps, inputMapIdx, 1) !=
517-
getAffineMapDim(indexingMaps, filterMapIdx, 0))
518-
return false;
519-
if (getAffineMapDim(indexingMaps, inputMapIdx, 1) !=
520-
getAffineMapDim(indexingMaps, outputMapIdx, 1))
521-
return false;
522-
// Match: H + h
523+
MLIRContext *context = op->getContext();
524+
AffineExpr N = getAffineDimExpr(0, context);
525+
AffineExpr H = getAffineDimExpr(1, context);
526+
AffineExpr W = getAffineDimExpr(2, context);
527+
AffineExpr C = getAffineDimExpr(3, context);
528+
AffineExpr h = getAffineDimExpr(4, context);
529+
AffineExpr w = getAffineDimExpr(5, context);
530+
// First fetch dilations/strides :-
531+
// Match: H * stride + h * dilation
523532
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
524533
/*oDim=*/2, (*dilations)[0], (*strides)[0]))
525534
return false;
526-
// Match: W + w
535+
// Match: W * stride + w * dilation
527536
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
528537
/*oDim=*/3, (*dilations)[1], (*strides)[1]))
529538
return false;
539+
// Match expected indexing maps
540+
if (!convLayoutMatches(
541+
{/*inputMap=*/{N, C, H * (*strides)[0] + h * (*dilations)[0],
542+
W * (*strides)[1] + w * (*dilations)[1]},
543+
/*filterMap=*/{C, h, w},
544+
/*outputMap=*/{N, C, H, W}},
545+
indexingMaps, context))
546+
return false;
530547
// Match body
531548
Block *body = op.getBlock();
532549
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
@@ -556,36 +573,39 @@ bool isaConvolutionOpOfType<linalg::DepthwiseConv3DNdhwcDhwcmOp>(
556573
if (!verifyConvIndexingMapSizes(indexingMaps, {5, 5, 6}))
557574
return false;
558575

559-
unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = 2;
560-
561576
*dilations = SmallVector<int64_t>(3, 1);
562577
*strides = SmallVector<int64_t>(3, 1);
563-
// Match: N
564-
if (getAffineMapDim(indexingMaps, inputMapIdx, 0) !=
565-
getAffineMapDim(indexingMaps, outputMapIdx, 0))
566-
return false;
567-
// Match: D + d
578+
MLIRContext *context = op->getContext();
579+
AffineExpr N = getAffineDimExpr(0, context);
580+
AffineExpr D = getAffineDimExpr(1, context);
581+
AffineExpr H = getAffineDimExpr(2, context);
582+
AffineExpr W = getAffineDimExpr(3, context);
583+
AffineExpr CM = getAffineDimExpr(4, context);
584+
AffineExpr d = getAffineDimExpr(5, context);
585+
AffineExpr h = getAffineDimExpr(6, context);
586+
AffineExpr w = getAffineDimExpr(7, context);
587+
AffineExpr C = getAffineDimExpr(8, context);
588+
// First fetch dilations/strides :-
589+
// Match: D * stride + d * dilation
568590
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
569591
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
570592
return false;
571-
// Match: H + h
593+
// Match: H * stride + h * dilation
572594
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
573595
/*oDim=*/2, (*dilations)[1], (*strides)[1]))
574596
return false;
575-
// Match: W + w
597+
// Match: W * stride + w * dilation
576598
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2,
577599
/*oDim=*/3, (*dilations)[2], (*strides)[2]))
578600
return false;
579-
// Match: C
580-
if (getAffineMapDim(indexingMaps, inputMapIdx, 4) !=
581-
getAffineMapDim(indexingMaps, filterMapIdx, 3))
582-
return false;
583-
if (getAffineMapDim(indexingMaps, inputMapIdx, 4) !=
584-
getAffineMapDim(indexingMaps, outputMapIdx, 4))
585-
return false;
586-
// Match: CM
587-
if (getAffineMapDim(indexingMaps, filterMapIdx, 4) !=
588-
getAffineMapDim(indexingMaps, outputMapIdx, 5))
601+
// Match expected indexing maps
602+
if (!convLayoutMatches(
603+
{/*inputMap=*/{N, D * (*strides)[0] + d * (*dilations)[0],
604+
H * (*strides)[1] + h * (*dilations)[1],
605+
W * (*strides)[2] + w * (*dilations)[2], C},
606+
/*filterMap=*/{d, h, w, C, CM},
607+
/*outputMap=*/{N, D, H, W, C, CM}},
608+
indexingMaps, context))
589609
return false;
590610
// Match body
591611
Block *body = op.getBlock();
@@ -613,25 +633,31 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxOp>(
613633
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
614634
return false;
615635

616-
unsigned inputMapIdx = 0, outputMapIdx = 2;
617-
618636
*dilations = SmallVector<int64_t>(2, 1);
619637
*strides = SmallVector<int64_t>(2, 1);
620-
// Match: N
621-
if (getAffineMapDim(indexingMaps, inputMapIdx, 0) !=
622-
getAffineMapDim(indexingMaps, outputMapIdx, 0))
623-
return false;
624-
// Match: H + h
638+
MLIRContext *context = op->getContext();
639+
AffineExpr N = getAffineDimExpr(0, context);
640+
AffineExpr H = getAffineDimExpr(1, context);
641+
AffineExpr W = getAffineDimExpr(2, context);
642+
AffineExpr C = getAffineDimExpr(3, context);
643+
AffineExpr h = getAffineDimExpr(4, context);
644+
AffineExpr w = getAffineDimExpr(5, context);
645+
// First fetch dilations/strides :-
646+
// Match: H * stride + h * dilation
625647
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
626648
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
627649
return false;
628-
// Match: W + w
650+
// Match: W * stride + w * dilation
629651
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
630652
/*oDim=*/2, (*dilations)[1], (*strides)[1]))
631653
return false;
632-
// Match: C
633-
if (getAffineMapDim(indexingMaps, inputMapIdx, 3) !=
634-
getAffineMapDim(indexingMaps, outputMapIdx, 3))
654+
// Match expected indexing maps
655+
if (!convLayoutMatches(
656+
{/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
657+
W * (*strides)[1] + w * (*dilations)[1], C},
658+
/*filterMap=*/{h, w},
659+
/*outputMap=*/{N, H, W, C}},
660+
indexingMaps, context))
635661
return false;
636662
// Match body
637663
Block *body = op.getBlock();
@@ -659,25 +685,31 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinOp>(
659685
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
660686
return false;
661687

662-
unsigned inputMapIdx = 0, outputMapIdx = 2;
663-
664688
*dilations = SmallVector<int64_t>(2, 1);
665689
*strides = SmallVector<int64_t>(2, 1);
666-
// Match: N
667-
if (getAffineMapDim(indexingMaps, inputMapIdx, 0) !=
668-
getAffineMapDim(indexingMaps, outputMapIdx, 0))
669-
return false;
670-
// Match: H + h
690+
MLIRContext *context = op->getContext();
691+
AffineExpr N = getAffineDimExpr(0, context);
692+
AffineExpr H = getAffineDimExpr(1, context);
693+
AffineExpr W = getAffineDimExpr(2, context);
694+
AffineExpr C = getAffineDimExpr(3, context);
695+
AffineExpr h = getAffineDimExpr(4, context);
696+
AffineExpr w = getAffineDimExpr(5, context);
697+
// First fetch dilations/strides :-
698+
// Match: H * stride + h * dilation
671699
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
672700
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
673701
return false;
674-
// Match: W + w
702+
// Match: W * stride + w * dilation
675703
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
676704
/*oDim=*/2, (*dilations)[1], (*strides)[1]))
677705
return false;
678-
// Match: C
679-
if (getAffineMapDim(indexingMaps, inputMapIdx, 3) !=
680-
getAffineMapDim(indexingMaps, outputMapIdx, 3))
706+
// Match expected indexing maps
707+
if (!convLayoutMatches(
708+
{/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
709+
W * (*strides)[1] + w * (*dilations)[1], C},
710+
/*filterMap=*/{h, w},
711+
/*outputMap=*/{N, H, W, C}},
712+
indexingMaps, context))
681713
return false;
682714
// Match body
683715
Block *body = op.getBlock();
@@ -705,25 +737,31 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcSumOp>(
705737
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
706738
return false;
707739

708-
unsigned inputMapIdx = 0, outputMapIdx = 2;
709-
710740
*dilations = SmallVector<int64_t>(2, 1);
711741
*strides = SmallVector<int64_t>(2, 1);
712-
// Match: N
713-
if (getAffineMapDim(indexingMaps, inputMapIdx, 0) !=
714-
getAffineMapDim(indexingMaps, outputMapIdx, 0))
715-
return false;
716-
// Match: H + h
742+
MLIRContext *context = op->getContext();
743+
AffineExpr N = getAffineDimExpr(0, context);
744+
AffineExpr H = getAffineDimExpr(1, context);
745+
AffineExpr W = getAffineDimExpr(2, context);
746+
AffineExpr C = getAffineDimExpr(3, context);
747+
AffineExpr h = getAffineDimExpr(4, context);
748+
AffineExpr w = getAffineDimExpr(5, context);
749+
// First fetch dilations/strides :-
750+
// Match: H * stride + h * dilation
717751
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
718752
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
719753
return false;
720-
// Match: W + w
754+
// Match: W * stride + w * dilation
721755
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
722756
/*oDim=*/2, (*dilations)[1], (*strides)[1]))
723757
return false;
724-
// Match: C
725-
if (getAffineMapDim(indexingMaps, inputMapIdx, 3) !=
726-
getAffineMapDim(indexingMaps, outputMapIdx, 3))
758+
// Match expected indexing maps
759+
if (!convLayoutMatches(
760+
{/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
761+
W * (*strides)[1] + w * (*dilations)[1], C},
762+
/*filterMap=*/{h, w},
763+
/*outputMap=*/{N, H, W, C}},
764+
indexingMaps, context))
727765
return false;
728766
// Match body
729767
Block *body = op.getBlock();
@@ -751,25 +789,31 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMaxUnsignedOp>(
751789
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
752790
return false;
753791

754-
unsigned inputMapIdx = 0, outputMapIdx = 2;
755-
756792
*dilations = SmallVector<int64_t>(2, 1);
757793
*strides = SmallVector<int64_t>(2, 1);
758-
// Match: N
759-
if (getAffineMapDim(indexingMaps, inputMapIdx, 0) !=
760-
getAffineMapDim(indexingMaps, outputMapIdx, 0))
761-
return false;
762-
// Match: H + h
794+
MLIRContext *context = op->getContext();
795+
AffineExpr N = getAffineDimExpr(0, context);
796+
AffineExpr H = getAffineDimExpr(1, context);
797+
AffineExpr W = getAffineDimExpr(2, context);
798+
AffineExpr C = getAffineDimExpr(3, context);
799+
AffineExpr h = getAffineDimExpr(4, context);
800+
AffineExpr w = getAffineDimExpr(5, context);
801+
// First fetch dilations/strides :-
802+
// Match: H * stride + h * dilation
763803
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
764804
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
765805
return false;
766-
// Match: W + w
806+
// Match: W * stride + w * dilation
767807
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
768808
/*oDim=*/2, (*dilations)[1], (*strides)[1]))
769809
return false;
770-
// Match: C
771-
if (getAffineMapDim(indexingMaps, inputMapIdx, 3) !=
772-
getAffineMapDim(indexingMaps, outputMapIdx, 3))
810+
// Match expected indexing maps
811+
if (!convLayoutMatches(
812+
{/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
813+
W * (*strides)[1] + w * (*dilations)[1], C},
814+
/*filterMap=*/{h, w},
815+
/*outputMap=*/{N, H, W, C}},
816+
indexingMaps, context))
773817
return false;
774818
// Match body
775819
Block *body = op.getBlock();
@@ -797,25 +841,31 @@ bool isaConvolutionOpOfType<linalg::PoolingNhwcMinUnsignedOp>(
797841
if (!verifyConvIndexingMapSizes(indexingMaps, {4, 2, 4}))
798842
return false;
799843

800-
unsigned inputMapIdx = 0, outputMapIdx = 2;
801-
802844
*dilations = SmallVector<int64_t>(2, 1);
803845
*strides = SmallVector<int64_t>(2, 1);
804-
// Match: N
805-
if (getAffineMapDim(indexingMaps, inputMapIdx, 0) !=
806-
getAffineMapDim(indexingMaps, outputMapIdx, 0))
807-
return false;
808-
// Match: H + h
846+
MLIRContext *context = op->getContext();
847+
AffineExpr N = getAffineDimExpr(0, context);
848+
AffineExpr H = getAffineDimExpr(1, context);
849+
AffineExpr W = getAffineDimExpr(2, context);
850+
AffineExpr C = getAffineDimExpr(3, context);
851+
AffineExpr h = getAffineDimExpr(4, context);
852+
AffineExpr w = getAffineDimExpr(5, context);
853+
// First fetch dilations/strides :-
854+
// Match: H * stride + h * dilation
809855
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0,
810856
/*oDim=*/1, (*dilations)[0], (*strides)[0]))
811857
return false;
812-
// Match: W + w
858+
// Match: W * stride + w * dilation
813859
if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1,
814860
/*oDim=*/2, (*dilations)[1], (*strides)[1]))
815861
return false;
816-
// Match: C
817-
if (getAffineMapDim(indexingMaps, inputMapIdx, 3) !=
818-
getAffineMapDim(indexingMaps, outputMapIdx, 3))
862+
// Match expected indexing maps
863+
if (!convLayoutMatches(
864+
{/*inputMap=*/{N, H * (*strides)[0] + h * (*dilations)[0],
865+
W * (*strides)[1] + w * (*dilations)[1], C},
866+
/*filterMap=*/{h, w},
867+
/*outputMap=*/{N, H, W, C}},
868+
indexingMaps, context))
819869
return false;
820870
// Match body
821871
Block *body = op.getBlock();

0 commit comments

Comments
 (0)