@@ -3919,6 +3919,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
39193919 TPU_ASSERT_EQ_OP (layouts_out.size (), 1 );
39203920 TPU_ASSERT_OP (
39213921 llvm::all_of (layouts_in, [&](const Layout &l) { return l.has_value (); }));
3922+ const Location loc = op.getLoc ();
39223923 const VectorLayout &src_layout = *layouts_in[0 ];
39233924 const VectorLayout &acc_layout = *layouts_in[1 ];
39243925 const VectorLayout &dst_layout = *layouts_out[0 ];
@@ -4101,6 +4102,16 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
41014102 xla::Array<Value> reduced_vregs =
41024103 src_vregs.Slice (src_slice_start, src_slice_end);
41034104 std::optional<Value> acc_vreg;
4105+ auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value {
4106+ switch (tpu_kind) {
4107+ case tpu::ReductionKind::SUM:
4108+ return builder.create <arith::AddFOp>(loc, lhs, rhs);
4109+ case tpu::ReductionKind::MAX:
4110+ return builder.create <arith::MaximumFOp>(loc, lhs, rhs);
4111+ case tpu::ReductionKind::MIN:
4112+ return builder.create <arith::MinimumFOp>(loc, lhs, rhs);
4113+ }
4114+ };
41044115 auto reduction_status = reduced_vregs.EachStatus (
41054116 [&](const absl::Span<const int64_t > red_idx,
41064117 Value *const src_vreg) {
@@ -4130,20 +4141,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
41304141 if (!acc_vreg.has_value ()) {
41314142 acc_vreg = vreg;
41324143 } else {
4133- switch (tpu_kind) {
4134- case tpu::ReductionKind::SUM:
4135- acc_vreg = builder.create <arith::AddFOp>(vreg.getLoc (),
4136- *acc_vreg, vreg);
4137- break ;
4138- case tpu::ReductionKind::MAX:
4139- acc_vreg = builder.create <arith::MaximumFOp>(
4140- vreg.getLoc (), *acc_vreg, vreg);
4141- break ;
4142- case tpu::ReductionKind::MIN:
4143- acc_vreg = builder.create <arith::MinimumFOp>(
4144- vreg.getLoc (), *acc_vreg, vreg);
4145- break ;
4146- }
4144+ acc_vreg = reduce_elementwise (*acc_vreg, vreg);
41474145 }
41484146 return absl::OkStatus ();
41494147 });
@@ -4156,8 +4154,43 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
41564154 multi_reduction_op->getLoc (), *acc_vreg, 1 , tpu_kind);
41574155 }
41584156 if (reduces[0 ]) {
4157+ // Packed types are compressed along rows, so we need to reduce them
4158+ // within each 32-bit word. There's no performance penalty for doing
4159+ // this in 32-bit precision, so we take advantage of it.
4160+ Type acc_vreg_ty = acc_vreg->getType ();
4161+ if (acc_layout.packing () > 1 ) {
4162+ Type vreg_ty_32 = nullptr ;
4163+ if (acc.getType ().getElementType ().isBF16 ()) {
4164+ vreg_ty_32 =
4165+ getNativeVregType (builder.getF32Type (), ctx.target_shape );
4166+ } else {
4167+ multi_reduction_op.emitOpError (
4168+ " Not implemented: Unsupported reduction dtype" );
4169+ return absl::UnknownError (" " );
4170+ }
4171+ Value acc_vreg_32 = builder.create <tpu::UnpackSubelementsOp>(
4172+ loc, vreg_ty_32, *acc_vreg, 0 , tpu::PackFormat::kInterleaved );
4173+ for (int i = 1 ; i < acc_layout.packing (); ++i) {
4174+ Value acc_vreg_part_32 = builder.create <tpu::UnpackSubelementsOp>(
4175+ loc, vreg_ty_32, *acc_vreg, i, tpu::PackFormat::kInterleaved );
4176+ acc_vreg_32 = reduce_elementwise (acc_vreg_32, acc_vreg_part_32);
4177+ }
4178+ acc_vreg = acc_vreg_32;
4179+ }
4180+ // At this point acc_vreg is always 32-bit.
41594181 acc_vreg = builder.create <tpu::AllReduceOp>(
41604182 multi_reduction_op->getLoc (), *acc_vreg, 0 , tpu_kind);
4183+ // We pack the final result back into the original type.
4184+ if (acc_layout.packing () > 1 ) {
4185+ SmallVector<int32_t > positions (acc_layout.packing ());
4186+ std::iota (positions.begin (), positions.end (),
4187+ static_cast <int32_t >(0 ));
4188+ SmallVector<Value> parts (acc_layout.packing (), *acc_vreg);
4189+ acc_vreg = builder.create <tpu::PackSubelementsOp>(
4190+ loc, acc_vreg_ty, parts,
4191+ builder.getDenseI32ArrayAttr (positions),
4192+ tpu::PackFormat::kInterleaved );
4193+ }
41614194 }
41624195 *dst_vreg = *acc_vreg;
41634196 return absl::OkStatus ();
0 commit comments