Skip to content

Commit 4911a39

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Add support for the interleaved pack format to tpu.unpack_subelements
PiperOrigin-RevId: 707142562
1 parent 36b12d5 commit 4911a39

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
363363
def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
364364
let arguments = (ins
365365
AnyVectorOfNonZeroRank:$source,
366-
I32Attr:$index
366+
I32Attr:$index,
367+
TPU_PackFormatEnum:$pack_format
367368
);
368369
let results = (outs AnyVectorOfNonZeroRank:$output);
369370
let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }];

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,8 @@ FailureOr<xla::Array<Value>> ext_op_rule_impl(RewriteContext &ctx,
874874
int64_t vreg_part = *(input_vreg_idxs.end() - 2) % packing;
875875
*(input_vreg_idxs.end() - 2) /= packing;
876876
*v = builder.create<UnpackSubelementsOp>(
877-
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
877+
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part,
878+
tpu::PackFormat::kCompressed);
878879
});
879880
} else {
880881
if (layout_in.tiling() != layout_out.tiling()) {
@@ -890,7 +891,8 @@ FailureOr<xla::Array<Value>> ext_op_rule_impl(RewriteContext &ctx,
890891
input_vreg_idxs.back() /= packing;
891892
const int64_t vreg_part = idxs.back() % packing;
892893
*v = builder.create<UnpackSubelementsOp>(
893-
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part);
894+
op.getLoc(), res_vreg_ty, input_vregs(input_vreg_idxs), vreg_part,
895+
tpu::PackFormat::kCompressed);
894896
});
895897
}
896898
return output_vregs;
@@ -6265,7 +6267,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
62656267
src_idx[src_idx.size() - 1] /= packing;
62666268
for (int i = 0; i < packing; ++i) {
62676269
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
6268-
loc, vreg_x32, vregs(src_idx), vreg_part));
6270+
loc, vreg_x32, vregs(src_idx), vreg_part,
6271+
tpu::PackFormat::kCompressed));
62696272
if (src_idx[src_idx.size() - 2] <
62706273
vregs.dim(vregs.num_dimensions() - 2) - 1) {
62716274
++src_idx[src_idx.size() - 2];
@@ -6345,7 +6348,8 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
63456348
*(src_idx.end() - 1) /= packing;
63466349
for (int i = 0; i < packing; ++i) {
63476350
parts.push_back(builder.create<tpu::UnpackSubelementsOp>(
6348-
loc, vreg_x32, vregs(src_idx), vreg_part));
6351+
loc, vreg_x32, vregs(src_idx), vreg_part,
6352+
tpu::PackFormat::kCompressed));
63496353
if (*(src_idx.end() - 2) < *(vregs.dimensions().end() - 2) - 1) {
63506354
++*(src_idx.end() - 2);
63516355
} // The rest is padding, so just pick any of the input parts (but not

0 commit comments

Comments
 (0)