@@ -361,10 +361,10 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
361361 // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
362362 // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
363363 // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
364- if (matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 0 , /* fDim=*/ 0 , /* oDim=*/ 0 ) &&
365- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 1 , /* oDim=*/ 1 ))
364+ if (isaConv2DOp (genericOp))
366365 return " linalg.conv_2d" ;
367366
367+ unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
368368 Block *body = genericOp.getBlock ();
369369 auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
370370 Value yieldVal = yieldOp.getOperand (0 );
@@ -432,21 +432,13 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
432432 // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
433433 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
434434 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
435- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
436- matchConvDimExprPattern (indexingMaps, iIndex, 1 , fIndex , 0 ) &&
437- matchConvDimExprPattern (indexingMaps, iIndex, 1 , oIndex, 1 ) &&
438- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
439- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 2 , /* oDim=*/ 3 ))
435+ if (isaDepthwiseConv2DNchwChwOp (genericOp))
440436 return " linalg.depthwise_conv_2d_nchw_chw" ;
441437 // depthwise_conv_2d_nhwc_hwc
442438 // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
443439 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
444440 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
445- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
446- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
447- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
448- matchConvDimExprPattern (indexingMaps, iIndex, 3 , fIndex , 2 ) &&
449- matchConvDimExprPattern (indexingMaps, iIndex, 3 , oIndex, 3 ))
441+ if (isaDepthwiseConv2DNhwcHwcOp (genericOp))
450442 return " linalg.depthwise_conv_2d_nhwc_hwc" ;
451443 // conv_3d
452444 // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
@@ -511,90 +503,50 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
511503static std::string inferBasedOnRank7ConvIteratorTypes (GenericOp genericOp) {
512504 ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
513505 if (indexingMaps.size () < 3 ) return " " ;
514- unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
515506 // conv_2d_nhwc_fhwc
516507 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
517508 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
518509 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
519- if (indexingMaps.size () == 3 &&
520- matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
521- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 1 , /* oDim=*/ 1 ) &&
522- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 2 , /* oDim=*/ 2 ) &&
523- matchConvDimExprPattern (indexingMaps, iIndex, 3 , fIndex , 3 ) &&
524- matchConvDimExprPattern (indexingMaps, fIndex , 0 , oIndex, 3 ))
510+ if (isaConv2DNhwcFhwcOp (genericOp))
525511 return " linalg.conv_2d_nhwc_fhwc" ;
526512 // conv_2d_nhwc_hwcf
527513 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
528514 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
529515 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
530- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
531- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
532- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
533- matchConvDimExprPattern (indexingMaps, iIndex, 3 , fIndex , 2 ) &&
534- matchConvDimExprPattern (indexingMaps, fIndex , 3 , oIndex, 3 ))
516+ if (isaConv2DNhwcHwcfOp (genericOp))
535517 return " linalg.conv_2d_nhwc_hwcf" ;
536518 // conv_2d_nchw_fchw
537519 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
538520 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
539521 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
540- if (indexingMaps.size () == 3 &&
541- matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
542- matchConvDimExprPattern (indexingMaps, iIndex, 1 , fIndex , 1 ) &&
543- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 2 , /* oDim=*/ 2 ) &&
544- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 3 , /* oDim=*/ 3 ) &&
545- matchConvDimExprPattern (indexingMaps, fIndex , 0 , oIndex, 1 ))
522+ if (isaConv2DNchwFchwOp (genericOp))
546523 return " linalg.conv_2d_nchw_fchw" ;
547524 // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
548525 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
549526 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
550527 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
551528 // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
552- if (indexingMaps.size () == 5 &&
553- (indexingMaps[2 ] == indexingMaps[3 ] && cast<AffineMapAttr>(indexingMaps[2 ]).getValue ().getNumResults () == 0 ) &&
554- matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
555- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 1 , /* oDim=*/ 1 ) &&
556- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 2 , /* oDim=*/ 2 ) &&
557- matchConvDimExprPattern (indexingMaps, iIndex, 3 , fIndex , 3 ) &&
558- matchConvDimExprPattern (indexingMaps, fIndex , 0 , oIndex, 3 ))
529+ if (isaConv2DNhwcFhwcQOp (genericOp))
559530 return " linalg.conv_2d_nhwc_fhwc_q" ;
560531 // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
561532 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
562533 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
563534 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
564535 // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
565- if (indexingMaps.size () == 5 &&
566- (indexingMaps[2 ] == indexingMaps[3 ] && cast<AffineMapAttr>(indexingMaps[2 ]).getValue ().getNumResults () == 0 ) &&
567- matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
568- matchConvDimExprPattern (indexingMaps, iIndex, 1 , fIndex , 1 ) &&
569- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 2 , /* oDim=*/ 2 ) &&
570- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 3 , /* oDim=*/ 3 ) &&
571- matchConvDimExprPattern (indexingMaps, fIndex , 0 , oIndex, 1 ))
536+ if (isaConv2DNchwFchwQOp (genericOp))
572537 return " linalg.conv_2d_nchw_fchw_q" ;
573538 // depthwise_conv_2d_nhwc_hwcm
574539 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
575540 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
576541 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
577- if (indexingMaps.size () == 3 &&
578- matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
579- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
580- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
581- matchConvDimExprPattern (indexingMaps, iIndex, 3 , fIndex , 2 ) &&
582- matchConvDimExprPattern (indexingMaps, iIndex, 3 , oIndex, 3 ) &&
583- matchConvDimExprPattern (indexingMaps, fIndex , 3 , oIndex, 4 ))
542+ if (isaDepthwiseConv2DNhwcHwcmOp (genericOp))
584543 return " linalg.depthwise_conv_2d_nhwc_hwcm" ;
585544 // depthwise_conv_2d_nhwc_hwcm_q
586545 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
587546 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
588547 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
589548 // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
590- if (indexingMaps.size () == 5 &&
591- (indexingMaps[2 ] == indexingMaps[3 ] && cast<AffineMapAttr>(indexingMaps[2 ]).getValue ().getNumResults () == 0 ) &&
592- matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
593- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 0 , /* oDim=*/ 1 ) &&
594- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 1 , /* oDim=*/ 2 ) &&
595- matchConvDimExprPattern (indexingMaps, iIndex, 3 , fIndex , 2 ) &&
596- matchConvDimExprPattern (indexingMaps, iIndex, 3 , oIndex, 3 ) &&
597- matchConvDimExprPattern (indexingMaps, fIndex , 3 , oIndex, 4 ))
549+ if (isaDepthwiseConv2DNhwcHwcmQOp (genericOp))
598550 return " linalg.depthwise_conv_2d_nhwc_hwcm_q" ;
599551 return " " ;
600552}
@@ -607,53 +559,26 @@ static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
607559 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
608560 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
609561 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
610- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
611- matchConvDimExprPattern (indexingMaps, iIndex, 1 , fIndex , 1 ) &&
612- matchConvDimExprPattern (indexingMaps, iIndex, 1 , oIndex, 1 ) &&
613- matchConvDimExprPattern (indexingMaps, iIndex, 2 , fIndex , 2 ) &&
614- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 3 , /* oDim=*/ 3 ) &&
615- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 4 , /* fDim=*/ 4 , /* oDim=*/ 4 ) &&
616- matchConvDimExprPattern (indexingMaps, fIndex , 0 , oIndex, 2 ))
562+ if (isaConv2DNgchwFgchwOp (genericOp))
617563 return " linalg.conv_2d_ngchw_fgchw" ;
618564 // conv_2d_ngchw_gfchw
619565 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
620566 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
621567 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
622- if (indexingMaps.size () == 3 &&
623- matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
624- matchConvDimExprPattern (indexingMaps, iIndex, 1 , fIndex , 0 ) &&
625- matchConvDimExprPattern (indexingMaps, iIndex, 1 , oIndex, 1 ) &&
626- matchConvDimExprPattern (indexingMaps, iIndex, 2 , fIndex , 2 ) &&
627- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 3 , /* oDim=*/ 3 ) &&
628- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 4 , /* fDim=*/ 4 , /* oDim=*/ 4 ) &&
629- matchConvDimExprPattern (indexingMaps, fIndex , 1 , oIndex, 2 ))
568+ if (isaConv2DNgchwGfchwOp (genericOp))
630569 return " linalg.conv_2d_ngchw_gfchw" ;
631570 // conv_2d_ngchw_gfchw_q
632571 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
633572 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
634573 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
635574 // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
636- if (indexingMaps.size () == 5 &&
637- (indexingMaps[2 ] == indexingMaps[3 ] && cast<AffineMapAttr>(indexingMaps[2 ]).getValue ().getNumResults () == 0 ) &&
638- matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
639- matchConvDimExprPattern (indexingMaps, iIndex, 1 , fIndex , 0 ) &&
640- matchConvDimExprPattern (indexingMaps, iIndex, 1 , oIndex, 1 ) &&
641- matchConvDimExprPattern (indexingMaps, iIndex, 2 , fIndex , 2 ) &&
642- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 3 , /* fDim=*/ 3 , /* oDim=*/ 3 ) &&
643- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 4 , /* fDim=*/ 4 , /* oDim=*/ 4 ) &&
644- matchConvDimExprPattern (indexingMaps, fIndex , 1 , oIndex, 2 ))
575+ if (isaConv2DNgchwGfchwQOp (genericOp))
645576 return " linalg.conv_2d_ngchw_gfchw_q" ;
646577 // conv_2d_nhwgc_gfhwc
647578 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
648579 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
649580 // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
650- if (matchConvDimExprPattern (indexingMaps, iIndex, 0 , oIndex, 0 ) &&
651- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 1 , /* fDim=*/ 2 , /* oDim=*/ 1 ) &&
652- matchConvDimAddExprPattern (indexingMaps, /* iDim=*/ 2 , /* fDim=*/ 3 , /* oDim=*/ 2 ) &&
653- matchConvDimExprPattern (indexingMaps, iIndex, 3 , fIndex , 0 ) &&
654- matchConvDimExprPattern (indexingMaps, iIndex, 3 , oIndex, 3 ) &&
655- matchConvDimExprPattern (indexingMaps, iIndex, 4 , fIndex , 4 ) &&
656- matchConvDimExprPattern (indexingMaps, fIndex , 1 , oIndex, 4 ))
581+ if (isaConv2DNhwgcGfhwcOp (genericOp))
657582 return " linalg.conv_2d_nhwgc_gfhwc" ;
658583 // depthwise_conv_3d_ncdhw_cdhw
659584 // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d7, d1 + d4, d2 + d5, d3 + d6)>
0 commit comments