@@ -452,24 +452,41 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
452452 StringAttr kWarp = str_attr (" warp" );
453453 StringAttr kBlock = str_attr (" block" );
454454
455- LinearLayout comp = srcLayout .invertAndCompose (dstLayout );
455+ LinearLayout comp = dstLayout .invertAndCompose (srcLayout );
456456 std::optional<LinearLayout> conversion = comp.divideRight (
457457 LinearLayout::identity1D (comp.getInDimSize (kWarp ), kWarp , kWarp ) *
458458 LinearLayout::identity1D (comp.getInDimSize (kBlock ), kBlock , kBlock ));
459459 assert (conversion && " Expecting valid conversion" );
460460 // Expected conversion is:
461461 // - register=1 -> (0, 1)
462- // register=2 -> (0, 2)
463- // register=4 -> (0, 4)
464- // register=8 -> (0, 8)
465- // - lane=1 -> (1, 0)
466- // lane=2 -> (2, 0)
467- // lane=4 -> (4, 0)
468- // lane=8 -> (8, 0)
469- // where out dims are: [register (size 16), lane (size 16)]
462+ // ...
463+ // - register=i -> (0, 2**(i-1))
464+ // ...
465+ // - register=N -> (0, 2**(N-1))
466+ // - lane=1 -> (0, 1)
467+ // ...
468+ // - lane=j -> (2**(j-1), 0)
469+ // ...
470+ // lane=M -> (2**(M-1), 0)
471+ // where out dims are: [register (size 2**(N-1)), lane (size 2**(M-1))]
472+ //
473+ // With N = M.
474+ const auto buildBasis = [&](int32_t size, std::size_t index) {
475+ std::vector<std::vector<int32_t >> basis;
476+ std::vector<int32_t > curr (2 );
477+ for (int32_t i = 1 ; i < size; i *= 2 ) {
478+ curr[index] = i;
479+ basis.push_back (curr);
480+ }
481+ return basis;
482+ };
483+
484+ constexpr std::size_t laneIndex = 0 ;
485+ constexpr std::size_t registerIndex = 1 ;
486+ int32_t size = conversion->getInDimSize (kLane );
470487 std::array<std::pair<StringAttr, std::vector<std::vector<int32_t >>>, 2 >
471- bases{{{kRegister , {{ 0 , 1 }, { 0 , 2 }, { 0 , 4 }, { 0 , 8 }} },
472- {kLane , {{ 1 , 0 }, { 2 , 0 }, { 4 , 0 }, { 8 , 0 }} }}};
488+ bases{{{kRegister , buildBasis (size, registerIndex) },
489+ {kLane , buildBasis (size, laneIndex) }}};
473490 std::array<StringAttr, 2 > outDimNames{kRegister , kLane };
474491 return conversion == LinearLayout (bases, outDimNames);
475492 }
0 commit comments