Skip to content

Commit 87b91ee

Browse files
Add 3D APIs
1 parent 053d912 commit 87b91ee

File tree

3 files changed

+132
-61
lines changed

3 files changed

+132
-61
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
134134
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
135135
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
136136
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
137+
bool isaConv3DOp(LinalgOp op);
138+
bool isaConv3DNcdhwFcdhwOp(LinalgOp op);
139+
bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
140+
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
141+
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
142+
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);
137143

138144
//===----------------------------------------------------------------------===//
139145
// Fusion / Tiling utilities

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 9 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -406,13 +406,7 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
406406
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
407407
if (indexingMaps.size() < 3) return "";
408408
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
409-
// conv_3d
410-
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
411-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
412-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
413-
if (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
414-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
415-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2))
409+
if (isaConv3DOp(genericOp))
416410
return "linalg.conv_3d";
417411

418412
Block *body = genericOp.getBlock();
@@ -493,29 +487,14 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
493487
return "linalg.conv_2d_ngchw_gfchw_q";
494488
if (isaConv2DNhwgcGfhwcOp(genericOp))
495489
return "linalg.conv_2d_nhwgc_gfhwc";
496-
// depthwise_conv_3d_ncdhw_cdhw
497-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
498-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
499-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
500-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
501-
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
502-
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
503-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
504-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
505-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4))
490+
if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp))
506491
return "linalg.depthwise_conv_3d_ncdhw_cdhw";
507-
// depthwise_conv_3d_ndhwc_dhwc
508-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
509-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
510-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
511-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
512-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
513-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
514-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
515-
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
516-
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4))
492+
if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
517493
return "linalg.depthwise_conv_3d_ndhwc_dhwc";
518494

495+
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
496+
if (indexingMaps.size() < 3) return "";
497+
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
519498
Block *body = genericOp.getBlock();
520499
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
521500
Value yieldVal = yieldOp.getOperand(0);
@@ -541,42 +520,11 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
541520
}
542521

543522
static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
544-
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
545-
if (indexingMaps.size() < 3) return "";
546-
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
547-
// conv_3d_ncdhw_fcdhw
548-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
549-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
550-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
551-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
552-
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
553-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
554-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
555-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
556-
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1))
523+
if (isaConv3DNcdhwFcdhwOp(genericOp))
557524
return "linalg.conv_3d_ncdhw_fcdhw";
558-
// conv_3d_ndhwc_dhwcf
559-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
560-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
561-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
562-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
563-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
564-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
565-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
566-
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
567-
matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4))
525+
if (isaConv3DNdhwcDhwcfOp(genericOp))
568526
return "linalg.conv_3d_ndhwc_dhwcf";
569-
// depthwise_conv_3d_ndhwc_dhwcm
570-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
571-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
572-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
573-
if (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
574-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
575-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
576-
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
577-
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
578-
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
579-
matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5))
527+
if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp))
580528
return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
581529
return "";
582530
}

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,123 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
736736
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
737737
}
738738

739+
bool isaConv3DOp(LinalgOp op) {
740+
if (isa<linalg::Conv1DOp>(op)) return true;
741+
742+
if (!isaConvolutionOpInterface(op)) return false;
743+
744+
ArrayAttr indexingMaps = op.getIndexingMaps();
745+
if (!verifyConvIndexingMapSizes(indexingMaps, {3,3,3})) return false;
746+
747+
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
748+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
749+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
750+
return (matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, /*oDim=*/0) &&
751+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, /*oDim=*/1) &&
752+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2));
753+
}
754+
755+
bool isaConv3DNcdhwFcdhwOp(LinalgOp op) {
756+
if (isa<linalg::Conv1DOp>(op)) return true;
757+
758+
if (!isaConvolutionOpInterface(op)) return false;
759+
760+
ArrayAttr indexingMaps = op.getIndexingMaps();
761+
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
762+
763+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
764+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
765+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
766+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
767+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
768+
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 1) &&
769+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, /*oDim=*/2) &&
770+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/3, /*oDim=*/3) &&
771+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/4, /*oDim=*/4) &&
772+
matchConvDimExprPattern(indexingMaps, fIndex, 0, oIndex, 1));
773+
}
774+
775+
bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
776+
if (isa<linalg::Conv1DOp>(op)) return true;
777+
778+
if (!isaConvolutionOpInterface(op)) return false;
779+
780+
ArrayAttr indexingMaps = op.getIndexingMaps();
781+
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,5})) return false;
782+
783+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
784+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
785+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
786+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
787+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
788+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
789+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
790+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
791+
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
792+
matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
793+
}
794+
795+
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
796+
if (isa<linalg::Conv1DOp>(op)) return true;
797+
798+
if (!isaConvolutionOpInterface(op)) return false;
799+
800+
ArrayAttr indexingMaps = op.getIndexingMaps();
801+
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,6})) return false;
802+
803+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
804+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
805+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
806+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
807+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
808+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
809+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
810+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
811+
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
812+
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4) &&
813+
matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 5));
814+
}
815+
816+
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op) {
817+
if (isa<linalg::Conv1DOp>(op)) return true;
818+
819+
if (!isaConvolutionOpInterface(op)) return false;
820+
821+
ArrayAttr indexingMaps = op.getIndexingMaps();
822+
if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
823+
824+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
825+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
826+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
827+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
828+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
829+
matchConvDimExprPattern(indexingMaps, iIndex, 1, fIndex, 0) &&
830+
matchConvDimExprPattern(indexingMaps, iIndex, 1, oIndex, 1) &&
831+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
832+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
833+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/4, /*fDim=*/3, /*oDim=*/4));
834+
}
835+
836+
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op) {
837+
if (isa<linalg::Conv1DOp>(op)) return true;
838+
839+
if (!isaConvolutionOpInterface(op)) return false;
840+
841+
ArrayAttr indexingMaps = op.getIndexingMaps();
842+
if (!verifyConvIndexingMapSizes(indexingMaps, {5,4,5})) return false;
843+
844+
unsigned iIndex = 0, fIndex = 1, oIndex = 2;
845+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
846+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
847+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
848+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
849+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
850+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
851+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
852+
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
853+
matchConvDimExprPattern(indexingMaps, iIndex, 4, oIndex, 4));
854+
}
855+
739856
Value makeComposedPadHighOp(OpBuilder &b, Location loc, RankedTensorType type,
740857
Value source, Value pad, bool nofold,
741858
ValueRange typeDynDims) {

0 commit comments

Comments
 (0)