@@ -6588,15 +6588,20 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
65886588 /* use_implicit_shape=*/ true );
65896589 }
65906590 if (src.layout_rank () >= dst.layout_rank () && !src.offsets ()[0 ].has_value () &&
6591- !src.offsets ()[1 ].has_value () && src. tilesPerVreg (target_shape) == 1 ) {
6591+ !src.offsets ()[1 ].has_value ()) {
65926592 // A fully replicated value is always easy to relayout
6593- // It would be nice to be able to assert this here, but given replicated
6594- // values our rules can introduce equivalent expressions.
6595- // assert all(t is src_tiles_list[0] for t in src_tiles_list)
65966593 xla::Array<Value> dst_tiles (
6597- /* sizes=*/ dst.tileArrayShape (vty.getShape (), target_shape),
6598- /* value=*/ src_tiles.data ()[0 ]);
6599- return assemble_with_mask_check (dst_tiles);
6594+ dst.tileArrayImplicitShape (vty.getShape (), target_shape));
6595+ SmallVector<int64_t > idxs;
6596+ dst_tiles.Each ([&](const absl::Span<const int64_t > src_idx, Value *vreg) {
6597+ idxs.assign (src_idx.begin (), src_idx.end ());
6598+ dst.eraseImplicit (idxs);
6599+ src.insertImplicit <int64_t >(idxs, 0 );
6600+ *(idxs.end () - 2 ) = 0 ;
6601+ *(idxs.end () - 1 ) = 0 ;
6602+ *vreg = src_tiles (idxs);
6603+ });
6604+ return assemble_with_mask_check (dst_tiles, /* use_implicit_shape=*/ true );
66006605 }
66016606
66026607 // Consider (1,128),-2 -> (8,128). In this case we can change the implicit
0 commit comments