@@ -461,30 +461,49 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
461461 assert (conversion && " Expecting valid conversion" );
462462 // Expected conversion is:
463463 // - register=1 -> (0, 1)
464- // register=2 -> (0, 2)
465- // register=4 -> (0, 4)
466- // register=8 -> (0, 8)
467- // register=N -> (N, 0)
468- // ...
469- // - lane=1 -> (1, 0)
470- // lane=2 -> (2, 0)
471- // lane=4 -> (4, 0)
472- // lane=8 -> (8, 0)
473- // where out dims are: [register (size 2*N), lane (size 16)]
474- std::vector<std::vector<int32_t >> registerBases{
475- {0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }};
464+ // ...
465+ // - register=2**i -> (0, 2**i)
466+ // ...
467+ // - register=M -> (0, 2**M)
468+ // ...
469+ // - register=2**k -> (2**k, 0)
470+ // ...
471+ // - register=N -> (2**N, 0)
472+ // - lane=1 -> (0, 1)
473+ // ...
474+ // - lane=2**j -> (2**j, 0)
475+ // ...
476+ // lane=2**M -> (2**M, 0)
477+ // where out dims are: [register (size 2**(N + 1)), lane (size 2**(M + 1))]
478+ //
479+ // With N >= M.
480+ const auto buildBasis = [&](int32_t size, std::size_t index) {
481+ std::vector<std::vector<int32_t >> basis;
482+ std::vector<int32_t > curr (2 );
483+ for (int32_t i = 1 ; i < size; i *= 2 ) {
484+ curr[index] = i;
485+ basis.push_back (curr);
486+ }
487+ return basis;
488+ };
489+ constexpr std::size_t laneIndex = 0 ;
490+ constexpr std::size_t registerIndex = 1 ;
491+ int32_t laneSize = conversion->getInDimSize (kLane );
492+ std::vector<std::vector<int32_t >> registerBases =
493+ buildBasis (laneSize, registerIndex);
476494 {
477- // Populate register bases for N > 8 .
495+ // Populate register bases for N > M .
478496 std::vector<int32_t > base (2 );
479- for (int32_t i = 16 , n = conversion->getInDimSize (kRegister ); i < n;
480- i *= 2 ) {
481- base.front () = i;
497+ for (int32_t i = laneSize,
498+ registerSize = conversion->getInDimSize (kRegister );
499+ i < registerSize; i *= 2 ) {
500+ base[laneIndex] = i;
482501 registerBases.push_back (base);
483502 }
484503 }
485504 std::array<std::pair<StringAttr, std::vector<std::vector<int32_t >>>, 2 >
486505 bases{{{kRegister , std::move (registerBases)},
487- {kLane , {{ 1 , 0 }, { 2 , 0 }, { 4 , 0 }, { 8 , 0 }} }}};
506+ {kLane , buildBasis (laneSize, laneIndex) }}};
488507 std::array<StringAttr, 2 > outDimNames{kRegister , kLane };
489508 return conversion == LinearLayout (bases, outDimNames);
490509 }
@@ -853,18 +872,7 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
853872 llvm::transform (
854873 outVals, std::begin (outVals),
855874 [&](Value val) -> Value { return inttoptr (ptrTy, val); });
856- })As a follow up to #2266 , extend work in #2531 to detect more complex broadcast shuffles.
857-
858- Cases with more than 1 warp in the " sliced" dimension are problematic here, e.g .:
859-
860- ```mlir
861- #blocked = #triton_gpu.blocked <{sizePerThread = [16 , 1 , 1 , 1 , 1 , 1 , 1 ], threadsPerWarp = [1 , 1 , 1 , 16 , 1 , 1 , 1 ], warpsPerCTA = [1 , 1 , 16 , 1 , 1 , 1 , 1 ], order = [3 , 4 , 5 , 6 , 0 , 1 , 2 ]}>
862- #blocked1 = #triton_gpu.blocked <{sizePerThread = [1 , 1 , 1 , 16 , 1 ], threadsPerWarp = [16 , 1 , 1 , 1 , 1 ], warpsPerCTA = [1 , 1 , 16 , 1 , 1 ], order = [3 , 4 , 0 , 1 , 2 ]}>
863- // ...
864- triton_gpu.convert_layout %arg : tensor<16x1x16x16x1xf32, #triton_gpu.slice <{dim = 4 , parent = #triton_gpu.slice <{dim = 6 , parent = #blocked}>}>> -> tensor<16x1x16x16x1xf32, #blocked1>
865- ```
866-
867- Is lowered to a shufle via
875+ })
868876 .Default ([](auto ) { llvm_unreachable (" Unsupported type" ); });
869877
870878 Value result = packLLElements (loc, getTypeConverter (), outVals, rewriter,
@@ -967,9 +975,6 @@ Is lowered to a shufle via
967975 // TODO(jlebar): Implement me.
968976 return failure ();
969977 }
970-
971- private:
972- const triton::intel::TargetInfo &targetInfo;
973978};
974979
975980} // namespace
0 commit comments