Skip to content

Commit 5aeb371

Browse files
Missing ops
1 parent b06ba75 commit 5aeb371

File tree

3 files changed

+103
-0
lines changed

3 files changed

+103
-0
lines changed

mlir/include/mlir/Dialect/Linalg/Utils/Utils.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,19 @@ bool isaConv2DNhwcFhwcQOp(LinalgOp op);
128128
bool isaConv2DNchwFchwQOp(LinalgOp op);
129129
bool isaConv2DNgchwFgchwOp(LinalgOp op);
130130
bool isaConv2DNgchwGfchwOp(LinalgOp op);
131+
bool isaConv2DNhwcHwcfQOp(LinalgOp op);
132+
bool isaConv2DNhwgcGfhwcQOp(LinalgOp op);
131133
bool isaConv2DNgchwGfchwQOp(LinalgOp op);
132134
bool isaConv2DNhwgcGfhwcOp(LinalgOp op);
133135
bool isaDepthwiseConv2DNchwChwOp(LinalgOp op);
134136
bool isaDepthwiseConv2DNhwcHwcOp(LinalgOp op);
135137
bool isaDepthwiseConv2DNhwcHwcmOp(LinalgOp op);
138+
bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op);
136139
bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op);
137140
bool isaConv3DOp(LinalgOp op);
138141
bool isaConv3DNcdhwFcdhwOp(LinalgOp op);
139142
bool isaConv3DNdhwcDhwcfOp(LinalgOp op);
143+
bool isaConv3DNdhwcDhwcfQOp(LinalgOp op);
140144
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op);
141145
bool isaDepthwiseConv3DNcdhwCdhwOp(LinalgOp op);
142146
bool isaDepthwiseConv3DNdhwcDhwcOp(LinalgOp op);

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,8 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
277277
return "linalg.depthwise_conv_2d_nchw_chw";
278278
if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
279279
return "linalg.depthwise_conv_2d_nhwc_hwc";
280+
if (isaDepthwiseConv2DNhwcHwcQOp(genericOp))
281+
return "linalg.depthwise_conv_2d_nhwc_hwc_q";
280282
if (isaConv3DOp(genericOp))
281283
return "linalg.conv_3d";
282284
if (isaPoolingNchwMaxOp(genericOp))
@@ -307,6 +309,8 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
307309
return "linalg.conv_2d_nhwc_fhwc_q";
308310
if (isaConv2DNchwFchwQOp(genericOp))
309311
return "linalg.conv_2d_nchw_fchw_q";
312+
if (isaConv2DNhwcHwcfQOp(genericOp))
313+
return "linalg.conv_2d_nhwc_hwcf_q";
310314
if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
311315
return "linalg.depthwise_conv_2d_nhwc_hwcm";
312316
if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
@@ -323,6 +327,8 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
323327
return "linalg.conv_2d_ngchw_gfchw_q";
324328
if (isaConv2DNhwgcGfhwcOp(genericOp))
325329
return "linalg.conv_2d_nhwgc_gfhwc";
330+
if (isaConv2DNhwgcGfhwcQOp(genericOp))
331+
return "linalg.conv_2d_nhwgc_gfhwc_q";
326332
if (isaDepthwiseConv3DNcdhwCdhwOp(genericOp))
327333
return "linalg.depthwise_conv_3d_ncdhw_cdhw";
328334
if (isaDepthwiseConv3DNdhwcDhwcOp(genericOp))
@@ -341,6 +347,8 @@ static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
341347
return "linalg.conv_3d_ncdhw_fcdhw";
342348
if (isaConv3DNdhwcDhwcfOp(genericOp))
343349
return "linalg.conv_3d_ndhwc_dhwcf";
350+
if (isaConv3DNdhwcDhwcfQOp(genericOp))
351+
return "linalg.conv_3d_ndhwc_dhwcf_q";
344352
if (isaDepthwiseConv3DNdhwcDhwcmOp(genericOp))
345353
return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
346354
return "";
@@ -412,6 +420,10 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
412420
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
413421
} else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") {
414422
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
423+
} else if (convKind == "linalg.conv_2d_nhwc_hwcf_q") {
424+
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwcHwcfQOp>(genericOp, resultTypes, inputs, outputs);
425+
} else if (convKind == "linalg.conv_2d_nhwgc_gfhwc_q") {
426+
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcQOp>(genericOp, resultTypes, inputs, outputs);
415427
} else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") {
416428
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNchwChwOp>(genericOp, resultTypes, inputs, outputs);
417429
} else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") {
@@ -420,12 +432,16 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
420432
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmOp>(genericOp, resultTypes, inputs, outputs);
421433
} else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") {
422434
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(genericOp, resultTypes, inputs, outputs);
435+
} else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc_q") {
436+
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcQOp>(genericOp, resultTypes, inputs, outputs);
423437
} else if (convKind == "linalg.conv_3d") {
424438
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DOp>(genericOp, resultTypes, inputs, outputs);
425439
} else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") {
426440
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNcdhwFcdhwOp>(genericOp, resultTypes, inputs, outputs);
427441
} else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") {
428442
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfOp>(genericOp, resultTypes, inputs, outputs);
443+
} else if (convKind == "linalg.conv_3d_ndhwc_dhwcf_q") {
444+
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfQOp>(genericOp, resultTypes, inputs, outputs);
429445
} else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") {
430446
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(genericOp, resultTypes, inputs, outputs);
431447
} else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") {

mlir/lib/Dialect/Linalg/Utils/Utils.cpp

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,48 @@ bool isaConv2DNgchwGfchwOp(LinalgOp op) {
614614
matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 2));
615615
}
616616

617+
bool isaConv2DNhwcHwcfQOp(LinalgOp op) {
618+
if (isa<linalg::Conv2DNhwcHwcfQOp>(op)) return true;
619+
620+
if (!isaConvolutionOpInterface(op)) return false;
621+
622+
ArrayAttr indexingMaps = op.getIndexingMaps();
623+
if (!verifyConvIndexingMapSizes(indexingMaps, {4,4,0,0,4})) return false;
624+
625+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
626+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
627+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
628+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
629+
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
630+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
631+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
632+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
633+
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
634+
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 3));
635+
}
636+
637+
bool isaConv2DNhwgcGfhwcQOp(LinalgOp op) {
638+
if (isa<linalg::Conv2DNhwgcGfhwcQOp>(op)) return true;
639+
640+
if (!isaConvolutionOpInterface(op)) return false;
641+
642+
ArrayAttr indexingMaps = op.getIndexingMaps();
643+
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
644+
645+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
646+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
647+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
648+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
649+
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)
650+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
651+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/2, /*oDim=*/1) &&
652+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/3, /*oDim=*/2) &&
653+
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 0) &&
654+
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3) &&
655+
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 4) &&
656+
matchConvDimExprPattern(indexingMaps, fIndex, 1, oIndex, 4));
657+
}
658+
617659
bool isaConv2DNgchwGfchwQOp(LinalgOp op) {
618660
if (isa<linalg::Conv2DNgchwGfchwQOp>(op)) return true;
619661

@@ -736,6 +778,26 @@ bool isaDepthwiseConv2DNhwcHwcmQOp(LinalgOp op) {
736778
matchConvDimExprPattern(indexingMaps, fIndex, 3, oIndex, 4));
737779
}
738780

781+
bool isaDepthwiseConv2DNhwcHwcQOp(LinalgOp op) {
782+
if (isa<linalg::DepthwiseConv2DNhwcHwcQOp>(op)) return true;
783+
784+
if (!isaConvolutionOpInterface(op)) return false;
785+
786+
ArrayAttr indexingMaps = op.getIndexingMaps();
787+
if (!verifyConvIndexingMapSizes(indexingMaps, {4,3,0,0,4})) return false;
788+
789+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
790+
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
791+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
792+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
793+
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
794+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
795+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
796+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
797+
matchConvDimExprPattern(indexingMaps, iIndex, 3, fIndex, 2) &&
798+
matchConvDimExprPattern(indexingMaps, iIndex, 3, oIndex, 3));
799+
}
800+
739801
bool isaConv3DOp(LinalgOp op) {
740802
if (isa<linalg::Conv3DOp>(op)) return true;
741803

@@ -792,6 +854,27 @@ bool isaConv3DNdhwcDhwcfOp(LinalgOp op) {
792854
matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
793855
}
794856

857+
bool isaConv3DNdhwcDhwcfQOp(LinalgOp op) {
858+
if (isa<linalg::Conv3DNdhwcDhwcfQOp>(op)) return true;
859+
860+
if (!isaConvolutionOpInterface(op)) return false;
861+
862+
ArrayAttr indexingMaps = op.getIndexingMaps();
863+
if (!verifyConvIndexingMapSizes(indexingMaps, {5,5,0,0,5})) return false;
864+
865+
unsigned iIndex = 0, fIndex = 1, oIndex = 4;
866+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
867+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
868+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> ()>
869+
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
870+
return (matchConvDimExprPattern(indexingMaps, iIndex, 0, oIndex, 0) &&
871+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, /*oDim=*/1) &&
872+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, /*oDim=*/2) &&
873+
matchConvDimAddExprPattern(indexingMaps, /*iDim=*/3, /*fDim=*/2, /*oDim=*/3) &&
874+
matchConvDimExprPattern(indexingMaps, iIndex, 4, fIndex, 3) &&
875+
matchConvDimExprPattern(indexingMaps, fIndex, 4, oIndex, 4));
876+
}
877+
795878
bool isaDepthwiseConv3DNdhwcDhwcmOp(LinalgOp op) {
796879
if (isa<linalg::DepthwiseConv3DNdhwcDhwcmOp>(op)) return true;
797880

0 commit comments

Comments
 (0)