@@ -343,27 +343,14 @@ static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
343343}
344344
345345static std::string inferBasedOnRank4ConvIteratorTypes (GenericOp genericOp) {
346- ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
347- if (indexingMaps.size () != 3 ) return " " ;
348- // depthwise_conv_1d_ncw_cw
349- // #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
350- // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
351- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
352346 if (isaDepthwiseConv1DNcwCwOp (genericOp))
353347 return " linalg.depthwise_conv_1d_ncw_cw" ;
354- // depthwise_conv_1d_nwc_wc
355- // #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
356- // #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
357- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
358348 if (isaDepthwiseConv1DNwcWcOp (genericOp))
359349 return " linalg.depthwise_conv_1d_nwc_wc" ;
360- // conv_2d
361- // #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
362- // #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
363- // #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
364350 if (isaConv2DOp (genericOp))
365351 return " linalg.conv_2d" ;
366352
353+ ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
367354 unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
368355 Block *body = genericOp.getBlock ();
369356 auto yieldOp = cast<linalg::YieldOp>(body->getTerminator ());
@@ -401,45 +388,24 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
401388}
402389
403390static std::string inferBasedOnRank5ConvIteratorTypes (GenericOp genericOp) {
404- ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
405- if (indexingMaps.size () != 3 ) return " " ;
406- // depthwise_conv_1d_nwc_wcm
407- // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
408- // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
409- // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
410391 if (isaDepthwiseConv1DNwcWcmOp (genericOp))
411392 return " linalg.depthwise_conv_1d_nwc_wcm" ;
412- // conv_1d_nwc_wcf
413- // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
414- // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
415- // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
416393 if (isaConv1DNwcWcfOp (genericOp))
417394 return " linalg.conv_1d_nwc_wcf" ;
418- // conv_1d_ncw_fcw
419- // #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
420- // #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
421- // #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
422395 if (isaConv1DNcwFcwOp (genericOp))
423396 return " linalg.conv_1d_ncw_fcw" ;
424397 return " " ;
425398}
426399
427400static std::string inferBasedOnRank6ConvIteratorTypes (GenericOp genericOp) {
428- ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
429- if (indexingMaps.size () < 3 ) return " " ;
430- unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
431- // depthwise_conv_2d_nchw_chw
432- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
433- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
434- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
435401 if (isaDepthwiseConv2DNchwChwOp (genericOp))
436402 return " linalg.depthwise_conv_2d_nchw_chw" ;
437- // depthwise_conv_2d_nhwc_hwc
438- // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
439- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
440- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
441403 if (isaDepthwiseConv2DNhwcHwcOp (genericOp))
442404 return " linalg.depthwise_conv_2d_nhwc_hwc" ;
405+
406+ ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
407+ if (indexingMaps.size () < 3 ) return " " ;
408+ unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
443409 // conv_3d
444410 // #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
445411 // #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
@@ -501,83 +467,30 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
501467}
502468
503469static std::string inferBasedOnRank7ConvIteratorTypes (GenericOp genericOp) {
504- ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
505- if (indexingMaps.size () < 3 ) return " " ;
506- // conv_2d_nhwc_fhwc
507- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
508- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
509- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
510470 if (isaConv2DNhwcFhwcOp (genericOp))
511471 return " linalg.conv_2d_nhwc_fhwc" ;
512- // conv_2d_nhwc_hwcf
513- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
514- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
515- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
516472 if (isaConv2DNhwcHwcfOp (genericOp))
517473 return " linalg.conv_2d_nhwc_hwcf" ;
518- // conv_2d_nchw_fchw
519- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
520- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
521- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
522474 if (isaConv2DNchwFchwOp (genericOp))
523475 return " linalg.conv_2d_nchw_fchw" ;
524- // conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
525- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
526- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
527- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
528- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
529476 if (isaConv2DNhwcFhwcQOp (genericOp))
530477 return " linalg.conv_2d_nhwc_fhwc_q" ;
531- // conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
532- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
533- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
534- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
535- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
536478 if (isaConv2DNchwFchwQOp (genericOp))
537479 return " linalg.conv_2d_nchw_fchw_q" ;
538- // depthwise_conv_2d_nhwc_hwcm
539- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
540- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
541- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
542480 if (isaDepthwiseConv2DNhwcHwcmOp (genericOp))
543481 return " linalg.depthwise_conv_2d_nhwc_hwcm" ;
544- // depthwise_conv_2d_nhwc_hwcm_q
545- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
546- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
547- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
548- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
549482 if (isaDepthwiseConv2DNhwcHwcmQOp (genericOp))
550483 return " linalg.depthwise_conv_2d_nhwc_hwcm_q" ;
551484 return " " ;
552485}
553486
554487static std::string inferBasedOnRank8ConvIteratorTypes (GenericOp genericOp) {
555- ArrayAttr indexingMaps = genericOp.getIndexingMaps ();
556- if (indexingMaps.size () < 3 ) return " " ;
557- unsigned iIndex = 0 , fIndex = 1 , oIndex = indexingMaps.size () - 1 ;
558- // conv_2d_ngchw_fgchw
559- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
560- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
561- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
562488 if (isaConv2DNgchwFgchwOp (genericOp))
563489 return " linalg.conv_2d_ngchw_fgchw" ;
564- // conv_2d_ngchw_gfchw
565- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
566- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
567- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
568490 if (isaConv2DNgchwGfchwOp (genericOp))
569491 return " linalg.conv_2d_ngchw_gfchw" ;
570- // conv_2d_ngchw_gfchw_q
571- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
572- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
573- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
574- // #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
575492 if (isaConv2DNgchwGfchwQOp (genericOp))
576493 return " linalg.conv_2d_ngchw_gfchw_q" ;
577- // conv_2d_nhwgc_gfhwc
578- // #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
579- // #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
580- // #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
581494 if (isaConv2DNhwgcGfhwcOp (genericOp))
582495 return " linalg.conv_2d_nhwgc_gfhwc" ;
583496 // depthwise_conv_3d_ncdhw_cdhw
0 commit comments