Skip to content

Commit f899d51

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Fold sublane offset to indices when storing to untiled ref.
This optimization avoids unnecessary retiling when storing to untiled ref but adds at most one extra store op for sublane offset (since sublane offset is limieted to < VregSlice[0]). PiperOrigin-RevId: 698896373
1 parent f3e7e68 commit f899d51

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,14 +1640,14 @@ class VectorLayoutInferer {
16401640
// Since it is untiled, we can store to any arbitrary address which
16411641
// means the sublane offset can be any value and we can fold it to
16421642
// 2nd minor index.
1643-
// TODO(jevinjiang): We can fold the sublane offset into the 2nd minor
1644-
// index. But we need to handle negative index in lower-to-llo. For
1645-
// now, we just force the sublane offset to be 0.
1643+
auto prev_store_layout = getLayout(op.getValueToStore());
1644+
TPU_CHECK_OP(prev_store_layout.has_value(), "missing vector layout");
1645+
offsets[0] = prev_store_layout->offsets()[0].value_or(0);
16461646
if (offsets[1].value_or(0) >= tiling[1]) {
16471647
offsets[1] = 0;
16481648
}
1649-
store_layout = VectorLayout(bitwidth, {0, offsets[1]},
1650-
nativeTiling(bitwidth), ImplicitDim::kNone);
1649+
store_layout = VectorLayout(bitwidth, offsets, nativeTiling(bitwidth),
1650+
ImplicitDim::kNone);
16511651
} else {
16521652
store_layout = VectorLayout(bitwidth, offsets, {tiling[0], tiling[1]},
16531653
ImplicitDim::kNone);

0 commit comments

Comments
 (0)