diff --git a/jaxlib/mosaic/dialect/tpu/tpu.td b/jaxlib/mosaic/dialect/tpu/tpu.td index e557de2976c0..6521bea5f18a 100644 --- a/jaxlib/mosaic/dialect/tpu/tpu.td +++ b/jaxlib/mosaic/dialect/tpu/tpu.td @@ -593,11 +593,35 @@ def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> { } // Integer unpacks are always signed at the moment. +// +// When unpacking integers to integers, setting `sign_extended` to false will +// leave bits higher than source bitwidth as undefined. +// +// Take int4 to int16 interleaved unpacking and `index = 1` as an example: +// +// Source: +// +// Bits 28 24 20 16 12 8 4 0 +// --------abcd------------efgh---- +// +// where "a" and "e" are the sign bits of the values to be unpacked, and "-" are +// bits to be ignored. +// +// Unpacked, sign_extend = true: +// +// Bits 28 24 20 16 12 8 4 0 +// aaaaaaaaaaaaabcdeeeeeeeeeeeeefgh +// +// Unpacked, sign_extend = false: +// +// Bits 28 24 20 16 12 8 4 0 +// ------------abcd------------efgh def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> { let arguments = (ins AnyVectorOfNonZeroRank:$source, I32Attr:$index, - TPU_PackFormatEnum:$pack_format + TPU_PackFormatEnum:$pack_format, + DefaultValuedAttr:$sign_extended ); let results = (outs AnyVectorOfNonZeroRank:$output); let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }]; diff --git a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc index b72bf9ea6b0b..dc9534e47556 100644 --- a/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc +++ b/jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc @@ -1632,8 +1632,12 @@ FailureOr> packVregs(RewriteContext &ctx, OpBuilder &builder, for (Value part : parts) { if (part) { for (int i = 0; i < packing_factor; ++i) { + // Note that input bitwidth is larger than result bitwidth. We + // don't need sign extension here because the following packing + // ends up truncating sign-extended bits. unpacks.push_back(builder.create( - loc, unpacked_vty, part, i, pack_format)); + loc, unpacked_vty, part, i, pack_format, + /*sign_extended=*/false)); } } else { unpacks.append(packing_factor, nullptr); @@ -5704,28 +5708,31 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op, VectorType unpacked_vty = getNativeVregType( builder.getIntegerType(src_ty.getElementTypeBitWidth() * 2), ctx.target_shape); - src_vregs.Each([&](absl::Span src_vreg_indices, - Value* src_vreg) { - Value dst_vreg = builder.create( - src_vreg->getLoc(), packed_vty, *src_vreg); - int64_t from_sublane_tiling = src_sublane_tiling; - while (from_sublane_tiling != dst_sublane_tiling) { - std::array src_parts; - for (int i = 0; i < src_parts.size(); ++i) { - src_parts[i] = builder.create( - src_vreg->getLoc(), unpacked_vty, dst_vreg, i, unpack_format); - } - dst_vreg = builder.create( - src_vreg->getLoc(), packed_vty, src_parts, pack_format); - if (from_sublane_tiling > dst_sublane_tiling) { - from_sublane_tiling /= 2; - } else { - from_sublane_tiling *= 2; - } - } - *src_vreg = builder.create( - src_vreg->getLoc(), src_vreg->getType(), dst_vreg); - }); + src_vregs.Each( + [&](absl::Span src_vreg_indices, Value* src_vreg) { + Value dst_vreg = builder.create( + src_vreg->getLoc(), packed_vty, *src_vreg); + int64_t from_sublane_tiling = src_sublane_tiling; + while (from_sublane_tiling != dst_sublane_tiling) { + std::array src_parts; + for (int i = 0; i < src_parts.size(); ++i) { + // We don't need sign extension here because the following + // packing ends up truncating sign-extended bits. + src_parts[i] = builder.create( + src_vreg->getLoc(), unpacked_vty, dst_vreg, i, + unpack_format, /*sign_extended=*/false); + } + dst_vreg = builder.create( + src_vreg->getLoc(), packed_vty, src_parts, pack_format); + if (from_sublane_tiling > dst_sublane_tiling) { + from_sublane_tiling /= 2; + } else { + from_sublane_tiling *= 2; + } + } + *src_vreg = builder.create( + src_vreg->getLoc(), src_vreg->getType(), dst_vreg); + }); } src_vregs.Reshape(dst_vregs_shape); return src_vregs;