Skip to content

Commit 8163e74

Browse files
tlongeriGoogle-ML-Automation
authored andcommitted
[Mosaic:TPU] Add relayout for adding minor implicit dim and relax some offset restrictions on similar shape cast
This factors out some logic from the apply-vector-layout shape cast rule where we insert a minor dimension, relaxes some offset restrictions on it, and uses it for the relayout. PiperOrigin-RevId: 702993092
1 parent 1011687 commit 8163e74

File tree

2 files changed

+135
-49
lines changed

2 files changed

+135
-49
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,13 @@ def TPU_RepeatOp : TPU_Op<"repeat", [Pure]> {
342342
}
343343

344344
def TPU_BroadcastInSublanesOp : TPU_Op<"broadcast_in_sublanes", [Pure]> {
345+
let description = [{
346+
For each sublane `i`, broadcasts the value in lane `lane + i` along the entire
347+
sublane. If `lane + i` is not in [0, lane_count), then the value in sublane `i`
348+
is not defined (can be anything).
349+
}];
345350
let arguments = (ins
346-
AnyVectorOfNonZeroRank:$source, // All sublanes should be equal.
351+
TPU_Vreg:$source, // All sublanes should be equal.
347352
I32Attr:$lane // Coordinates of the first element to take.
348353
);
349354
// Output shape should be the same, except for position dim which contains

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 129 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,105 @@ FailureOr<SmallVector<Layout>> getInLayouts(
655655
return in_layouts;
656656
}
657657

658+
// Insert a minor dimension to the implicit shape. The original minor dimension
659+
// becomes the new second minor dimension, laid out across sublanes.
660+
//
661+
// The returned vreg array uses the original tiling and the offsets specified in
662+
// new_offsets to hold the value with the new implicit shape.
663+
//
664+
// Args:
665+
// vregs: The vreg array with *implicit* array shape.
666+
// ishape: The implicit shape of the represented value.
667+
// layout: The layout used for the represented value. The implicit
668+
// dimension is ignored, since this function operates directly at
669+
// the level of the implicit shape.
670+
// new_offsets: The offsets to use for the layout of the returned vreg array.
671+
FailureOr<xla::Array<Value>> insertImplicitMinorDimension(
672+
RewriteContext &ctx, OpBuilder &builder, const Location loc,
673+
const xla::Array<Value> &vregs, const ArrayRef<int64_t> ishape,
674+
const VectorLayout &layout, const LayoutOffsets new_offsets) {
675+
if (layout.bitwidth() != 32 || !layout.hasNativeTiling(ctx.target_shape)) {
676+
return emitError(loc, "Not implemented: Unsupported bitwidth or tiling");
677+
}
678+
if (layout.offsets()[1].has_value()) {
679+
if (!new_offsets[0]) {
680+
// TODO(tlongeri): This can only be valid if the dim size is 1.
681+
return emitError(loc, "Not implemented: Replication mismatch");
682+
}
683+
if (*new_offsets[0] != *layout.offsets()[1] % ctx.target_shape[0] &&
684+
*layout.offsets()[1] + *(ishape.end() - 1) > ctx.target_shape[1]) {
685+
// This requires blending data from different vregs.
686+
return emitError(loc,
687+
"Not implemented: Misaligned offsets and shape does not "
688+
"fit in one vreg");
689+
}
690+
}
691+
// new_layout is only to get the new vreg array shape, the implicit dim is
692+
// irrelevant (since we already have the implicit shape):
693+
const VectorLayout new_layout(layout.bitwidth(), new_offsets, layout.tiling(),
694+
VectorLayout::ImplicitDim::kNone);
695+
SmallVector<int64_t> new_ishape(ishape);
696+
new_ishape.push_back(1);
697+
xla::Array<Value> new_vregs(new_layout.tileArrayShape(
698+
/*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(new_ishape),
699+
ctx.target_shape));
700+
// Preallocate an indices vector to avoid repeated allocations:
701+
SmallVector<int64_t> idxs;
702+
new_vregs.Each([&](const absl::Span<const int64_t> dst_idx,
703+
Value *const dst_vreg) {
704+
// Indices of the new vreg in the new vreg array:
705+
const int64_t new_2nd_minor_idx = *(dst_idx.end() - 2);
706+
const int64_t new_3rd_minor_idx = *(dst_idx.end() - 3);
707+
idxs.assign(dst_idx.begin(), dst_idx.end());
708+
if (!layout.offsets()[0].has_value() && new_3rd_minor_idx != 0) {
709+
// All vregs along that dimension are the same
710+
*(idxs.end() - 3) = 0;
711+
*dst_vreg = new_vregs(idxs);
712+
} else if (!layout.offsets()[1].has_value() && new_2nd_minor_idx != 0) {
713+
// All vregs along that dimension are the same
714+
*(idxs.end() - 2) = 0;
715+
*dst_vreg = new_vregs(idxs);
716+
} else {
717+
// dst_vreg will hold slice [row_idx, col_idx:(col_idx + target_shape[0])]
718+
// of the after-offsets source shape
719+
const int64_t row_idx =
720+
layout.offsets()[0] ? new_3rd_minor_idx + *layout.offsets()[0] : 0;
721+
const int64_t col_idx = layout.offsets()[1]
722+
? new_2nd_minor_idx * ctx.target_shape[0] +
723+
*layout.offsets()[1] - *new_offsets[0]
724+
: 0;
725+
726+
idxs.pop_back();
727+
*(idxs.end() - 2) = row_idx / ctx.target_shape[0];
728+
*(idxs.end() - 1) = col_idx / ctx.target_shape[1];
729+
Value src_vreg = vregs(idxs);
730+
// TODO(tlongeri): We can sometimes skip operations when dst_vreg will
731+
// hold a single non-padding element (first or last) and we don't need
732+
// replication in the output.
733+
if (layout.offsets()[0].has_value()) {
734+
// [ . . . . . . . . ] [ . . . . a b c d ]
735+
// [ . . . . a b c d ] => [ . . . . a b c d ]
736+
// [ . . . . . . . . ] [ . . . . a b c d ]
737+
// [ . . . . . . . . ] [ . . . . a b c d ]
738+
src_vreg = broadcastSublane(
739+
builder, src_vreg,
740+
/*sublane_idx=*/row_idx % ctx.target_shape[0], ctx.target_shape);
741+
}
742+
if (layout.offsets()[1].has_value()) {
743+
// [ . . . . a b c d ] [ a a a a a a a a ]
744+
// [ . . . . a b c d ] => [ b b b b b b b b ]
745+
// [ . . . . a b c d ] [ c c c c c c c c ]
746+
// [ . . . . a b c d ] [ d d d d d d d d ]
747+
src_vreg = builder.create<BroadcastInSublanesOp>(
748+
loc, src_vreg.getType(), src_vreg,
749+
/*lane=*/col_idx % ctx.target_shape[1]);
750+
}
751+
*dst_vreg = src_vreg;
752+
}
753+
});
754+
return new_vregs;
755+
}
756+
658757
LogicalResult elementwise_op_rule(RewriteContext &ctx, Operation &op,
659758
const ArrayRef<Layout> layouts_in,
660759
const ArrayRef<Layout> layouts_out) {
@@ -4155,54 +4254,16 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
41554254
layout_in.bitwidth() == 32 &&
41564255
layout_in.hasNativeTiling(ctx.target_shape) &&
41574256
layout_in.tiling() == layout_out.tiling() &&
4158-
layout_in.offsets()[0].value_or(0) == 0 &&
4159-
layout_in.offsets()[1] == 0 && layout_out.offsets()[0] == 0
4160-
// layout_out.offsets[1] can be anything, as we produce a
4161-
// replicated result
4162-
) {
4163-
// First, insert the new singleton lane dimension.
4164-
SmallVector<int64_t> s = layout_in.implicitShape(src_shape);
4165-
s.push_back(1);
4166-
xla::Array<Value> dst_vregs_local(layout_out.tileArrayShape(
4167-
/*src_is_implicit=*/true, /*res_is_implicit=*/true, std::move(s),
4168-
ctx.target_shape));
4169-
TPU_ASSERT_EQ_OP(dst_vregs_local.dimensions().back(),
4170-
1); // We're inserting a singleton dimension
4171-
dst_vregs_local.Each(
4172-
[&](const absl::Span<const int64_t> dst_idx, Value *const dst_vreg) {
4173-
const int64_t col_idx = *(dst_idx.end() - 2);
4174-
const int64_t row_idx = *(dst_idx.end() - 3);
4175-
auto [sublanes_in_lane, rem] =
4176-
std::div(ctx.target_shape[1], ctx.target_shape[0]);
4177-
CHECK_EQ(rem, 0);
4178-
if (!layout_in.offsets()[0].has_value() && row_idx != 0) {
4179-
return; // All vregs along that dimension are the same.
4180-
}
4181-
SmallVector<int64_t> src_idx(toArrayRef(dst_idx));
4182-
src_idx.pop_back();
4183-
*(src_idx.end() - 2) /= ctx.target_shape[0];
4184-
*(src_idx.end() - 1) /= sublanes_in_lane;
4185-
Value col_vreg = src_vregs(src_idx);
4186-
// BroadcastInSublanesOp requires the sublanes to be replicated.
4187-
if (layout_in.offsets()[0].has_value()) {
4188-
const int32_t sublane = row_idx % ctx.target_shape[0];
4189-
col_vreg = broadcastSublane(builder, col_vreg, sublane,
4190-
ctx.target_shape);
4191-
}
4192-
*dst_vreg = builder.create<BroadcastInSublanesOp>(
4193-
col_vreg.getType(), col_vreg,
4194-
/*lane=*/(col_idx % sublanes_in_lane) * ctx.target_shape[0]);
4195-
});
4196-
if (!layout_in.offsets()[0].has_value()) {
4197-
// Broadcast the sublane vregs.
4198-
// TODO(tlongeri): This could be done more efficiently
4199-
dst_vregs_local.Each([&](const absl::Span<const int64_t> dst_idx,
4200-
Value *const dst_vreg) {
4201-
SmallVector<int64_t> first_row_idx(toArrayRef(dst_idx));
4202-
*(first_row_idx.end() - 3) = 0;
4203-
*dst_vreg = dst_vregs_local(first_row_idx);
4204-
});
4205-
}
4257+
(!layout_in.offsets()[1].has_value() ||
4258+
*layout_in.offsets()[1] % ctx.target_shape[0] ==
4259+
layout_out.offsets()[0] ||
4260+
*layout_in.offsets()[1] + src_tiled_dims[1] <=
4261+
ctx.target_shape[1])) {
4262+
FAILUREOR_ASSIGN_OR_RETURN(
4263+
xla::Array<Value> dst_vregs_local,
4264+
insertImplicitMinorDimension(ctx, builder, op.getLoc(), src_vregs,
4265+
layout_in.implicitShape(src_shape),
4266+
layout_in, layout_out.offsets()));
42064267
// Now, reshape the major axes of the vreg array.
42074268
dst_vregs_local.Reshape(
42084269
layout_out.tileArrayImplicitShape(dst_shape, ctx.target_shape));
@@ -6370,6 +6431,26 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
63706431
});
63716432
return std::make_pair(dst, new_vregs);
63726433
}
6434+
if (src.implicit_dim() == VectorLayout::ImplicitDim::kNone &&
6435+
dst_implicit_dim == VectorLayout::ImplicitDim::kMinor &&
6436+
src.bitwidth() == 32 && src.hasNativeTiling(ctx.target_shape)) {
6437+
// TODO(tlongeri): Make insertImplicitMinorDimension more flexible about
6438+
// offsets, then we can pass dst_offset_hints directly.
6439+
const LayoutOffset dst_2nd_minor_offset =
6440+
!src.offsets()[1] || *src.offsets()[1] + *(vty.getShape().end() - 1) <=
6441+
ctx.target_shape[1]
6442+
? dst_offset_hints[0]
6443+
: LayoutOffset(*src.offsets()[1] % ctx.target_shape[0]);
6444+
VectorLayout dst(src.bitwidth(),
6445+
{dst_2nd_minor_offset, dst_offset_hints[1]}, src.tiling(),
6446+
VectorLayout::ImplicitDim::kMinor);
6447+
FAILUREOR_ASSIGN_OR_RETURN(
6448+
xla::Array<Value> dst_vregs,
6449+
insertImplicitMinorDimension(ctx, builder, loc, vregs,
6450+
src.implicitShape(vty.getShape()), src,
6451+
dst.offsets()));
6452+
return std::make_pair(dst, std::move(dst_vregs));
6453+
}
63736454
return emitError(loc,
63746455
"Not implemented: Unsupported implicit dim change: from ")
63756456
<< src << " to " << dst_implicit_dim;

0 commit comments

Comments
 (0)