@@ -4723,6 +4723,11 @@ FailureOr<xla::Array<Value>> disassemble(
47234723 TPU_ASSERT_LOC (val.getLoc (), def_layout.has_value ());
47244724 TPU_ASSERT_LOC (val.getLoc (),
47254725 def_layout->generalizes (layout, vty.getShape (), target_shape));
4726+ auto layout_product =
4727+ xla::Product (layout.tileArrayShape (vty.getShape (), target_shape));
4728+ auto def_layout_product =
4729+ xla::Product (def_layout->tileArrayShape (vty.getShape (), target_shape));
4730+ TPU_ASSERT_LOC (val.getLoc (), layout_product == def_layout_product);
47264731 // TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of
47274732 // having `tileArrayShape` and `tileArrayImplicitShape`.
47284733 SmallVector<int64_t > layout_shape =
@@ -6324,11 +6329,50 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
63246329 if (src.generalizes (dst, vty.getShape (), target_shape)) {
63256330 // A value with a replicated offset might use fewer vregs than a value with
63266331 // a non-zero offset.
6327- if (xla::Product (src.tileArrayShape (vty.getShape (), target_shape)) !=
6328- xla::Product (dst.tileArrayShape (vty.getShape (), target_shape))) {
6329- return emitError (v.getLoc (),
6330- " Not implemented: source layout is more general, but "
6331- " vreg count changes" );
6332+ auto src_product =
6333+ xla::Product (src.tileArrayShape (vty.getShape (), target_shape));
6334+ auto dst_product =
6335+ xla::Product (dst.tileArrayShape (vty.getShape (), target_shape));
6336+ if (src_product != dst_product) {
6337+ TPU_ASSERT_LOC (v.getLoc (), dst_product > src_product);
6338+ auto src_offsets = src.offsets ();
6339+
6340+ TPU_ASSERT_LOC (v.getLoc (), src_offsets != dst.offsets ());
6341+ TPU_ASSERT_LOC (v.getLoc (), src.bitwidth () == dst.bitwidth ());
6342+
6343+ if (src.implicit_dim () != dst.implicit_dim ()) {
6344+ return emitError (v.getLoc (),
6345+ " Not implemented: Source layout is more general, but "
6346+ " vreg count changes and implicit dims are mismatched" );
6347+ }
6348+
6349+ if (src.tiling () != dst.tiling ()) {
6350+ return emitError (v.getLoc (),
6351+ " Not implemented: Source layout is more general, but "
6352+ " vreg count changes and tiling are mismatched" );
6353+ }
6354+
6355+ // This case is moving from a replicated to a non replicated layout.
6356+ // As such, we need to make a new destination shape that is the
6357+ // materialization of the src shape with replication.
6358+ FAILUREOR_ASSIGN_OR_RETURN (auto src_vregs,
6359+ disassemble (builder, src, v, target_shape,
6360+ /* use_implicit_shape=*/ true ));
6361+ auto dst_vregs_shape = dst.tileArrayShape (vty.getShape (), target_shape);
6362+ xla::Array<Value> dst_vregs (dst_vregs_shape);
6363+ dst_vregs.Each ([&](const absl::Span<const int64_t > idx, Value *vreg) {
6364+ SmallVector<int64_t > local_idx (idx.begin (), idx.end ());
6365+ if (!src_offsets[0 ].has_value ()) {
6366+ local_idx[local_idx.size () - 2 ] = 0 ;
6367+ }
6368+ if (!src_offsets[1 ].has_value ()) {
6369+ local_idx[local_idx.size () - 1 ] = 0 ;
6370+ }
6371+ *vreg = src_vregs (local_idx);
6372+ });
6373+ return assemble (builder, vty, dst, std::move (dst_vregs), target_shape,
6374+ /* use_implicit_shape=*/ true )
6375+ .getResult ();
63326376 }
63336377 src_tiles.Reshape (dst.tileArrayImplicitShape (vty.getShape (), target_shape));
63346378 return assemble (builder, vty, dst, std::move (src_tiles), target_shape,
@@ -6411,8 +6455,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
64116455 if (vector_operand == nullptr ) {
64126456 continue ;
64136457 }
6414- auto vty = vector_operand.getType ();
6415-
64166458 // The operand should always be an Operation (and not a BlockArgument)
64176459 // since we expect the FuncOp to have only memrefs and semaphores as
64186460 // arguments.
@@ -6427,9 +6469,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
64276469 getOutLayouts (*def_op, ctx.target_shape ));
64286470 const Layout lo = def_layouts[res_idx];
64296471 TPU_ASSERT_OP (lo.has_value ());
6430- if (lo->generalizes (*li, vty.getShape (), ctx.target_shape )) {
6431- continue ;
6432- }
64336472 OpBuilder builder (&op);
64346473 FAILUREOR_ASSIGN_OR_RETURN (
64356474 Value new_v, relayout (ctx, builder, vector_operand, /* src=*/ *lo,
0 commit comments