Skip to content

Commit dac92f1

Browse files
Conv complete -> start Pool op now
1 parent 89b7190 commit dac92f1

File tree

1 file changed

+142
-4
lines changed

1 file changed

+142
-4
lines changed

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 142 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,39 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
316316
return "";
317317
}
318318

319+
static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
320+
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
321+
if (indexingMaps.size() < 3) return "";
322+
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
323+
// depthwise_conv_2d_nchw_chw
324+
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
325+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
326+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
327+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
328+
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
329+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
330+
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))))
331+
return "linalg.depthwise_conv_2d_nchw_chw";
332+
// depthwise_conv_2d_nhwc_hwc
333+
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
334+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
335+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
336+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
337+
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
338+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
339+
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)))
340+
return "linalg.depthwise_conv_2d_nhwc_hwc";
341+
// conv_3d
342+
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
343+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
344+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
345+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 0))) &&
346+
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
347+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))))
348+
return "linalg.conv_3d";
349+
return "";
350+
}
351+
319352
static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
320353
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
321354
if (indexingMaps.size() < 3) return "";
@@ -370,9 +403,6 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
370403
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
371404
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
372405
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
373-
llvm::outs()<<"Indexing map size = "<<indexingMaps.size()<<"\n";
374-
llvm::outs()<<"(indexingMaps[2] == indexingMaps[3]) == "<<(indexingMaps[2] == indexingMaps[3])<<"\n";
375-
llvm::outs()<<"cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() = "<<cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults()<<"\n";
376406
if (indexingMaps.size() == 5 &&
377407
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
378408
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
@@ -381,6 +411,30 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
381411
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
382412
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
383413
return "linalg.conv_2d_nchw_fchw_q";
414+
// depthwise_conv_2d_nhwc_hwcm
415+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
416+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
417+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
418+
if (indexingMaps.size() == 3 &&
419+
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
420+
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
421+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
422+
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
423+
(getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
424+
return "linalg.depthwise_conv_2d_nhwc_hwcm";
425+
// depthwise_conv_2d_nhwc_hwcm_q
426+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
427+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
428+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
429+
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
430+
if (indexingMaps.size() == 5 &&
431+
(indexingMaps[2] == indexingMaps[3] && cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() == 0) &&
432+
(getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
433+
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
434+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
435+
(getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, fIndex, 2) && getAffineMapDim(indexingMaps, iIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 3)) &&
436+
(getAffineMapDim(indexingMaps, fIndex, 3) == getAffineMapDim(indexingMaps, oIndex, 4)))
437+
return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
384438
return "";
385439
}
386440

@@ -397,7 +451,7 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
397451
(getAffineMapDim(indexingMaps, iIndex, 2) == getAffineMapDim(indexingMaps, fIndex, 2)) &&
398452
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
399453
(getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
400-
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
454+
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 2)))
401455
return "linalg.conv_2d_ngchw_fgchw";
402456
// conv_2d_ngchw_gfchw
403457
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
@@ -436,6 +490,66 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
436490
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 4)) &&
437491
(getAffineMapDim(indexingMaps, fIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 4)))
438492
return "linalg.conv_2d_nhwgc_gfhwc";
493+
// depthwise_conv_3d_ncdhw_cdhw
494+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
495+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
496+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
497+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
498+
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 0) && getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, oIndex, 1)) &&
499+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
500+
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
501+
(getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 4))))
502+
return "linalg.depthwise_conv_3d_ncdhw_cdhw";
503+
// depthwise_conv_3d_ndhwc_dhwc
504+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
505+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
506+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
507+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
508+
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
509+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
510+
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
511+
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
512+
return "linalg.depthwise_conv_3d_ndhwc_dhwc";
513+
return "";
514+
}
515+
516+
static std::string inferBasedOnRank9ConvIteratorTypes(GenericOp genericOp) {
517+
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
518+
if (indexingMaps.size() < 3) return "";
519+
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
520+
// conv_3d_ncdhw_fcdhw
521+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
522+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
523+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
524+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
525+
(getAffineMapDim(indexingMaps, iIndex, 1) == getAffineMapDim(indexingMaps, fIndex, 1)) &&
526+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
527+
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 3) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
528+
(getAffineMapDim(indexingMaps, iIndex, 4) == (getAffineMapDim(indexingMaps, fIndex, 4) + getAffineMapDim(indexingMaps, oIndex, 4))) &&
529+
(getAffineMapDim(indexingMaps, fIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 1)))
530+
return "linalg.conv_3d_ncdhw_fcdhw";
531+
// conv_3d_ndhwc_dhwcf
532+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
533+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
534+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
535+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
536+
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
537+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
538+
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
539+
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3)) &&
540+
(getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)))
541+
return "linalg.conv_3d_ndhwc_dhwcf";
542+
// depthwise_conv_3d_ndhwc_dhwcm
543+
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
544+
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
545+
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
546+
if ((getAffineMapDim(indexingMaps, iIndex, 0) == getAffineMapDim(indexingMaps, oIndex, 0)) &&
547+
(getAffineMapDim(indexingMaps, iIndex, 1) == (getAffineMapDim(indexingMaps, fIndex, 0) + getAffineMapDim(indexingMaps, oIndex, 1))) &&
548+
(getAffineMapDim(indexingMaps, iIndex, 2) == (getAffineMapDim(indexingMaps, fIndex, 1) + getAffineMapDim(indexingMaps, oIndex, 2))) &&
549+
(getAffineMapDim(indexingMaps, iIndex, 3) == (getAffineMapDim(indexingMaps, fIndex, 2) + getAffineMapDim(indexingMaps, oIndex, 3))) &&
550+
(getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, fIndex, 3) && getAffineMapDim(indexingMaps, iIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 4)) &&
551+
(getAffineMapDim(indexingMaps, fIndex, 4) == getAffineMapDim(indexingMaps, oIndex, 5)))
552+
return "linalg.depthwise_conv_3d_ndhwc_dhwcm";
439553
return "";
440554
}
441555

@@ -449,10 +563,14 @@ static std::string inferConvolutionKind(GenericOp genericOp) {
449563
return inferBasedOnRank4ConvIteratorTypes(genericOp);
450564
case 5:
451565
return inferBasedOnRank5ConvIteratorTypes(genericOp);
566+
case 6:
567+
return inferBasedOnRank6ConvIteratorTypes(genericOp);
452568
case 7:
453569
return inferBasedOnRank7ConvIteratorTypes(genericOp);
454570
case 8:
455571
return inferBasedOnRank8ConvIteratorTypes(genericOp);
572+
case 9:
573+
return inferBasedOnRank9ConvIteratorTypes(genericOp);
456574
}
457575
return "";
458576
}
@@ -501,6 +619,26 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
501619
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
502620
} else if (convKind == "linalg.conv_2d_nhwgc_gfhwc") {
503621
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
622+
} else if (convKind == "linalg.depthwise_conv_2d_nchw_chw") {
623+
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNchwChwOp>(genericOp, resultTypes, inputs, outputs);
624+
} else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwc") {
625+
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcOp>(genericOp, resultTypes, inputs, outputs);
626+
} else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm") {
627+
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmOp>(genericOp, resultTypes, inputs, outputs);
628+
} else if (convKind == "linalg.depthwise_conv_2d_nhwc_hwcm_q") {
629+
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv2DNhwcHwcmQOp>(genericOp, resultTypes, inputs, outputs);
630+
} else if (convKind == "linalg.conv_3d") {
631+
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DOp>(genericOp, resultTypes, inputs, outputs);
632+
} else if (convKind == "linalg.conv_3d_ncdhw_fcdhw") {
633+
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNcdhwFcdhwOp>(genericOp, resultTypes, inputs, outputs);
634+
} else if (convKind == "linalg.conv_3d_ndhwc_dhwcf") {
635+
namedOp = rewriter.replaceOpWithNewOp<linalg::Conv3DNdhwcDhwcfOp>(genericOp, resultTypes, inputs, outputs);
636+
} else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwcm") {
637+
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcmOp>(genericOp, resultTypes, inputs, outputs);
638+
} else if (convKind == "linalg.depthwise_conv_3d_ncdhw_cdhw") {
639+
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNcdhwCdhwOp>(genericOp, resultTypes, inputs, outputs);
640+
} else if (convKind == "linalg.depthwise_conv_3d_ndhwc_dhwc") {
641+
namedOp = rewriter.replaceOpWithNewOp<linalg::DepthwiseConv3DNdhwcDhwcOp>(genericOp, resultTypes, inputs, outputs);
504642
}
505643
return namedOp;
506644

0 commit comments

Comments
 (0)