Skip to content

Commit 21f6b40

Browse files
WindQAQGoogle-ML-Automation
authored andcommitted
[Mosaic] Pad trailing transposes chunks with zeros.
PiperOrigin-RevId: 705310340
1 parent 39e4f7f commit 21f6b40

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4623,7 +4623,9 @@ LogicalResult vector_transpose_rule(RewriteContext &ctx, Operation &op,
46234623
incremented_batch_idx.end());
46244624
src_slice_ends.append({(src_row + 1) * vregs_per_tile, src_col_end});
46254625
xla::Array<Value> src_tile_vregs = src_vregs.Slice(
4626-
src_slice_starts, src_slice_ends, /*out_of_bounds_ok=*/true);
4626+
src_slice_starts, src_slice_ends,
4627+
builder.create<arith::ConstantOp>(
4628+
op.getLoc(), builder.getZeroAttr(src_vregs.begin()->getType())));
46274629
// Drop leading singleton (batch) dimensions to have a shape that conforms
46284630
// with the vreg array shape specified by layout_in, as expected by assemble
46294631
src_tile_vregs.Reshape(

0 commit comments

Comments
 (0)