Skip to content

Commit a8464ce

Browse files
[Mosaic][TPU] Omit short circuiting of relayout (we should always relayout!) and implement product mismatch case for where we relayout from replicated to offset, and the number of vregs changes.
PiperOrigin-RevId: 696557463
1 parent 89f411a commit a8464ce

File tree

1 file changed

+49
-10
lines changed

1 file changed

+49
-10
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4723,6 +4723,11 @@ FailureOr<xla::Array<Value>> disassemble(
47234723
TPU_ASSERT_LOC(val.getLoc(), def_layout.has_value());
47244724
TPU_ASSERT_LOC(val.getLoc(),
47254725
def_layout->generalizes(layout, vty.getShape(), target_shape));
4726+
auto layout_product =
4727+
xla::Product(layout.tileArrayShape(vty.getShape(), target_shape));
4728+
auto def_layout_product =
4729+
xla::Product(def_layout->tileArrayShape(vty.getShape(), target_shape));
4730+
TPU_ASSERT_LOC(val.getLoc(), layout_product == def_layout_product);
47264731
// TODO(tlongeri): Maybe just add a parameter to tileArrayShape instead of
47274732
// having `tileArrayShape` and `tileArrayImplicitShape`.
47284733
SmallVector<int64_t> layout_shape =
@@ -6324,11 +6329,50 @@ FailureOr<TypedValue<VectorType>> relayout(RewriteContext &ctx,
63246329
if (src.generalizes(dst, vty.getShape(), target_shape)) {
63256330
// A value with a replicated offset might use fewer vregs than a value with
63266331
// a non-zero offset.
6327-
if (xla::Product(src.tileArrayShape(vty.getShape(), target_shape)) !=
6328-
xla::Product(dst.tileArrayShape(vty.getShape(), target_shape))) {
6329-
return emitError(v.getLoc(),
6330-
"Not implemented: source layout is more general, but "
6331-
"vreg count changes");
6332+
auto src_product =
6333+
xla::Product(src.tileArrayShape(vty.getShape(), target_shape));
6334+
auto dst_product =
6335+
xla::Product(dst.tileArrayShape(vty.getShape(), target_shape));
6336+
if (src_product != dst_product) {
6337+
TPU_ASSERT_LOC(v.getLoc(), dst_product > src_product);
6338+
auto src_offsets = src.offsets();
6339+
6340+
TPU_ASSERT_LOC(v.getLoc(), src_offsets != dst.offsets());
6341+
TPU_ASSERT_LOC(v.getLoc(), src.bitwidth() == dst.bitwidth());
6342+
6343+
if (src.implicit_dim() != dst.implicit_dim()) {
6344+
return emitError(v.getLoc(),
6345+
"Not implemented: Source layout is more general, but "
6346+
"vreg count changes and implicit dims are mismatched");
6347+
}
6348+
6349+
if (src.tiling() != dst.tiling()) {
6350+
return emitError(v.getLoc(),
6351+
"Not implemented: Source layout is more general, but "
6352+
"vreg count changes and tiling are mismatched");
6353+
}
6354+
6355+
// This case is moving from a replicated to a non replicated layout.
6356+
// As such, we need to make a new destination shape that is the
6357+
// materialization of the src shape with replication.
6358+
FAILUREOR_ASSIGN_OR_RETURN(auto src_vregs,
6359+
disassemble(builder, src, v, target_shape,
6360+
/*use_implicit_shape=*/true));
6361+
auto dst_vregs_shape = dst.tileArrayShape(vty.getShape(), target_shape);
6362+
xla::Array<Value> dst_vregs(dst_vregs_shape);
6363+
dst_vregs.Each([&](const absl::Span<const int64_t> idx, Value *vreg) {
6364+
SmallVector<int64_t> local_idx(idx.begin(), idx.end());
6365+
if (!src_offsets[0].has_value()) {
6366+
local_idx[local_idx.size() - 2] = 0;
6367+
}
6368+
if (!src_offsets[1].has_value()) {
6369+
local_idx[local_idx.size() - 1] = 0;
6370+
}
6371+
*vreg = src_vregs(local_idx);
6372+
});
6373+
return assemble(builder, vty, dst, std::move(dst_vregs), target_shape,
6374+
/*use_implicit_shape=*/true)
6375+
.getResult();
63326376
}
63336377
src_tiles.Reshape(dst.tileArrayImplicitShape(vty.getShape(), target_shape));
63346378
return assemble(builder, vty, dst, std::move(src_tiles), target_shape,
@@ -6411,8 +6455,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
64116455
if (vector_operand == nullptr) {
64126456
continue;
64136457
}
6414-
auto vty = vector_operand.getType();
6415-
64166458
// The operand should always be an Operation (and not a BlockArgument)
64176459
// since we expect the FuncOp to have only memrefs and semaphores as
64186460
// arguments.
@@ -6427,9 +6469,6 @@ LogicalResult applyLayoutOp(RewriteContext &ctx, Operation &op) {
64276469
getOutLayouts(*def_op, ctx.target_shape));
64286470
const Layout lo = def_layouts[res_idx];
64296471
TPU_ASSERT_OP(lo.has_value());
6430-
if (lo->generalizes(*li, vty.getShape(), ctx.target_shape)) {
6431-
continue;
6432-
}
64336472
OpBuilder builder(&op);
64346473
FAILUREOR_ASSIGN_OR_RETURN(
64356474
Value new_v, relayout(ctx, builder, vector_operand, /*src=*/*lo,

0 commit comments

Comments
 (0)