@@ -1035,12 +1035,19 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
10351035 output_vregs.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
10361036 SmallVector<Value> parts;
10371037 SmallVector<int64_t > idxs_local (toArrayRef (idxs));
1038- idxs_local.back () *= packing;
1039- for (int64_t i = 0 ; i < packing; ++i) {
1040- parts.push_back (input_vregs (idxs_local));
1041- // Pack any data lying around if OOB
1042- if (idxs_local.back () < input_vregs.dimensions ().back () - 1 ) {
1043- ++idxs_local.back ();
1038+ if (!layout_out.offsets ()[1 ].has_value ()) {
1039+ idxs_local.back () = 0 ;
1040+ // Make sure we set all parts of the output vreg to make it replicated
1041+ parts.append (packing, input_vregs (idxs_local));
1042+ } else {
1043+ idxs_local.back () *= packing;
1044+ for (int64_t i = 0 ; i < packing; ++i) {
1045+ if (idxs_local.back () < input_vregs.dimensions ().back ()) {
1046+ parts.push_back (input_vregs (idxs_local));
1047+ ++idxs_local.back ();
1048+ } else {
1049+ parts.push_back (nullptr );
1050+ }
10441051 }
10451052 }
10461053 *v = builder.create <PackSubelementsOp>(res_vreg_ty, parts,
@@ -1053,16 +1060,19 @@ LogicalResult trunc_op_rule_impl(RewriteContext &ctx, OpTy op,
10531060 output_vregs.Each ([&](absl::Span<const int64_t > idxs, Value *v) {
10541061 CHECK_GE (idxs.size (), 2 );
10551062 SmallVector<int64_t > idxs_local (toArrayRef (idxs));
1056- idxs_local[idxs.size () - 2 ] *= packing;
1057- parts.push_back (input_vregs (idxs_local));
1058- idxs_local[idxs.size () - 2 ]++;
1059- while (parts.size () < packing) {
1060- if (*(idxs_local.end () - 2 ) < *(input_vregs.dimensions ().end () - 2 )) {
1061- parts.push_back (input_vregs (idxs_local));
1062- idxs_local[idxs.size () - 2 ]++;
1063- } else {
1064- // Once we run out of tiles, we can pick any one we like.
1065- parts.push_back (parts.back ());
1063+ if (!layout_out.offsets ()[0 ].has_value ()) {
1064+ *(idxs_local.end () - 2 ) = 0 ;
1065+ // Make sure we set all parts of the output vreg to make it replicated
1066+ parts.append (packing, input_vregs (idxs_local));
1067+ } else {
1068+ *(idxs_local.end () - 2 ) *= packing;
1069+ for (int64_t i = 0 ; i < packing; ++i) {
1070+ if (*(idxs_local.end () - 2 ) < *(input_vregs.dimensions ().end () - 2 )) {
1071+ parts.push_back (input_vregs (idxs_local));
1072+ ++*(idxs_local.end () - 2 );
1073+ } else {
1074+ parts.push_back (nullptr );
1075+ }
10661076 }
10671077 }
10681078 *v = builder.create <PackSubelementsOp>(res_vreg_ty, parts,
@@ -6253,6 +6263,11 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
62536263 ctx.target_shape [1 ]}) {
62546264 // Note: for int4, retiling with scratch is always faster.
62556265 if (bitwidth != 4 || !has_enough_scratch) {
6266+ // Note: The code below does not work when src is replicated and dst is
6267+ // not, since it relies on the src vreg array shape to know how many tiles
6268+ // to pack in dst, and vreg array shapes with materialized offsets are
6269+ // unfortunately not equal to vreg array shapes with replicated offsets.
6270+ CHECK (dst.offsets () == src_offsets);
62566271 xla::Array<Value> retiled (dst_tiles_shape);
62576272 VectorType vreg_x32 =
62586273 vty.getElementType ().isSignlessInteger ()
@@ -6263,19 +6278,29 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
62636278 SmallVector<Value, 8 > parts;
62646279 parts.reserve (packing);
62656280 SmallVector<int64_t > src_idx (idx.begin (), idx.end ());
6266- src_idx[src_idx.size () - 2 ] *= packing;
6267- src_idx[src_idx.size () - 1 ] /= packing;
6268- for (int i = 0 ; i < packing; ++i) {
6269- parts.push_back (builder.create <tpu::UnpackSubelementsOp>(
6270- loc, vreg_x32, vregs (src_idx), vreg_part,
6271- tpu::PackFormat::kCompressed ));
6272- if (src_idx[src_idx.size () - 2 ] <
6273- vregs.dim (vregs.num_dimensions () - 2 ) - 1 ) {
6274- ++src_idx[src_idx.size () - 2 ];
6281+ *(src_idx.end () - 1 ) /= packing;
6282+ if (!dst.offsets ()[0 ].has_value ()) {
6283+ *(src_idx.end () - 2 ) = 0 ;
6284+ // Make sure we set all parts of the output vreg to make it replicated
6285+ parts.append (packing, builder.create <tpu::UnpackSubelementsOp>(
6286+ loc, vreg_x32, vregs (src_idx), vreg_part,
6287+ tpu::PackFormat::kCompressed ));
6288+ } else {
6289+ *(src_idx.end () - 2 ) *= packing;
6290+ for (int i = 0 ; i < packing; ++i) {
6291+ if (*(src_idx.end () - 2 ) < *(vregs.dimensions ().end () - 2 )) {
6292+ parts.push_back (builder.create <tpu::UnpackSubelementsOp>(
6293+ loc, vreg_x32, vregs (src_idx), vreg_part,
6294+ tpu::PackFormat::kCompressed ));
6295+ ++*(src_idx.end () - 2 );
6296+ } else {
6297+ parts.push_back (nullptr );
6298+ }
62756299 }
62766300 }
62776301 *tile = builder.create <tpu::PackSubelementsOp>(
6278- loc, vregs.begin ()->getType (), parts, tpu::PackFormat::kCompressed );
6302+ loc, cast<VectorType>(vregs.begin ()->getType ()), parts,
6303+ tpu::PackFormat::kCompressed );
62796304 });
62806305 return std::pair (dst, std::move (retiled));
62816306 }
@@ -6334,6 +6359,12 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
63346359 // [(a b) (A B) (c d) (C D) ...]. That is, traverse down each column before
63356360 // moving to the next one. This is exactly an interleaving of the sublanes
63366361 // of the vreg parts.
6362+
6363+ // Note: The code below does not work when src is replicated and dst is
6364+ // not, since it relies on the src vreg array shape to know how many tiles
6365+ // to pack in dst, and vreg array shapes with materialized offsets are
6366+ // unfortunately not equal to vreg array shapes with replicated offsets.
6367+ CHECK (dst.offsets () == src.offsets ());
63376368 xla::Array<Value> retiled (dst_tiles_shape);
63386369 const VectorType vreg_x32 =
63396370 vty.getElementType ().isSignlessInteger ()
@@ -6343,20 +6374,30 @@ FailureOr<std::pair<VectorLayout, xla::Array<Value>>> changeTiling(
63436374 SmallVector<Value> parts;
63446375 parts.reserve (packing);
63456376 SmallVector<int64_t > src_idx (toArrayRef (idx));
6346- *(src_idx.end () - 2 ) *= packing;
63476377 const int64_t vreg_part = *(src_idx.end () - 1 ) % packing;
63486378 *(src_idx.end () - 1 ) /= packing;
6349- for (int i = 0 ; i < packing; ++i) {
6350- parts.push_back (builder.create <tpu::UnpackSubelementsOp>(
6351- loc, vreg_x32, vregs (src_idx), vreg_part,
6352- tpu::PackFormat::kCompressed ));
6353- if (*(src_idx.end () - 2 ) < *(vregs.dimensions ().end () - 2 ) - 1 ) {
6354- ++*(src_idx.end () - 2 );
6355- } // The rest is padding, so just pick any of the input parts (but not
6356- // an arbitrary vreg so we don't add an extra dependency).
6379+ if (!dst.offsets ()[0 ].has_value ()) {
6380+ *(src_idx.end () - 2 ) = 0 ;
6381+ // Make sure we set all parts of the output vreg to make it replicated
6382+ parts.append (packing, builder.create <tpu::UnpackSubelementsOp>(
6383+ loc, vreg_x32, vregs (src_idx), vreg_part,
6384+ tpu::PackFormat::kCompressed ));
6385+ } else {
6386+ *(src_idx.end () - 2 ) *= packing;
6387+ for (int i = 0 ; i < packing; ++i) {
6388+ if (*(src_idx.end () - 2 ) < *(vregs.dimensions ().end () - 2 )) {
6389+ parts.push_back (builder.create <tpu::UnpackSubelementsOp>(
6390+ loc, vreg_x32, vregs (src_idx), vreg_part,
6391+ tpu::PackFormat::kCompressed ));
6392+ ++*(src_idx.end () - 2 );
6393+ } else {
6394+ parts.push_back (nullptr );
6395+ }
6396+ }
63576397 }
63586398 *tile = builder.create <tpu::PackSubelementsOp>(
6359- loc, vregs.begin ()->getType (), parts, tpu::PackFormat::kInterleaved );
6399+ loc, cast<VectorType>(vregs.begin ()->getType ()), parts,
6400+ tpu::PackFormat::kInterleaved );
63606401 });
63616402 return std::pair (dst, std::move (retiled));
63626403 }
0 commit comments