Skip to content

Commit b0ea488

Browse files
committed
Address comments
1 parent 1c2f1a8 commit b0ea488

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -452,24 +452,41 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
452452
StringAttr kWarp = str_attr("warp");
453453
StringAttr kBlock = str_attr("block");
454454

455-
LinearLayout comp = srcLayout.invertAndCompose(dstLayout);
455+
LinearLayout comp = dstLayout.invertAndCompose(srcLayout);
456456
std::optional<LinearLayout> conversion = comp.divideRight(
457457
LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) *
458458
LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock));
459459
assert(conversion && "Expecting valid conversion");
460460
// Expected conversion is:
461461
// - register=1 -> (0, 1)
462-
// register=2 -> (0, 2)
463-
// register=4 -> (0, 4)
464-
// register=8 -> (0, 8)
465-
// - lane=1 -> (1, 0)
466-
// lane=2 -> (2, 0)
467-
// lane=4 -> (4, 0)
468-
// lane=8 -> (8, 0)
469-
// where out dims are: [register (size 16), lane (size 16)]
462+
// ...
463+
// - register=i -> (0, 2**(i-1))
464+
// ...
465+
// - register=N -> (0, 2**(N-1))
466+
// - lane=1 -> (0, 1)
467+
// ...
468+
// - lane=j -> (2**(j-1), 0)
469+
// ...
470+
// lane=M -> (2**(M-1), 0)
471+
// where out dims are: [register (size 2**(N-1)), lane (size 2**(M-1))]
472+
//
473+
// With N = M.
474+
const auto buildBasis = [&](int32_t size, std::size_t index) {
475+
std::vector<std::vector<int32_t>> basis;
476+
std::vector<int32_t> curr(2);
477+
for (int32_t i = 1; i < size; i *= 2) {
478+
curr[index] = i;
479+
basis.push_back(curr);
480+
}
481+
return basis;
482+
};
483+
484+
constexpr std::size_t laneIndex = 0;
485+
constexpr std::size_t registerIndex = 1;
486+
int32_t size = conversion->getInDimSize(kLane);
470487
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
471-
bases{{{kRegister, {{0, 1}, {0, 2}, {0, 4}, {0, 8}}},
472-
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}}};
488+
bases{{{kRegister, buildBasis(size, registerIndex)},
489+
{kLane, buildBasis(size, laneIndex)}}};
473490
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
474491
return conversion == LinearLayout(bases, outDimNames);
475492
}

0 commit comments

Comments
 (0)