Skip to content

Commit 053d912

Browse files
Clean a bit
1 parent a08247c commit 053d912

File tree

1 file changed

+5
-92
lines changed

1 file changed

+5
-92
lines changed

mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp

Lines changed: 5 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -343,27 +343,14 @@ static std::string inferBasedOnRank2ConvIteratorTypes(GenericOp genericOp) {
343343
}
344344

345345
static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
346-
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
347-
if (indexingMaps.size() != 3) return "";
348-
// depthwise_conv_1d_ncw_cw
349-
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1 + d3)>
350-
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
351-
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d2, d1)>
352346
if (isaDepthwiseConv1DNcwCwOp(genericOp))
353347
return "linalg.depthwise_conv_1d_ncw_cw";
354-
// depthwise_conv_1d_nwc_wc
355-
// #map = affine_map<(d0, d1, d2, d3) -> (d0, d1 + d3, d2)>
356-
// #map1 = affine_map<(d0, d1, d2, d3) -> (d3, d2)>
357-
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>
358348
if (isaDepthwiseConv1DNwcWcOp(genericOp))
359349
return "linalg.depthwise_conv_1d_nwc_wc";
360-
// conv_2d
361-
// #map = affine_map<(d0, d1, d2, d3) -> (d0 + d2, d1 + d3)>
362-
// #map1 = affine_map<(d0, d1, d2, d3) -> (d2, d3)>
363-
// #map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1)>
364350
if (isaConv2DOp(genericOp))
365351
return "linalg.conv_2d";
366352

353+
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
367354
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
368355
Block *body = genericOp.getBlock();
369356
auto yieldOp = cast<linalg::YieldOp>(body->getTerminator());
@@ -401,45 +388,24 @@ static std::string inferBasedOnRank4ConvIteratorTypes(GenericOp genericOp) {
401388
}
402389

403390
static std::string inferBasedOnRank5ConvIteratorTypes(GenericOp genericOp) {
404-
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
405-
if (indexingMaps.size() != 3) return "";
406-
// depthwise_conv_1d_nwc_wcm
407-
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d4, d2)>
408-
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d4, d2, d3)>
409-
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3)>
410391
if (isaDepthwiseConv1DNwcWcmOp(genericOp))
411392
return "linalg.depthwise_conv_1d_nwc_wcm";
412-
// conv_1d_nwc_wcf
413-
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1 + d3, d4)>
414-
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d3, d4, d2)>
415-
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
416393
if (isaConv1DNwcWcfOp(genericOp))
417394
return "linalg.conv_1d_nwc_wcf";
418-
// conv_1d_ncw_fcw
419-
// #map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2 + d4)>
420-
// #map1 = affine_map<(d0, d1, d2, d3, d4) -> (d1, d3, d4)>
421-
// #map2 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>
422395
if (isaConv1DNcwFcwOp(genericOp))
423396
return "linalg.conv_1d_ncw_fcw";
424397
return "";
425398
}
426399

427400
static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
428-
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
429-
if (indexingMaps.size() < 3) return "";
430-
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
431-
// depthwise_conv_2d_nchw_chw
432-
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1 + d4, d2 + d5)>
433-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
434-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d3, d1, d2)>
435401
if (isaDepthwiseConv2DNchwChwOp(genericOp))
436402
return "linalg.depthwise_conv_2d_nchw_chw";
437-
// depthwise_conv_2d_nhwc_hwc
438-
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1 + d4, d2 + d5, d3)>
439-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d4, d5, d3)>
440-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>
441403
if (isaDepthwiseConv2DNhwcHwcOp(genericOp))
442404
return "linalg.depthwise_conv_2d_nhwc_hwc";
405+
406+
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
407+
if (indexingMaps.size() < 3) return "";
408+
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
443409
// conv_3d
444410
// #map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0 + d3, d1 + d4, d2 + d5)>
445411
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>
@@ -501,83 +467,30 @@ static std::string inferBasedOnRank6ConvIteratorTypes(GenericOp genericOp) {
501467
}
502468

503469
static std::string inferBasedOnRank7ConvIteratorTypes(GenericOp genericOp) {
504-
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
505-
if (indexingMaps.size() < 3) return "";
506-
// conv_2d_nhwc_fhwc
507-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
508-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
509-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
510470
if (isaConv2DNhwcFhwcOp(genericOp))
511471
return "linalg.conv_2d_nhwc_fhwc";
512-
// conv_2d_nhwc_hwcf
513-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
514-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d3)>
515-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
516472
if (isaConv2DNhwcHwcfOp(genericOp))
517473
return "linalg.conv_2d_nhwc_hwcf";
518-
// conv_2d_nchw_fchw
519-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
520-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
521-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
522474
if (isaConv2DNchwFchwOp(genericOp))
523475
return "linalg.conv_2d_nchw_fchw";
524-
// conv_2d_nhwc_fhwc_q (same as conv_2d_nhwc_fhwc + check total 4 indexing maps)
525-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d4, d2 + d5, d6)>
526-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d3, d4, d5, d6)>
527-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
528-
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
529476
if (isaConv2DNhwcFhwcQOp(genericOp))
530477
return "linalg.conv_2d_nhwc_fhwc_q";
531-
// conv_2d_nchw_fchw_q (same as conv_2d_nchw_fchw + check total 4 indexing maps)
532-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>
533-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>
534-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
535-
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3)>
536478
if (isaConv2DNchwFchwQOp(genericOp))
537479
return "linalg.conv_2d_nchw_fchw_q";
538-
// depthwise_conv_2d_nhwc_hwcm
539-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
540-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
541-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
542480
if (isaDepthwiseConv2DNhwcHwcmOp(genericOp))
543481
return "linalg.depthwise_conv_2d_nhwc_hwcm";
544-
// depthwise_conv_2d_nhwc_hwcm_q
545-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1 + d5, d2 + d6, d3)>
546-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d3, d4)>
547-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> ()>
548-
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4)>
549482
if (isaDepthwiseConv2DNhwcHwcmQOp(genericOp))
550483
return "linalg.depthwise_conv_2d_nhwc_hwcm_q";
551484
return "";
552485
}
553486

554487
static std::string inferBasedOnRank8ConvIteratorTypes(GenericOp genericOp) {
555-
ArrayAttr indexingMaps = genericOp.getIndexingMaps();
556-
if (indexingMaps.size() < 3) return "";
557-
unsigned iIndex = 0, fIndex = 1, oIndex = indexingMaps.size() - 1;
558-
// conv_2d_ngchw_fgchw
559-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
560-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d2, d1, d5, d6, d7)>
561-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
562488
if (isaConv2DNgchwFgchwOp(genericOp))
563489
return "linalg.conv_2d_ngchw_fgchw";
564-
// conv_2d_ngchw_gfchw
565-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
566-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
567-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
568490
if (isaConv2DNgchwGfchwOp(genericOp))
569491
return "linalg.conv_2d_ngchw_gfchw";
570-
// conv_2d_ngchw_gfchw_q
571-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d5, d3 + d6, d4 + d7)>
572-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d1, d2, d5, d6, d7)>
573-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> ()>
574-
// #map3 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
575492
if (isaConv2DNgchwGfchwQOp(genericOp))
576493
return "linalg.conv_2d_ngchw_gfchw_q";
577-
// conv_2d_nhwgc_gfhwc
578-
// #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1 + d5, d2 + d6, d3, d7)>
579-
// #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d3, d4, d5, d6, d7)>
580-
// #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)>
581494
if (isaConv2DNhwgcGfhwcOp(genericOp))
582495
return "linalg.conv_2d_nhwgc_gfhwc";
583496
// depthwise_conv_3d_ncdhw_cdhw

0 commit comments

Comments
 (0)