Skip to content

Commit a08247c

Browse files
Some more APIs
1 parent bafdb41 commit a08247c

File tree

3 files changed

+306
-90
lines changed

3 files changed

+306
-90
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,20 @@ bool isaConv1DNcwFcwOp(LinalgOp op);
120120
bool isaDepthwiseConv1DNcwCwOp(LinalgOp op);
121121
bool isaDepthwiseConv1DNwcWcOp(LinalgOp op);
122122
bool isaDepthwiseConv1DNwcWcmOp(LinalgOp op);
123+
bool isaConv2DOp(LinalgOp op);
124+
bool isaConv2DNhwcFhwcOp(LinalgOp op);
125+
bool isaConv2DNhwcHwcfOp(LinalgOp op);
126+
bool isaConv2DNchwFchwOp(LinalgOp op);
127+
bool isaConv2DNhwcFhwcQOp(LinalgOp op);
128+
bool isaConv2DNchwFchwQOp(LinalgOp op);
129+
bool isaConv2DNgchwFgchwOp(LinalgOp op);
130+
bool isaConv2DNgchwGfchwOp(LinalgOp op);
131+
bool isaConv2DNgchwGfchwQOp(LinalgOp op);
132+
bool isaConv2DNhwgcGfhwcOp(LinalgOp op);
133+
bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
134+
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
135+
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
136+
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
123137

124138
//===----------------------------------------------------------------------===//
125139
// Fusion / Tiling utilities

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 15 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -361,10 +361,10 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
361361
// #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
362362
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
363363
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
364-
if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
365-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1))
364+
if (isaConv2DOp(genericOp))
366365
return "linalg.conv_2d";
367366

367+
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
368368
Block *body = genericOp.getBlock();
369369
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
370370
Value yieldVal = yieldOp.getOperand(0);
@@ -432,21 +432,13 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
432432
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
433433
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
434434
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
435-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
436-
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
437-
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
438-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
439-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3))
435+
if (isaDepthwiseConv2DNchwChwOp(genericOp))
440436
return "linalg.depthwise_conv_2d_nchw_chw";
441437
// depthwise_conv_2d_nhwc_hwc
442438
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
443439
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
444440
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
445-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
446-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
447-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
448-
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
449-
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3))
441+
if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
450442
return "linalg.depthwise_conv_2d_nhwc_hwc";
451443
// conv_3d
452444
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
@@ -511,90 +503,50 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
511503
static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
512504
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
513505
if (indexingMaps.size() < 3) return "";
514-
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
515506
// conv_2d_nhwc_fhwc
516507
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
517508
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
518509
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
519-
if (indexingMaps.size() == 3 &&
520-
matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
521-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
522-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
523-
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
524-
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3))
510+
if (isaConv2DNhwcFhwcOp(genericOp))
525511
return "linalg.conv_2d_nhwc_fhwc";
526512
// conv_2d_nhwc_hwcf
527513
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
528514
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
529515
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
530-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
531-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
532-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
533-
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
534-
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3))
516+
if (isaConv2DNhwcHwcfOp(genericOp))
535517
return "linalg.conv_2d_nhwc_hwcf";
536518
// conv_2d_nchw_fchw
537519
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
538520
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
539521
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
540-
if (indexingMaps.size() == 3 &&
541-
matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
542-
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
543-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
544-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
545-
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
522+
if (isaConv2DNchwFchwOp(genericOp))
546523
return "linalg.conv_2d_nchw_fchw";
547524
// conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
548525
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
549526
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
550527
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
551528
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
552-
if (indexingMaps.size() == 5 &&
553-
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
554-
matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
555-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
556-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
557-
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 3) &&
558-
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 3))
529+
if (isaConv2DNhwcFhwcQOp(genericOp))
559530
return "linalg.conv_2d_nhwc_fhwc_q";
560531
// conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
561532
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
562533
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
563534
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
564535
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
565-
if (indexingMaps.size() == 5 &&
566-
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
567-
matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
568-
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
569-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
570-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
571-
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
536+
if (isaConv2DNchwFchwQOp(genericOp))
572537
return "linalg.conv_2d_nchw_fchw_q";
573538
// depthwise_conv_2d_nhwc_hwcm
574539
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
575540
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
576541
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
577-
if (indexingMaps.size() == 3 &&
578-
matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
579-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
580-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
581-
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
582-
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
583-
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4))
542+
if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
584543
return "linalg.depthwise_conv_2d_nhwc_hwcm";
585544
// depthwise_conv_2d_nhwc_hwcm_q
586545
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
587546
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
588547
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
589548
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
590-
if (indexingMaps.size() == 5 &&
591-
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
592-
matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
593-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
594-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
595-
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
596-
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
597-
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4))
549+
if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
598550
return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
599551
return "";
600552
}
@@ -607,53 +559,26 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
607559
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
608560
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
609561
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
610-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
611-
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
612-
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
613-
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
614-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
615-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
616-
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 2))
562+
if (isaConv2DNgchwFgchwOp(genericOp))
617563
return "linalg.conv_2d_ngchw_fgchw";
618564
// conv_2d_ngchw_gfchw
619565
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
620566
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
621567
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
622-
if (indexingMaps.size() == 3 &&
623-
matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
624-
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
625-
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
626-
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
627-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
628-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
629-
matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2))
568+
if (isaConv2DNgchwGfchwOp(genericOp))
630569
return "linalg.conv_2d_ngchw_gfchw";
631570
// conv_2d_ngchw_gfchw_q
632571
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
633572
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
634573
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
635574
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
636-
if (indexingMaps.size() == 5 &&
637-
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
638-
matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
639-
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
640-
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
641-
matchConvDimExprPattern(indexingMaps, iIndex, 2, fIndex, 2) &&
642-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
643-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
644-
matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2))
575+
if (isaConv2DNgchwGfchwQOp(genericOp))
645576
return "linalg.conv_2d_ngchw_gfchw_q";
646577
// conv_2d_nhwgc_gfhwc
647578
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
648579
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
649580
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
650-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
651-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
652-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
653-
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
654-
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
655-
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
656-
matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4))
581+
if (isaConv2DNhwgcGfhwcOp(genericOp))
657582
return "linalg.conv_2d_nhwgc_gfhwc";
658583
// depthwise_conv_3d_ncdhw_cdhw
659584
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>

0 commit comments

Comments
 (0)