Skip to content

Commit b011666

Browse files
committed
Address comments
1 parent 341b1fc commit b011666

File tree

1 file changed

+27
-23
lines changed

1 file changed

+27
-23
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 27 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -452,37 +452,41 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
452452
StringAttr kWarp = str_attr("warp");
453453
StringAttr kBlock = str_attr("block");
454454

455-
LinearLayout comp = srcLayout.invertAndCompose(dstLayout);
455+
LinearLayout comp = dstLayout.invertAndCompose(srcLayout);
456456
std::optional<LinearLayout> conversion = comp.divideRight(
457457
LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) *
458458
LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock));
459459
assert(conversion && "Expecting valid conversion");
460460
// Expected conversion is:
461461
// - register=1 -> (0, 1)
462-
// register=2 -> (0, 2)
463-
// register=4 -> (0, 4)
464-
// register=8 -> (0, 8)
465-
// register=N -> (N, 0)
466-
// ...
467-
// - lane=1 -> (1, 0)
468-
// lane=2 -> (2, 0)
469-
// lane=4 -> (4, 0)
470-
// lane=8 -> (8, 0)
471-
// where out dims are: [register (size 2*N), lane (size 16)]
472-
std::vector<std::vector<int32_t>> registerBases{
473-
{0, 1}, {0, 2}, {0, 4}, {0, 8}};
474-
{
475-
// Populate register bases for N > 8.
476-
std::vector<int32_t> base(2);
477-
for (int32_t i = 16, n = conversion->getInDimSize(kRegister); i < n;
478-
i *= 2) {
479-
base.front() = i;
480-
registerBases.push_back(base);
462+
// ...
463+
// - register=i -> (0, 2**(i-1))
464+
// ...
465+
// - register=N -> (0, 2**(N-1))
466+
// - lane=1 -> (0, 1)
467+
// ...
468+
// - lane=j -> (2**(j-1), 0)
469+
// ...
470+
// lane=M -> (2**(M-1), 0)
471+
// where out dims are: [register (size 2**(N-1)), lane (size 2**(M-1))]
472+
//
473+
// With N = M.
474+
const auto buildBasis = [&](int32_t size, std::size_t index) {
475+
std::vector<std::vector<int32_t>> basis;
476+
std::vector<int32_t> curr(2);
477+
for (int32_t i = 1; i < size; i *= 2) {
478+
curr[index] = i;
479+
basis.push_back(curr);
481480
}
482-
}
481+
return basis;
482+
};
483+
484+
constexpr std::size_t laneIndex = 0;
485+
constexpr std::size_t registerIndex = 1;
486+
int32_t size = conversion->getInDimSize(kLane);
483487
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
484-
bases{{{kRegister, std::move(registerBases)},
485-
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}}};
488+
bases{{{kRegister, buildBasis(size, registerIndex)},
489+
{kLane, buildBasis(size, laneIndex)}}};
486490
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
487491
return conversion == LinearLayout(bases, outDimNames);
488492
}

0 commit comments

Comments
 (0)