@@ -452,37 +452,41 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
452452 StringAttr kWarp = str_attr (" warp" );
453453 StringAttr kBlock = str_attr (" block" );
454454
455- LinearLayout comp = srcLayout .invertAndCompose (dstLayout );
455+ LinearLayout comp = dstLayout .invertAndCompose (srcLayout );
456456 std::optional<LinearLayout> conversion = comp.divideRight (
457457 LinearLayout::identity1D (comp.getInDimSize (kWarp ), kWarp , kWarp ) *
458458 LinearLayout::identity1D (comp.getInDimSize (kBlock ), kBlock , kBlock ));
459459 assert (conversion && " Expecting valid conversion" );
460460 // Expected conversion is:
461461 // - register=1 -> (0, 1)
462- // register=2 -> (0, 2)
463- // register=4 -> (0, 4)
464- // register=8 -> (0, 8)
465- // register=N -> (N, 0)
466- // ...
467- // - lane=1 -> (1, 0)
468- // lane=2 -> (2, 0)
469- // lane=4 -> (4, 0)
470- // lane=8 -> (8, 0)
471- // where out dims are: [register (size 2*N), lane (size 16)]
472- std::vector<std::vector<int32_t >> registerBases{
473- {0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }};
474- {
475- // Populate register bases for N > 8.
476- std::vector<int32_t > base (2 );
477- for (int32_t i = 16 , n = conversion->getInDimSize (kRegister ); i < n;
478- i *= 2 ) {
479- base.front () = i;
480- registerBases.push_back (base);
462+ // ...
463+ // - register=i -> (0, 2**(i-1))
464+ // ...
465+ // - register=N -> (0, 2**(N-1))
466+ // - lane=1 -> (0, 1)
467+ // ...
468+ // - lane=j -> (2**(j-1), 0)
469+ // ...
470+ // lane=M -> (2**(M-1), 0)
471+ // where out dims are: [register (size 2**(N-1)), lane (size 2**(M-1))]
472+ //
473+ // With N = M.
474+ const auto buildBasis = [&](int32_t size, std::size_t index) {
475+ std::vector<std::vector<int32_t >> basis;
476+ std::vector<int32_t > curr (2 );
477+ for (int32_t i = 1 ; i < size; i *= 2 ) {
478+ curr[index] = i;
479+ basis.push_back (curr);
481480 }
482- }
481+ return basis;
482+ };
483+
484+ constexpr std::size_t laneIndex = 0 ;
485+ constexpr std::size_t registerIndex = 1 ;
486+ int32_t size = conversion->getInDimSize (kLane );
483487 std::array<std::pair<StringAttr, std::vector<std::vector<int32_t >>>, 2 >
484- bases{{{kRegister , std::move (registerBases )},
485- {kLane , {{ 1 , 0 }, { 2 , 0 }, { 4 , 0 }, { 8 , 0 }} }}};
488+ bases{{{kRegister , buildBasis (size, registerIndex )},
489+ {kLane , buildBasis (size, laneIndex) }}};
486490 std::array<StringAttr, 2 > outDimNames{kRegister , kLane };
487491 return conversion == LinearLayout (bases, outDimNames);
488492 }
0 commit comments