Skip to content

Commit 1011687

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU] Lift offset restrictions on single-row (1, 128) -> (8, 128) 32-bit replicated retiling
PiperOrigin-RevId: 702966495
1 parent f160df0 commit 1011687

File tree

1 file changed

+31
-17
lines changed

1 file changed

+31
-17
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 31 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -6116,35 +6116,49 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
61166116
}
61176117
const int packing = src.packing();
61186118
const int8_t bitwidth = src.bitwidth();
6119-
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
6120-
src.implicit_dim());
6121-
if (!dst.isValid(target_shape)) {
6122-
return emitError(loc, "Not implemented: invalid offsets in tiling target");
6123-
}
6124-
auto dst_tiles_shape =
6125-
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
61266119
// Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating
61276120
// sublanes.
61286121
if (try_replicate_rows && packing == 1 &&
61296122
*(vregs.dimensions().end() - 2) == 1 &&
6130-
src.offsets() == LayoutOffsets{0, 0} &&
61316123
src.tiling() == std::array<int64_t, 2>{1, ctx.target_shape[1]} &&
61326124
dst_tiling == ctx.target_shape) {
6133-
xla::Array<Value> retiled(dst_tiles_shape);
6125+
DCHECK_EQ(src.offsets()[0].value_or(0), 0);
6126+
const LayoutOffset dst_minor_offset =
6127+
src.offsets()[1] ? LayoutOffset(*src.offsets()[1] % target_shape[1])
6128+
: std::nullopt;
6129+
const VectorLayout dst(bitwidth, {std::nullopt, dst_minor_offset},
6130+
dst_tiling, src.implicit_dim());
6131+
xla::Array<Value> retiled(
6132+
dst.tileArrayImplicitShape(vty.getShape(), target_shape));
61346133
retiled.Each([&](absl::Span<const int64_t> idx, Value *tile) {
61356134
SmallVector<int64_t> src_idx(idx.begin(), idx.end());
61366135
*(src_idx.end() - 2) *= target_shape[0];
6137-
*(src_idx.end() - 1) /= target_shape[0];
6138-
const int64_t src_sl_idx = *(idx.end() - 1) % target_shape[0];
6139-
CHECK_EQ(src.getImplicitTiledDims(vty.getShape(), 1)[0], 1);
6140-
*tile =
6141-
broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape);
6136+
if (!src.offsets()[1].has_value()) {
6137+
// With (1, 128) tiling each vreg holds values from a single row. This
6138+
// means that if the columns are replicated, then the whole vreg is
6139+
// already replicated.
6140+
*(src_idx.end() - 1) = 0;
6141+
*tile = vregs(src_idx);
6142+
} else {
6143+
// The column (in units of sublanes) of the sublane we want:
6144+
const int64_t sublane_column =
6145+
*(src_idx.end() - 1) + *src.offsets()[1] / target_shape[1];
6146+
*(src_idx.end() - 1) = sublane_column / target_shape[0];
6147+
const int64_t src_sl_idx = sublane_column % target_shape[0];
6148+
*tile =
6149+
broadcastSublane(builder, vregs(src_idx), src_sl_idx, target_shape);
6150+
}
61426151
});
6143-
// We have successfully replicated sublanes.
6144-
dst = VectorLayout(bitwidth, {std::nullopt, dst.offsets()[1]}, dst_tiling,
6145-
dst.implicit_dim());
6152+
// We have successfully replicated sublanes
61466153
return std::pair(dst, std::move(retiled));
61476154
}
6155+
VectorLayout dst(src.bitwidth(), src.offsets(), dst_tiling,
6156+
src.implicit_dim());
6157+
if (!dst.isValid(target_shape)) {
6158+
return emitError(loc, "Not implemented: invalid offsets in tiling target");
6159+
}
6160+
auto dst_tiles_shape =
6161+
dst.tileArrayImplicitShape(vty.getShape(), target_shape);
61486162
// (8,128) -> (8 * packing,128) tiling change for packed type.
61496163
if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
61506164
dst_tiling == std::array<int64_t, 2>{ctx.target_shape[0] * dst.packing(),

0 commit comments

Comments
 (0)