@@ -457,12 +457,21 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
457457 LinearLayout::identity1D (comp.getInDimSize (kWarp ), kWarp , kWarp ) *
458458 LinearLayout::identity1D (comp.getInDimSize (kBlock ), kBlock , kBlock ));
459459 assert (conversion && " Expecting valid conversion" );
460- LinearLayout id =
461- LinearLayout::identity1D (conversion->getInDimSize (kRegister ), kRegister ,
462- kRegister ) *
463- LinearLayout::identity1D (conversion->getInDimSize (kLane ), kLane , kLane );
464- // Composing the transposition with itself should give us the identity.
465- return id == conversion->compose (*conversion);
460+ // Expected conversion is:
461+ // - register=1 -> (0, 1)
462+ // register=2 -> (0, 2)
463+ // register=4 -> (0, 4)
464+ // register=8 -> (0, 8)
465+ // - lane=1 -> (1, 0)
466+ // lane=2 -> (2, 0)
467+ // lane=4 -> (4, 0)
468+ // lane=8 -> (8, 0)
469+ // where out dims are: [register (size 16), lane (size 16)]
470+ std::array<std::pair<StringAttr, std::vector<std::vector<int32_t >>>, 2 >
471+ bases{{{kRegister , {{0 , 1 }, {0 , 2 }, {0 , 4 }, {0 , 8 }}},
472+ {kLane , {{1 , 0 }, {2 , 0 }, {4 , 0 }, {8 , 0 }}}}};
473+ std::array<StringAttr, 2 > outDimNames{kRegister , kLane };
474+ return conversion == LinearLayout (bases, outDimNames);
466475 }
467476
468477 LogicalResult
0 commit comments