@@ -316,6 +316,39 @@ static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
316316 return " " ;
317317}
318318
319+ static std::string inferBasedOnRank6ConvIteratorTypes (GenericOp genericOp) {
320+ ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
321+ if (indexingMaps.size () < 3 ) return " " ;
322+ unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
323+ // depthwise_conv_2d_nchw_chw
324+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
325+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
326+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
327+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
328+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == getAffineMapDim (indexingMaps, fIndex , 0 ) && getAffineMapDim (indexingMaps, iIndex, 1 ) == getAffineMapDim (indexingMaps, oIndex, 1 )) &&
329+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
330+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 2 ) + getAffineMapDim (indexingMaps, oIndex, 3 ))))
331+ return " linalg.depthwise_conv_2d_nchw_chw" ;
332+ // depthwise_conv_2d_nhwc_hwc
333+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
334+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
335+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
336+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
337+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
338+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
339+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == getAffineMapDim (indexingMaps, fIndex , 2 ) && getAffineMapDim (indexingMaps, iIndex, 3 ) == getAffineMapDim (indexingMaps, oIndex, 3 )))
340+ return " linalg.depthwise_conv_2d_nhwc_hwc" ;
341+ // conv_3d
342+ // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
343+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
344+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2)>
345+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 0 ))) &&
346+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
347+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 2 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))))
348+ return " linalg.conv_3d" ;
349+ return " " ;
350+ }
351+
319352static std::string inferBasedOnRank7ConvIteratorTypes (GenericOp genericOp) {
320353 ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
321354 if (indexingMaps.size () < 3 ) return " " ;
@@ -370,9 +403,6 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
370403 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
371404 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
372405 // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
373- llvm::outs ()<<" Indexing map size = " <<indexingMaps.size ()<<" \n " ;
374- llvm::outs ()<<" (indexingMaps[2] == indexingMaps[3]) == " <<(indexingMaps[2 ] == indexingMaps[3 ])<<" \n " ;
375- llvm::outs ()<<" cast<AffineMapAttr>(indexingMaps[2]).getValue().getNumResults() = " <<cast<AffineMapAttr>(indexingMaps[2 ]).getValue ().getNumResults ()<<" \n " ;
376406 if (indexingMaps.size () == 5 &&
377407 (indexingMaps[2 ] == indexingMaps[3 ] && cast<AffineMapAttr>(indexingMaps[2 ]).getValue ().getNumResults () == 0 ) &&
378408 (getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
@@ -381,6 +411,30 @@ static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
381411 (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 3 ) + getAffineMapDim (indexingMaps, oIndex, 3 ))) &&
382412 (getAffineMapDim (indexingMaps, fIndex , 0 ) == getAffineMapDim (indexingMaps, oIndex, 1 )))
383413 return " linalg.conv_2d_nchw_fchw_q" ;
414+ // depthwise_conv_2d_nhwc_hwcm
415+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
416+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
417+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
418+ if (indexingMaps.size () == 3 &&
419+ (getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
420+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
421+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
422+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == getAffineMapDim (indexingMaps, fIndex , 2 ) && getAffineMapDim (indexingMaps, iIndex, 3 ) == getAffineMapDim (indexingMaps, oIndex, 3 )) &&
423+ (getAffineMapDim (indexingMaps, fIndex , 3 ) == getAffineMapDim (indexingMaps, oIndex, 4 )))
424+ return " linalg.depthwise_conv_2d_nhwc_hwcm" ;
425+ // depthwise_conv_2d_nhwc_hwcm_q
426+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
427+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
428+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
429+ // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
430+ if (indexingMaps.size () == 5 &&
431+ (indexingMaps[2 ] == indexingMaps[3 ] && cast<AffineMapAttr>(indexingMaps[2 ]).getValue ().getNumResults () == 0 ) &&
432+ (getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
433+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
434+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
435+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == getAffineMapDim (indexingMaps, fIndex , 2 ) && getAffineMapDim (indexingMaps, iIndex, 3 ) == getAffineMapDim (indexingMaps, oIndex, 3 )) &&
436+ (getAffineMapDim (indexingMaps, fIndex , 3 ) == getAffineMapDim (indexingMaps, oIndex, 4 )))
437+ return " linalg.depthwise_conv_2d_nhwc_hwcm_q" ;
384438 return " " ;
385439}
386440
@@ -397,7 +451,7 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
397451 (getAffineMapDim (indexingMaps, iIndex, 2 ) == getAffineMapDim (indexingMaps, fIndex , 2 )) &&
398452 (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 3 ) + getAffineMapDim (indexingMaps, oIndex, 3 ))) &&
399453 (getAffineMapDim (indexingMaps, iIndex, 4 ) == (getAffineMapDim (indexingMaps, fIndex , 4 ) + getAffineMapDim (indexingMaps, oIndex, 4 ))) &&
400- (getAffineMapDim (indexingMaps, fIndex , 0 ) == getAffineMapDim (indexingMaps, oIndex, 1 )))
454+ (getAffineMapDim (indexingMaps, fIndex , 0 ) == getAffineMapDim (indexingMaps, oIndex, 2 )))
401455 return " linalg.conv_2d_ngchw_fgchw" ;
402456 // conv_2d_ngchw_gfchw
403457 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
@@ -436,6 +490,66 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
436490 (getAffineMapDim (indexingMaps, iIndex, 4 ) == getAffineMapDim (indexingMaps, fIndex , 4 )) &&
437491 (getAffineMapDim (indexingMaps, fIndex , 1 ) == getAffineMapDim (indexingMaps, oIndex, 4 )))
438492 return " linalg.conv_2d_nhwgc_gfhwc" ;
493+ // depthwise_conv_3d_ncdhw_cdhw
494+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
495+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d7, d4, d5, d6)>
496+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1, d2, d3)>
497+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
498+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == getAffineMapDim (indexingMaps, fIndex , 0 ) && getAffineMapDim (indexingMaps, iIndex, 1 ) == getAffineMapDim (indexingMaps, oIndex, 1 )) &&
499+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
500+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 2 ) + getAffineMapDim (indexingMaps, oIndex, 3 ))) &&
501+ (getAffineMapDim (indexingMaps, iIndex, 4 ) == (getAffineMapDim (indexingMaps, fIndex , 3 ) + getAffineMapDim (indexingMaps, oIndex, 4 ))))
502+ return " linalg.depthwise_conv_3d_ncdhw_cdhw" ;
503+ // depthwise_conv_3d_ndhwc_dhwc
504+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d4, d2 + d5, d3 + d6, d7)>
505+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d4, d5, d6, d7)>
506+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d7)>
507+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
508+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
509+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
510+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 2 ) + getAffineMapDim (indexingMaps, oIndex, 3 ))) &&
511+ (getAffineMapDim (indexingMaps, iIndex, 4 ) == getAffineMapDim (indexingMaps, fIndex , 3 ) && getAffineMapDim (indexingMaps, iIndex, 4 ) == getAffineMapDim (indexingMaps, oIndex, 4 )))
512+ return " linalg.depthwise_conv_3d_ndhwc_dhwc" ;
513+ return " " ;
514+ }
515+
516+ static std::string inferBasedOnRank9ConvIteratorTypes (GenericOp genericOp) {
517+ ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
518+ if (indexingMaps.size () < 3 ) return " " ;
519+ unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
520+ // conv_3d_ncdhw_fcdhw
521+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d5, d2 + d6, d3 + d7, d4 + d8)>
522+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d1, d5, d6, d7, d8)>
523+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
524+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
525+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == getAffineMapDim (indexingMaps, fIndex , 1 )) &&
526+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 2 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
527+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 3 ) + getAffineMapDim (indexingMaps, oIndex, 3 ))) &&
528+ (getAffineMapDim (indexingMaps, iIndex, 4 ) == (getAffineMapDim (indexingMaps, fIndex , 4 ) + getAffineMapDim (indexingMaps, oIndex, 4 ))) &&
529+ (getAffineMapDim (indexingMaps, fIndex , 0 ) == getAffineMapDim (indexingMaps, oIndex, 1 )))
530+ return " linalg.conv_3d_ncdhw_fcdhw" ;
531+ // conv_3d_ndhwc_dhwcf
532+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
533+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
534+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d4)>
535+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
536+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
537+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
538+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 2 ) + getAffineMapDim (indexingMaps, oIndex, 3 ))) &&
539+ (getAffineMapDim (indexingMaps, iIndex, 4 ) == getAffineMapDim (indexingMaps, fIndex , 3 )) &&
540+ (getAffineMapDim (indexingMaps, fIndex , 4 ) == getAffineMapDim (indexingMaps, oIndex, 4 )))
541+ return " linalg.conv_3d_ndhwc_dhwcf" ;
542+ // depthwise_conv_3d_ndhwc_dhwcm
543+ // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1 + d5, d2 + d6, d3 + d7, d8)>
544+ // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d5, d6, d7, d8, d4)>
545+ // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d2, d3, d8, d4)>
546+ if ((getAffineMapDim (indexingMaps, iIndex, 0 ) == getAffineMapDim (indexingMaps, oIndex, 0 )) &&
547+ (getAffineMapDim (indexingMaps, iIndex, 1 ) == (getAffineMapDim (indexingMaps, fIndex , 0 ) + getAffineMapDim (indexingMaps, oIndex, 1 ))) &&
548+ (getAffineMapDim (indexingMaps, iIndex, 2 ) == (getAffineMapDim (indexingMaps, fIndex , 1 ) + getAffineMapDim (indexingMaps, oIndex, 2 ))) &&
549+ (getAffineMapDim (indexingMaps, iIndex, 3 ) == (getAffineMapDim (indexingMaps, fIndex , 2 ) + getAffineMapDim (indexingMaps, oIndex, 3 ))) &&
550+ (getAffineMapDim (indexingMaps, iIndex, 4 ) == getAffineMapDim (indexingMaps, fIndex , 3 ) && getAffineMapDim (indexingMaps, iIndex, 4 ) == getAffineMapDim (indexingMaps, oIndex, 4 )) &&
551+ (getAffineMapDim (indexingMaps, fIndex , 4 ) == getAffineMapDim (indexingMaps, oIndex, 5 )))
552+ return " linalg.depthwise_conv_3d_ndhwc_dhwcm" ;
439553 return " " ;
440554}
441555
@@ -449,10 +563,14 @@ static std::string inferConvolutionKind(GenericOp genericOp) {
449563 return inferBasedOnRank4ConvIteratorTypes (genericOp);
450564 case 5 :
451565 return inferBasedOnRank5ConvIteratorTypes (genericOp);
566+ case 6 :
567+ return inferBasedOnRank6ConvIteratorTypes (genericOp);
452568 case 7 :
453569 return inferBasedOnRank7ConvIteratorTypes (genericOp);
454570 case 8 :
455571 return inferBasedOnRank8ConvIteratorTypes (genericOp);
572+ case 9 :
573+ return inferBasedOnRank9ConvIteratorTypes (genericOp);
456574 }
457575 return " " ;
458576}
@@ -501,6 +619,26 @@ static FailureOr<LinalgOp> specializeLinalgConvolutions(RewriterBase &rewriter,
501619 namedOp = rewriter.replaceOpWithNewOp <linalg::Conv2DNgchwGfchwQOp>(genericOp, resultTypes, inputs, outputs);
502620 } else if (convKind == " linalg.conv_2d_nhwgc_gfhwc" ) {
503621 namedOp = rewriter.replaceOpWithNewOp <linalg::Conv2DNhwgcGfhwcOp>(genericOp, resultTypes, inputs, outputs);
622+ } else if (convKind == " linalg.depthwise_conv_2d_nchw_chw" ) {
623+ namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv2DNchwChwOp>(genericOp, resultTypes, inputs, outputs);
624+ } else if (convKind == " linalg.depthwise_conv_2d_nhwc_hwc" ) {
625+ namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv2DNhwcHwcOp>(genericOp, resultTypes, inputs, outputs);
626+ } else if (convKind == " linalg.depthwise_conv_2d_nhwc_hwcm" ) {
627+ namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv2DNhwcHwcmOp>(genericOp, resultTypes, inputs, outputs);
628+ } else if (convKind == " linalg.depthwise_conv_2d_nhwc_hwcm_q" ) {
629+ namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv2DNhwcHwcmQOp>(genericOp, resultTypes, inputs, outputs);
630+ } else if (convKind == " linalg.conv_3d" ) {
631+ namedOp = rewriter.replaceOpWithNewOp <linalg::Conv3DOp>(genericOp, resultTypes, inputs, outputs);
632+ } else if (convKind == " linalg.conv_3d_ncdhw_fcdhw" ) {
633+ namedOp = rewriter.replaceOpWithNewOp <linalg::Conv3DNcdhwFcdhwOp>(genericOp, resultTypes, inputs, outputs);
634+ } else if (convKind == " linalg.conv_3d_ndhwc_dhwcf" ) {
635+ namedOp = rewriter.replaceOpWithNewOp <linalg::Conv3DNdhwcDhwcfOp>(genericOp, resultTypes, inputs, outputs);
636+ } else if (convKind == " linalg.depthwise_conv_3d_ndhwc_dhwcm" ) {
637+ namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv3DNdhwcDhwcmOp>(genericOp, resultTypes, inputs, outputs);
638+ } else if (convKind == " linalg.depthwise_conv_3d_ncdhw_cdhw" ) {
639+ namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv3DNcdhwCdhwOp>(genericOp, resultTypes, inputs, outputs);
640+ } else if (convKind == " linalg.depthwise_conv_3d_ndhwc_dhwc" ) {
641+ namedOp = rewriter.replaceOpWithNewOp <linalg::DepthwiseConv3DNdhwcDhwcOp>(genericOp, resultTypes, inputs, outputs);
504642 }
505643 return namedOp;
506644
0 commit comments