Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 25 additions & 1 deletion jaxlib/mosaic/dialect/tpu/tpu.td
Original file line number Diff line number Diff line change
Expand Up @@ -593,11 +593,35 @@ def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
}

// Integer unpacks are always signed at the moment.
//
// When unpacking integers to integers, setting `sign_extended` to false will
// leave bits higher than source bitwidth as undefined.
//
// Take int4 to int16 interleaved unpacking and `index = 1` as an example:
//
// Source:
//
// Bits 28 24 20 16 12 8 4 0
// --------abcd------------efgh----
//
// where "a" and "e" are the sign bits of the values to be unpacked, and "-" are
// bits to be ignored.
//
// Unpacked, sign_extend = true:
//
// Bits 28 24 20 16 12 8 4 0
// aaaaaaaaaaaaabcdeeeeeeeeeeeeefgh
//
// Unpacked, sign_extend = false:
//
// Bits 28 24 20 16 12 8 4 0
// ------------abcd------------efgh
def TPU_UnpackSubelementsOp : TPU_Op<"unpack_subelements", [Pure]> {
let arguments = (ins
AnyVectorOfNonZeroRank:$source,
I32Attr:$index,
TPU_PackFormatEnum:$pack_format
TPU_PackFormatEnum:$pack_format,
DefaultValuedAttr<BoolAttr, "true">:$sign_extended
);
let results = (outs AnyVectorOfNonZeroRank:$output);
let assemblyFormat = [{ $source `,` $index attr-dict `:` type($source) `->` type($output) }];
Expand Down
53 changes: 30 additions & 23 deletions jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1632,8 +1632,12 @@ FailureOr<xla::Array<Value>> packVregs(RewriteContext &ctx, OpBuilder &builder,
for (Value part : parts) {
if (part) {
for (int i = 0; i < packing_factor; ++i) {
// Note that input bitwidth is larger than result bitwidth. We
// don't need sign extension here because the following packing
// ends up truncating sign-extended bits.
unpacks.push_back(builder.create<UnpackSubelementsOp>(
loc, unpacked_vty, part, i, pack_format));
loc, unpacked_vty, part, i, pack_format,
/*sign_extended=*/false));
}
} else {
unpacks.append(packing_factor, nullptr);
Expand Down Expand Up @@ -5704,28 +5708,31 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
VectorType unpacked_vty = getNativeVregType(
builder.getIntegerType(src_ty.getElementTypeBitWidth() * 2),
ctx.target_shape);
src_vregs.Each([&](absl::Span<const int64_t> src_vreg_indices,
Value* src_vreg) {
Value dst_vreg = builder.create<tpu::BitcastVregOp>(
src_vreg->getLoc(), packed_vty, *src_vreg);
int64_t from_sublane_tiling = src_sublane_tiling;
while (from_sublane_tiling != dst_sublane_tiling) {
std::array<Value, 2> src_parts;
for (int i = 0; i < src_parts.size(); ++i) {
src_parts[i] = builder.create<tpu::UnpackSubelementsOp>(
src_vreg->getLoc(), unpacked_vty, dst_vreg, i, unpack_format);
}
dst_vreg = builder.create<tpu::PackSubelementsOp>(
src_vreg->getLoc(), packed_vty, src_parts, pack_format);
if (from_sublane_tiling > dst_sublane_tiling) {
from_sublane_tiling /= 2;
} else {
from_sublane_tiling *= 2;
}
}
*src_vreg = builder.create<tpu::BitcastVregOp>(
src_vreg->getLoc(), src_vreg->getType(), dst_vreg);
});
src_vregs.Each(
[&](absl::Span<const int64_t> src_vreg_indices, Value* src_vreg) {
Value dst_vreg = builder.create<tpu::BitcastVregOp>(
src_vreg->getLoc(), packed_vty, *src_vreg);
int64_t from_sublane_tiling = src_sublane_tiling;
while (from_sublane_tiling != dst_sublane_tiling) {
std::array<Value, 2> src_parts;
for (int i = 0; i < src_parts.size(); ++i) {
// We don't need sign extension here because the following
// packing ends up truncating sign-extended bits.
src_parts[i] = builder.create<tpu::UnpackSubelementsOp>(
src_vreg->getLoc(), unpacked_vty, dst_vreg, i,
unpack_format, /*sign_extended=*/false);
}
dst_vreg = builder.create<tpu::PackSubelementsOp>(
src_vreg->getLoc(), packed_vty, src_parts, pack_format);
if (from_sublane_tiling > dst_sublane_tiling) {
from_sublane_tiling /= 2;
} else {
from_sublane_tiling *= 2;
}
}
*src_vreg = builder.create<tpu::BitcastVregOp>(
src_vreg->getLoc(), src_vreg->getType(), dst_vreg);
});
}
src_vregs.Reshape(dst_vregs_shape);
return src_vregs;
Expand Down
Loading