@@ -6116,35 +6116,49 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
61166116 }
61176117 const int packing = src.packing ();
61186118 const int8_t bitwidth = src.bitwidth ();
6119- VectorLayout dst (src.bitwidth (), src.offsets (), dst_tiling,
6120- src.implicit_dim ());
6121- if (!dst.isValid (target_shape)) {
6122- return emitError (loc, " Not implemented: invalid offsets in tiling target" );
6123- }
6124- auto dst_tiles_shape =
6125- dst.tileArrayImplicitShape (vty.getShape (), target_shape);
61266119 // Handle retiling from (1, 128) to (8, 128) for 32-bit data with replicating
61276120 // sublanes.
61286121 if (try_replicate_rows && packing == 1 &&
61296122 *(vregs.dimensions ().end () - 2 ) == 1 &&
6130- src.offsets () == LayoutOffsets{0 , 0 } &&
61316123 src.tiling () == std::array<int64_t , 2 >{1 , ctx.target_shape [1 ]} &&
61326124 dst_tiling == ctx.target_shape ) {
6133- xla::Array<Value> retiled (dst_tiles_shape);
6125+ DCHECK_EQ (src.offsets ()[0 ].value_or (0 ), 0 );
6126+ const LayoutOffset dst_minor_offset =
6127+ src.offsets ()[1 ] ? LayoutOffset (*src.offsets ()[1 ] % target_shape[1 ])
6128+ : std::nullopt ;
6129+ const VectorLayout dst (bitwidth, {std::nullopt , dst_minor_offset},
6130+ dst_tiling, src.implicit_dim ());
6131+ xla::Array<Value> retiled (
6132+ dst.tileArrayImplicitShape (vty.getShape (), target_shape));
61346133 retiled.Each ([&](absl::Span<const int64_t > idx, Value *tile) {
61356134 SmallVector<int64_t > src_idx (idx.begin (), idx.end ());
61366135 *(src_idx.end () - 2 ) *= target_shape[0 ];
6137- *(src_idx.end () - 1 ) /= target_shape[0 ];
6138- const int64_t src_sl_idx = *(idx.end () - 1 ) % target_shape[0 ];
6139- CHECK_EQ (src.getImplicitTiledDims (vty.getShape (), 1 )[0 ], 1 );
6140- *tile =
6141- broadcastSublane (builder, vregs (src_idx), src_sl_idx, target_shape);
6136+ if (!src.offsets ()[1 ].has_value ()) {
6137+ // With (1, 128) tiling each vreg holds values from a single row. This
6138+ // means that if the columns are replicated, then the whole vreg is
6139+ // already replicated.
6140+ *(src_idx.end () - 1 ) = 0 ;
6141+ *tile = vregs (src_idx);
6142+ } else {
6143+ // The column (in units of sublanes) of the sublane we want:
6144+ const int64_t sublane_column =
6145+ *(src_idx.end () - 1 ) + *src.offsets ()[1 ] / target_shape[1 ];
6146+ *(src_idx.end () - 1 ) = sublane_column / target_shape[0 ];
6147+ const int64_t src_sl_idx = sublane_column % target_shape[0 ];
6148+ *tile =
6149+ broadcastSublane (builder, vregs (src_idx), src_sl_idx, target_shape);
6150+ }
61426151 });
6143- // We have successfully replicated sublanes.
6144- dst = VectorLayout (bitwidth, {std::nullopt , dst.offsets ()[1 ]}, dst_tiling,
6145- dst.implicit_dim ());
6152+ // We have successfully replicated sublanes
61466153 return std::pair (dst, std::move (retiled));
61476154 }
6155+ VectorLayout dst (src.bitwidth (), src.offsets (), dst_tiling,
6156+ src.implicit_dim ());
6157+ if (!dst.isValid (target_shape)) {
6158+ return emitError (loc, " Not implemented: invalid offsets in tiling target" );
6159+ }
6160+ auto dst_tiles_shape =
6161+ dst.tileArrayImplicitShape (vty.getShape (), target_shape);
61486162 // (8,128) -> (8 * packing,128) tiling change for packed type.
61496163 if (bitwidth < 32 && 32 % bitwidth == 0 && src_tiling == ctx.target_shape &&
61506164 dst_tiling == std::array<int64_t , 2 >{ctx.target_shape [0 ] * dst.packing (),
0 commit comments