Skip to content

Commit ad960fd

Browse files
WindQAQGoogle-ML-Automation
authored andcommitted
[Mosaic] Allow padding in small tiling row shuffle reshape.
We just need to make sure shape aligns to vreg-slice lane dim and only last vreg contains padding on tiled dims. Examples: 1: Reshape vector<10x128xi32> to vector<5x256xi32> can use the row shuffle reshape routine by inferring in tiling = (8, 128) and out tiling = (4, 128) because 1) vregs are still one-to-one mapping, ensured by vreg-slice lane aligned, and 2) only last vreg in tiled dims are padded, ensured by #elements are the same in tiled dims. 2: Reshape vector<16x512x56x128xbf16> to vector<16x512x7168xbf16> can use in tiling = (16, 128) and out tiling = (1, 256) and make it no-op. PiperOrigin-RevId: 833370307
1 parent 18a05ef commit ad960fd

File tree

3 files changed

+116
-48
lines changed

3 files changed

+116
-48
lines changed

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,12 +1846,14 @@ LogicalResult UnpackSubelementsOp::canonicalize(UnpackSubelementsOp op,
18461846
if (auto pack = dyn_cast<PackSubelementsOp>(op.getSource().getDefiningOp());
18471847
pack && pack.getPackFormat() == op.getPackFormat() &&
18481848
pack.getSources().front().getType() == op.getType()) {
1849-
rewriter.replaceAllOpUsesWith(
1850-
op, pack.getPaddedSources(
1851-
pack.getSources(), pack.getPositions(),
1852-
op.getType().getElementTypeBitWidth() /
1853-
pack.getType().getElementTypeBitWidth())[op.getIndex()]);
1854-
return success();
1849+
Value source = pack.getPaddedSources(
1850+
pack.getSources(), pack.getPositions(),
1851+
op.getType().getElementTypeBitWidth() /
1852+
pack.getType().getElementTypeBitWidth())[op.getIndex()];
1853+
if (source) {
1854+
rewriter.replaceAllOpUsesWith(op, source);
1855+
return success();
1856+
}
18551857
}
18561858
return failure();
18571859
}

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 57 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -5650,40 +5650,55 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
56505650
no_op = true;
56515651
}
56525652

5653-
auto can_use_row_shuffle = [&ctx](ArrayRef<int64_t> shape,
5654-
VectorLayout layout,
5655-
std::array<int64_t, 2> vreg_slice) {
5656-
if (shape.size() < 2) {
5653+
bool can_use_row_shuffle = [&]() {
5654+
if (!llvm::isPowerOf2_32(layout_in.bitwidth())) {
56575655
return false;
56585656
}
5659-
// vreg must not be padded.
5660-
if (shape.back() % vreg_slice[1] != 0 ||
5661-
shape[shape.size() - 2] % vreg_slice[0] != 0) {
5657+
if (layout_in.offsets() != LayoutOffsets{0, 0} ||
5658+
layout_out.offsets() != LayoutOffsets{0, 0}) {
56625659
return false;
56635660
}
5664-
if (!llvm::isPowerOf2_32(layout.bitwidth())) {
5665-
return false;
5666-
}
5667-
if (layout.offsets() != LayoutOffsets{0, 0}) {
5668-
return false;
5669-
}
5670-
if (layout.implicit_dim() != VectorLayout::ImplicitDim::kNone) {
5671-
return false;
5661+
bool src_is_1d_tiling =
5662+
layout_in.tiling() ==
5663+
std::array<int64_t, 2>{1, ctx.target_shape[1] * layout_in.packing()};
5664+
bool dst_is_1d_tiling =
5665+
layout_out.tiling() ==
5666+
std::array<int64_t, 2>{1, ctx.target_shape[1] * layout_out.packing()};
5667+
bool src_is_vreg_slice_lane_aligned =
5668+
(!src_is_1d_tiling && src_tiled_dims[1] == src_vreg_slice[1]) ||
5669+
(src_is_1d_tiling && src_tiled_dims[1] % src_vreg_slice[1] == 0);
5670+
bool dst_is_vreg_slice_lane_aligned =
5671+
(!dst_is_1d_tiling && dst_tiled_dims[1] == dst_vreg_slice[1]) ||
5672+
(dst_is_1d_tiling && dst_tiled_dims[1] % dst_vreg_slice[1] == 0);
5673+
bool src_is_vreg_slice_sublane_aligned =
5674+
src_tiled_dims[0] % src_vreg_slice[0] == 0;
5675+
bool dst_is_vreg_slice_sublane_aligned =
5676+
dst_tiled_dims[0] % dst_vreg_slice[0] == 0;
5677+
if (src_is_vreg_slice_lane_aligned && dst_is_vreg_slice_lane_aligned) {
5678+
if (src_is_vreg_slice_sublane_aligned &&
5679+
dst_is_vreg_slice_sublane_aligned) {
5680+
// Both src and dst are aligned to vreg slice sublanes.
5681+
return true;
5682+
}
5683+
if (!src_is_vreg_slice_sublane_aligned &&
5684+
!dst_is_vreg_slice_sublane_aligned &&
5685+
llvm::product_of(src_tiled_dims) ==
5686+
llvm::product_of(dst_tiled_dims)) {
5687+
// Neither src nor dst are aligned to vreg slice sublanes.
5688+
// Padding happens only on the last vreg in tiled dims.
5689+
return true;
5690+
}
56725691
}
5673-
// 2d tiling.
5674-
if (layout.tiling()[0] <= ctx.target_shape[0] * layout.packing() &&
5675-
layout.tiling()[1] == ctx.target_shape[1] &&
5676-
shape.back() == vreg_slice[1]) {
5692+
if (src_is_vreg_slice_lane_aligned && dst_is_1d_tiling &&
5693+
llvm::product_of(src_tiled_dims) == dst_tiled_dims[1]) {
56775694
return true;
56785695
}
5679-
// 1d tiling.
5680-
if (layout.tiling() ==
5681-
std::array<int64_t, 2>{1, ctx.target_shape[1] * layout.packing()} &&
5682-
shape.back() % vreg_slice[1] == 0) {
5696+
if (dst_is_vreg_slice_lane_aligned && src_is_1d_tiling &&
5697+
llvm::product_of(dst_tiled_dims) == src_tiled_dims[1]) {
56835698
return true;
56845699
}
56855700
return false;
5686-
};
5701+
}();
56875702

56885703
FAILUREOR_ASSIGN_OR_RETURN(
56895704
xla::Array<Value> src_vregs,
@@ -5715,15 +5730,13 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
57155730
layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape));
57165731
return dst_vregs_local;
57175732
} else if (
5718-
// Row shuffle within a vreg if there is no padding and each vreg holds
5719-
// a contiguous slice of the flattened data.
5720-
can_use_row_shuffle(src_shape, layout_in, src_vreg_slice) &&
5721-
can_use_row_shuffle(dst_shape, layout_out, dst_vreg_slice)) {
5733+
// Row shuffle within a vreg if each vreg holds a contiguous slice of
5734+
// the flattened data and each row is either fully occupied or is all
5735+
// padding.
5736+
can_use_row_shuffle) {
57225737
auto [sublane_count, lane_count] = ctx.target_shape;
5723-
auto dst_vregs_shape =
5724-
layout_out.tileArrayShape(false, false, dst_shape, ctx.target_shape);
5725-
auto src_vregs_shape =
5726-
layout_in.tileArrayShape(false, false, src_shape, ctx.target_shape);
5738+
src_vregs.Reshape(
5739+
layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape));
57275740
if (bitwidth == 32) {
57285741
// For 32 bit data, a sublane is effectively a physical row.
57295742
std::array<int64_t, 2> src_sublane_slice = {
@@ -5845,8 +5858,20 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
58455858
// with tiling (16, 128) and then to (8, 512) with tiling (8, 128).
58465859
const int64_t src_sublane_tiling = layout_in.tiling()[0];
58475860
const int64_t dst_sublane_tiling = layout_out.tiling()[0];
5861+
const int64_t native_sublane_tiling =
5862+
ctx.target_shape[0] * layout_in.packing();
58485863
CHECK(llvm::isPowerOf2_64(static_cast<uint64_t>(src_sublane_tiling)));
58495864
CHECK(llvm::isPowerOf2_64(static_cast<uint64_t>(dst_sublane_tiling)));
5865+
CHECK(
5866+
llvm::isPowerOf2_64(static_cast<uint64_t>(native_sublane_tiling)));
5867+
// (target_shape[0] * packing, target_shape[1]) <->
5868+
// (1, target_shape[1] * packing) is a no-op.
5869+
if ((src_sublane_tiling == 1 &&
5870+
dst_sublane_tiling == native_sublane_tiling) ||
5871+
(src_sublane_tiling == native_sublane_tiling &&
5872+
dst_sublane_tiling == 1)) {
5873+
return src_vregs;
5874+
}
58505875
tpu::PackFormat unpack_format, pack_format;
58515876
if (src_sublane_tiling > dst_sublane_tiling) {
58525877
unpack_format = tpu::PackFormat::kInterleaved;
@@ -5887,7 +5912,6 @@ LogicalResult reshape_rule(RewriteContext& ctx, Operation& op,
58875912
src_vreg->getLoc(), src_vreg->getType(), dst_vreg);
58885913
});
58895914
}
5890-
src_vregs.Reshape(dst_vregs_shape);
58915915
return src_vregs;
58925916
} else if (
58935917
// Lower shape_casts for {32/16/8}-bit types where the minor dimension

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,8 +1633,8 @@ class VectorLayoutInferer {
16331633
return success();
16341634
}
16351635

1636-
// Find the small tiling such that there is not padding and each vreg holds
1637-
// a continuous slice of the flatten data.
1636+
// Find the small tiling such that each vreg holds a continuous slice of the
1637+
// flatten data and each row is either fully occupied or is all padding.
16381638
auto small_second_minor_tiling_layout =
16391639
[&](ArrayRef<int64_t> shape) -> std::optional<VectorLayout> {
16401640
const int64_t elements_per_vreg = native_tiling[0] * native_tiling[1];
@@ -1659,16 +1659,15 @@ class VectorLayoutInferer {
16591659
// TODO(b/440370770): Preserve replicated offsets.
16601660
auto layout = VectorLayout(bitwidth, {0, 0}, tiling, ImplicitDim::kNone);
16611661
auto vreg_slice = layout.vregSlice(target_shape_);
1662-
if ((shape.back() != vreg_slice[1] && !can_use_1d_tiling) ||
1663-
shape[shape.size() - 2] % vreg_slice[0] != 0) {
1662+
if (shape.back() != vreg_slice[1] && !can_use_1d_tiling) {
16641663
return std::nullopt;
16651664
}
16661665
return layout;
16671666
};
16681667

1669-
// Use the small tiling if there's no padding and each vreg holds a
1670-
// contiguous slice of the flattened data. It makes reshape a row shuffle
1671-
// within a vreg.
1668+
// Use the small tiling if each vreg holds a contiguous slice of the
1669+
// flattened data and each row is either fully occupied or is all
1670+
// padding. It makes reshape a row shuffle within a vreg.
16721671
//
16731672
// For example,
16741673
// - (4, 256) with (4, 128) tiling to (1, 1024) with (1, 128) tiling is
@@ -1684,8 +1683,51 @@ class VectorLayoutInferer {
16841683

16851684
if (src_small_second_minor_tiling_layout.has_value() &&
16861685
res_small_second_minor_tiling_layout.has_value()) {
1687-
setLayout(op, *src_small_second_minor_tiling_layout,
1688-
*res_small_second_minor_tiling_layout);
1686+
auto src_vreg_slice =
1687+
src_small_second_minor_tiling_layout->vregSlice(target_shape_);
1688+
auto res_vreg_slice =
1689+
res_small_second_minor_tiling_layout->vregSlice(target_shape_);
1690+
bool src_vreg_slice_aligned =
1691+
src_shape[src_shape.size() - 2] % src_vreg_slice[0] == 0;
1692+
bool res_vreg_slice_aligned =
1693+
res_shape[res_shape.size() - 2] % res_vreg_slice[0] == 0;
1694+
if (
1695+
// Both input and output are aligned to its vreg slice.
1696+
(src_vreg_slice_aligned && res_vreg_slice_aligned) ||
1697+
// Because the last dims are equal vreg slice lane dim, we know that
1698+
// in 2D tiled dim, vregs are organized from top to bottom. By
1699+
// checking the product of the last two dims, we make sure the tiled
1700+
// dims have the same number of vregs/elements, and only the last or
1701+
// bottom-most vreg has padding.
1702+
// For example, it's valid to reshape i32 (12, 128) to (6, 256) with
1703+
// input tiling (8, 128) and output tiling (4, 128), but not valid to
1704+
// reshape i32 (12, 128) to (3, 2, 256) with input tiling (8, 128) and
1705+
// output tiling (4, 128).
1706+
(!src_vreg_slice_aligned && !res_vreg_slice_aligned &&
1707+
llvm::product_of(src_shape.take_back(2)) ==
1708+
llvm::product_of(res_shape.take_back(2)))) {
1709+
setLayout(op, *src_small_second_minor_tiling_layout,
1710+
*res_small_second_minor_tiling_layout);
1711+
return success();
1712+
}
1713+
}
1714+
if (src_small_second_minor_tiling_layout.has_value() &&
1715+
llvm::product_of(src_shape.take_back(2)) == res_shape.back()) {
1716+
// For example, reshape i32 (8, 10, 128) to (8, 1280) with input tiling
1717+
// (8, 128) and output tiling (1, 128).
1718+
setLayout(
1719+
op, *src_small_second_minor_tiling_layout,
1720+
VectorLayout(layout.bitwidth(), {0, 0},
1721+
{1, target_shape_[1] * packing}, ImplicitDim::kNone));
1722+
return success();
1723+
}
1724+
if (res_small_second_minor_tiling_layout.has_value() &&
1725+
llvm::product_of(res_shape.take_back(2)) == src_shape.back()) {
1726+
setLayout(
1727+
op,
1728+
VectorLayout(layout.bitwidth(), {0, 0},
1729+
{1, target_shape_[1] * packing}, ImplicitDim::kNone),
1730+
*res_small_second_minor_tiling_layout);
16891731
return success();
16901732
}
16911733

0 commit comments

Comments
 (0)