@@ -655,6 +655,105 @@ FailureOr<SmallVector<Layout>> getInLayouts(
655655 return in_layouts;
656656}
657657
658+ // Insert a minor dimension to the implicit shape. The original minor dimension
659+ // becomes the new second minor dimension, laid out across sublanes.
660+ //
661+ // The returned vreg array uses the original tiling and the offsets specified in
662+ // new_offsets to hold the value with the new implicit shape.
663+ //
664+ // Args:
665+ // vregs: The vreg array with *implicit* array shape.
666+ // ishape: The implicit shape of the represented value.
667+ // layout: The layout used for the represented value. The implicit
668+ // dimension is ignored, since this function operates directly at
669+ // the level of the implicit shape.
670+ // new_offsets: The offsets to use for the layout of the returned vreg array.
671+ FailureOr<xla::Array<Value>> insertImplicitMinorDimension (
672+ RewriteContext &ctx, OpBuilder &builder, const Location loc,
673+ const xla::Array<Value> &vregs, const ArrayRef<int64_t > ishape,
674+ const VectorLayout &layout, const LayoutOffsets new_offsets) {
675+ if (layout.bitwidth () != 32 || !layout.hasNativeTiling (ctx.target_shape )) {
676+ return emitError (loc, " Not implemented: Unsupported bitwidth or tiling" );
677+ }
678+ if (layout.offsets ()[1 ].has_value ()) {
679+ if (!new_offsets[0 ]) {
680+ // TODO(tlongeri): This can only be valid if the dim size is 1.
681+ return emitError (loc, " Not implemented: Replication mismatch" );
682+ }
683+ if (*new_offsets[0 ] != *layout.offsets ()[1 ] % ctx.target_shape [0 ] &&
684+ *layout.offsets ()[1 ] + *(ishape.end () - 1 ) > ctx.target_shape [1 ]) {
685+ // This requires blending data from different vregs.
686+ return emitError (loc,
687+ " Not implemented: Misaligned offsets and shape does not "
688+ " fit in one vreg" );
689+ }
690+ }
691+ // new_layout is only to get the new vreg array shape, the implicit dim is
692+ // irrelevant (since we already have the implicit shape):
693+ const VectorLayout new_layout (layout.bitwidth (), new_offsets, layout.tiling (),
694+ VectorLayout::ImplicitDim::kNone );
695+ SmallVector<int64_t > new_ishape (ishape);
696+ new_ishape.push_back (1 );
697+ xla::Array<Value> new_vregs (new_layout.tileArrayShape (
698+ /* src_is_implicit=*/ true , /* res_is_implicit=*/ true , std::move (new_ishape),
699+ ctx.target_shape ));
700+ // Preallocate an indices vector to avoid repeated allocations:
701+ SmallVector<int64_t > idxs;
702+ new_vregs.Each ([&](const absl::Span<const int64_t > dst_idx,
703+ Value *const dst_vreg) {
704+ // Indices of the new vreg in the new vreg array:
705+ const int64_t new_2nd_minor_idx = *(dst_idx.end () - 2 );
706+ const int64_t new_3rd_minor_idx = *(dst_idx.end () - 3 );
707+ idxs.assign (dst_idx.begin (), dst_idx.end ());
708+ if (!layout.offsets ()[0 ].has_value () && new_3rd_minor_idx != 0 ) {
709+ // All vregs along that dimension are the same
710+ *(idxs.end () - 3 ) = 0 ;
711+ *dst_vreg = new_vregs (idxs);
712+ } else if (!layout.offsets ()[1 ].has_value () && new_2nd_minor_idx != 0 ) {
713+ // All vregs along that dimension are the same
714+ *(idxs.end () - 2 ) = 0 ;
715+ *dst_vreg = new_vregs (idxs);
716+ } else {
717+ // dst_vreg will hold slice [row_idx, col_idx:(col_idx + target_shape[0])]
718+ // of the after-offsets source shape
719+ const int64_t row_idx =
720+ layout.offsets ()[0 ] ? new_3rd_minor_idx + *layout.offsets ()[0 ] : 0 ;
721+ const int64_t col_idx = layout.offsets ()[1 ]
722+ ? new_2nd_minor_idx * ctx.target_shape [0 ] +
723+ *layout.offsets ()[1 ] - *new_offsets[0 ]
724+ : 0 ;
725+
726+ idxs.pop_back ();
727+ *(idxs.end () - 2 ) = row_idx / ctx.target_shape [0 ];
728+ *(idxs.end () - 1 ) = col_idx / ctx.target_shape [1 ];
729+ Value src_vreg = vregs (idxs);
730+ // TODO(tlongeri): We can sometimes skip operations when dst_vreg will
731+ // hold a single non-padding element (first or last) and we don't need
732+ // replication in the output.
733+ if (layout.offsets ()[0 ].has_value ()) {
734+ // [ . . . . . . . . ] [ . . . . a b c d ]
735+ // [ . . . . a b c d ] => [ . . . . a b c d ]
736+ // [ . . . . . . . . ] [ . . . . a b c d ]
737+ // [ . . . . . . . . ] [ . . . . a b c d ]
738+ src_vreg = broadcastSublane (
739+ builder, src_vreg,
740+ /* sublane_idx=*/ row_idx % ctx.target_shape [0 ], ctx.target_shape );
741+ }
742+ if (layout.offsets ()[1 ].has_value ()) {
743+ // [ . . . . a b c d ] [ a a a a a a a a ]
744+ // [ . . . . a b c d ] => [ b b b b b b b b ]
745+ // [ . . . . a b c d ] [ c c c c c c c c ]
746+ // [ . . . . a b c d ] [ d d d d d d d d ]
747+ src_vreg = builder.create <BroadcastInSublanesOp>(
748+ loc, src_vreg.getType (), src_vreg,
749+ /* lane=*/ col_idx % ctx.target_shape [1 ]);
750+ }
751+ *dst_vreg = src_vreg;
752+ }
753+ });
754+ return new_vregs;
755+ }
756+
658757LogicalResult elementwise_op_rule (RewriteContext &ctx, Operation &op,
659758 const ArrayRef<Layout> layouts_in,
660759 const ArrayRef<Layout> layouts_out) {
@@ -4155,54 +4254,16 @@ LogicalResult vector_shape_cast_rule(RewriteContext &ctx, Operation &op,
41554254 layout_in.bitwidth () == 32 &&
41564255 layout_in.hasNativeTiling (ctx.target_shape ) &&
41574256 layout_in.tiling () == layout_out.tiling () &&
4158- layout_in.offsets ()[0 ].value_or (0 ) == 0 &&
4159- layout_in.offsets ()[1 ] == 0 && layout_out.offsets ()[0 ] == 0
4160- // layout_out.offsets[1] can be anything, as we produce a
4161- // replicated result
4162- ) {
4163- // First, insert the new singleton lane dimension.
4164- SmallVector<int64_t > s = layout_in.implicitShape (src_shape);
4165- s.push_back (1 );
4166- xla::Array<Value> dst_vregs_local (layout_out.tileArrayShape (
4167- /* src_is_implicit=*/ true , /* res_is_implicit=*/ true , std::move (s),
4168- ctx.target_shape ));
4169- TPU_ASSERT_EQ_OP (dst_vregs_local.dimensions ().back (),
4170- 1 ); // We're inserting a singleton dimension
4171- dst_vregs_local.Each (
4172- [&](const absl::Span<const int64_t > dst_idx, Value *const dst_vreg) {
4173- const int64_t col_idx = *(dst_idx.end () - 2 );
4174- const int64_t row_idx = *(dst_idx.end () - 3 );
4175- auto [sublanes_in_lane, rem] =
4176- std::div (ctx.target_shape [1 ], ctx.target_shape [0 ]);
4177- CHECK_EQ (rem, 0 );
4178- if (!layout_in.offsets ()[0 ].has_value () && row_idx != 0 ) {
4179- return ; // All vregs along that dimension are the same.
4180- }
4181- SmallVector<int64_t > src_idx (toArrayRef (dst_idx));
4182- src_idx.pop_back ();
4183- *(src_idx.end () - 2 ) /= ctx.target_shape [0 ];
4184- *(src_idx.end () - 1 ) /= sublanes_in_lane;
4185- Value col_vreg = src_vregs (src_idx);
4186- // BroadcastInSublanesOp requires the sublanes to be replicated.
4187- if (layout_in.offsets ()[0 ].has_value ()) {
4188- const int32_t sublane = row_idx % ctx.target_shape [0 ];
4189- col_vreg = broadcastSublane (builder, col_vreg, sublane,
4190- ctx.target_shape );
4191- }
4192- *dst_vreg = builder.create <BroadcastInSublanesOp>(
4193- col_vreg.getType (), col_vreg,
4194- /* lane=*/ (col_idx % sublanes_in_lane) * ctx.target_shape [0 ]);
4195- });
4196- if (!layout_in.offsets ()[0 ].has_value ()) {
4197- // Broadcast the sublane vregs.
4198- // TODO(tlongeri): This could be done more efficiently
4199- dst_vregs_local.Each ([&](const absl::Span<const int64_t > dst_idx,
4200- Value *const dst_vreg) {
4201- SmallVector<int64_t > first_row_idx (toArrayRef (dst_idx));
4202- *(first_row_idx.end () - 3 ) = 0 ;
4203- *dst_vreg = dst_vregs_local (first_row_idx);
4204- });
4205- }
4257+ (!layout_in.offsets ()[1 ].has_value () ||
4258+ *layout_in.offsets ()[1 ] % ctx.target_shape [0 ] ==
4259+ layout_out.offsets ()[0 ] ||
4260+ *layout_in.offsets ()[1 ] + src_tiled_dims[1 ] <=
4261+ ctx.target_shape [1 ])) {
4262+ FAILUREOR_ASSIGN_OR_RETURN (
4263+ xla::Array<Value> dst_vregs_local,
4264+ insertImplicitMinorDimension (ctx, builder, op.getLoc (), src_vregs,
4265+ layout_in.implicitShape (src_shape),
4266+ layout_in, layout_out.offsets ()));
42064267 // Now, reshape the major axes of the vreg array.
42074268 dst_vregs_local.Reshape (
42084269 layout_out.tileArrayImplicitShape (dst_shape, ctx.target_shape ));
@@ -6370,6 +6431,26 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeImplicitDim(
63706431 });
63716432 return std::make_pair (dst, new_vregs);
63726433 }
6434+ if (src.implicit_dim () == VectorLayout::ImplicitDim::kNone &&
6435+ dst_implicit_dim == VectorLayout::ImplicitDim::kMinor &&
6436+ src.bitwidth () == 32 && src.hasNativeTiling (ctx.target_shape )) {
6437+ // TODO(tlongeri): Make insertImplicitMinorDimension more flexible about
6438+ // offsets, then we can pass dst_offset_hints directly.
6439+ const LayoutOffset dst_2nd_minor_offset =
6440+ !src.offsets ()[1 ] || *src.offsets ()[1 ] + *(vty.getShape ().end () - 1 ) <=
6441+ ctx.target_shape [1 ]
6442+ ? dst_offset_hints[0 ]
6443+ : LayoutOffset (*src.offsets ()[1 ] % ctx.target_shape [0 ]);
6444+ VectorLayout dst (src.bitwidth (),
6445+ {dst_2nd_minor_offset, dst_offset_hints[1 ]}, src.tiling (),
6446+ VectorLayout::ImplicitDim::kMinor );
6447+ FAILUREOR_ASSIGN_OR_RETURN (
6448+ xla::Array<Value> dst_vregs,
6449+ insertImplicitMinorDimension (ctx, builder, loc, vregs,
6450+ src.implicitShape (vty.getShape ()), src,
6451+ dst.offsets ()));
6452+ return std::make_pair (dst, std::move (dst_vregs));
6453+ }
63736454 return emitError (loc,
63746455 " Not implemented: Unsupported implicit dim change: from " )
63756456 << src << " to " << dst_implicit_dim;
0 commit comments