Skip to content

Commit 23d5c10

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU] Fix fully replicated relayout
It was incorrect since batch dims are not replicated PiperOrigin-RevId: 703189919
1 parent 2a4a0e8 commit 23d5c10

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6588,15 +6588,20 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
65886588
/*use_implicit_shape=*/true);
65896589
}
65906590
if (src.layout_rank() >= dst.layout_rank() && !src.offsets()[0].has_value() &&
6591-
!src.offsets()[1].has_value() && src.tilesPerVreg(target_shape) == 1) {
6591+
!src.offsets()[1].has_value()) {
65926592
// A fully replicated value is always easy to relayout
6593-
// It would be nice to be able to assert this here, but given replicated
6594-
// values our rules can introduce equivalent expressions.
6595-
// assert all(t is src_tiles_list[0] for t in src_tiles_list)
65966593
xla::Array<Value> dst_tiles(
6597-
/*sizes=*/dst.tileArrayShape(vty.getShape(), target_shape),
6598-
/*value=*/src_tiles.data()[0]);
6599-
return assemble_with_mask_check(dst_tiles);
6594+
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
6595+
SmallVector<int64_t> idxs;
6596+
dst_tiles.Each([&](const absl::Span<const int64_t> src_idx, Value *vreg) {
6597+
idxs.assign(src_idx.begin(), src_idx.end());
6598+
dst.eraseImplicit(idxs);
6599+
src.insertImplicit<int64_t>(idxs, 0);
6600+
*(idxs.end() - 2) = 0;
6601+
*(idxs.end() - 1) = 0;
6602+
*vreg = src_tiles(idxs);
6603+
});
6604+
return assemble_with_mask_check(dst_tiles, /*use_implicit_shape=*/true);
66006605
}
66016606

66026607
// Consider (1,128),-2 -> (8,128). In this case we can change the implicit

0 commit comments

Comments
 (0)