Skip to content

Commit f6649b9

Browse files
committed
Fix layout check
1 parent 5b59ae4 commit f6649b9

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -457,12 +457,21 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
457457
LinearLayout::identity1D(comp.getInDimSize(kWarp), kWarp, kWarp) *
458458
LinearLayout::identity1D(comp.getInDimSize(kBlock), kBlock, kBlock));
459459
assert(conversion && "Expecting valid conversion");
460-
LinearLayout id =
461-
LinearLayout::identity1D(conversion->getInDimSize(kRegister), kRegister,
462-
kRegister) *
463-
LinearLayout::identity1D(conversion->getInDimSize(kLane), kLane, kLane);
464-
// Composing the transposition with itself should give us the identity.
465-
return id == conversion->compose(*conversion);
460+
// Expected conversion is:
461+
// - register=1 -> (0, 1)
462+
// register=2 -> (0, 2)
463+
// register=4 -> (0, 4)
464+
// register=8 -> (0, 8)
465+
// - lane=1 -> (1, 0)
466+
// lane=2 -> (2, 0)
467+
// lane=4 -> (4, 0)
468+
// lane=8 -> (8, 0)
469+
// where out dims are: [register (size 16), lane (size 16)]
470+
std::array<std::pair<StringAttr, std::vector<std::vector<int32_t>>>, 2>
471+
bases{{{kRegister, {{0, 1}, {0, 2}, {0, 4}, {0, 8}}},
472+
{kLane, {{1, 0}, {2, 0}, {4, 0}, {8, 0}}}}};
473+
std::array<StringAttr, 2> outDimNames{kRegister, kLane};
474+
return conversion == LinearLayout(bases, outDimNames);
466475
}
467476

468477
LogicalResult

0 commit comments

Comments
 (0)