Skip to content

Commit 6edfe9e

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Add support for bf16 second minor reductions in TPUv6
PiperOrigin-RevId: 707557416
1 parent bcca77c commit 6edfe9e

File tree

3 files changed

+50
-17
lines changed

3 files changed

+50
-17
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def ir_constant(x, mlir_type=None):
241241
x = np.array(x, np.float32)
242242
if not mlir_type:
243243
mlir_type = _dtype_to_ir_type(x.dtype)
244-
if isinstance(x, int) or x.dtype in (np.int32, np.uint32, np.int8):
244+
if isinstance(x, int) or np.issubdtype(x.dtype, np.integer):
245245
return arith.constant(mlir_type, ir.IntegerAttr.get(mlir_type, int(x)))
246246
elif isinstance(x, float) or x.dtype == np.float32:
247247
return arith.constant(mlir_type, ir.FloatAttr.get(mlir_type, float(x)))
@@ -1458,7 +1458,7 @@ def _proxy_fun(val, *, axes):
14581458
if jnp.issubdtype(x_aval.dtype, jnp.floating):
14591459
kind = type_to_kind[jnp.floating]
14601460
val = type_to_identity[jnp.floating]
1461-
val = ir.FloatAttr.get(ir.F32Type.get(), val)
1461+
val = ir.FloatAttr.get(aval_to_ir_type(x_aval, shape=()), val)
14621462
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
14631463
raise NotImplementedError("Reductions over integers not implemented.")
14641464
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3919,6 +3919,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
39193919
TPU_ASSERT_EQ_OP(layouts_out.size(), 1);
39203920
TPU_ASSERT_OP(
39213921
llvm::all_of(layouts_in, [&](const Layout &l) { return l.has_value(); }));
3922+
const Location loc = op.getLoc();
39223923
const VectorLayout &src_layout = *layouts_in[0];
39233924
const VectorLayout &acc_layout = *layouts_in[1];
39243925
const VectorLayout &dst_layout = *layouts_out[0];
@@ -4101,6 +4102,16 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
41014102
xla::Array<Value> reduced_vregs =
41024103
src_vregs.Slice(src_slice_start, src_slice_end);
41034104
std::optional<Value> acc_vreg;
4105+
auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value {
4106+
switch (tpu_kind) {
4107+
case tpu::ReductionKind::SUM:
4108+
return builder.create<arith::AddFOp>(loc, lhs, rhs);
4109+
case tpu::ReductionKind::MAX:
4110+
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
4111+
case tpu::ReductionKind::MIN:
4112+
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
4113+
}
4114+
};
41044115
auto reduction_status = reduced_vregs.EachStatus(
41054116
[&](const absl::Span<const int64_t> red_idx,
41064117
Value *const src_vreg) {
@@ -4130,20 +4141,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
41304141
if (!acc_vreg.has_value()) {
41314142
acc_vreg = vreg;
41324143
} else {
4133-
switch (tpu_kind) {
4134-
case tpu::ReductionKind::SUM:
4135-
acc_vreg = builder.create<arith::AddFOp>(vreg.getLoc(),
4136-
*acc_vreg, vreg);
4137-
break;
4138-
case tpu::ReductionKind::MAX:
4139-
acc_vreg = builder.create<arith::MaximumFOp>(
4140-
vreg.getLoc(), *acc_vreg, vreg);
4141-
break;
4142-
case tpu::ReductionKind::MIN:
4143-
acc_vreg = builder.create<arith::MinimumFOp>(
4144-
vreg.getLoc(), *acc_vreg, vreg);
4145-
break;
4146-
}
4144+
acc_vreg = reduce_elementwise(*acc_vreg, vreg);
41474145
}
41484146
return absl::OkStatus();
41494147
});
@@ -4156,8 +4154,43 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
41564154
multi_reduction_op->getLoc(), *acc_vreg, 1, tpu_kind);
41574155
}
41584156
if (reduces[0]) {
4157+
// Packed types are compressed along rows, so we need to reduce them
4158+
// within each 32-bit word. There's no performance penalty for doing
4159+
// this in 32-bit precision, so we take advantage of it.
4160+
Type acc_vreg_ty = acc_vreg->getType();
4161+
if (acc_layout.packing() > 1) {
4162+
Type vreg_ty_32 = nullptr;
4163+
if (acc.getType().getElementType().isBF16()) {
4164+
vreg_ty_32 =
4165+
getNativeVregType(builder.getF32Type(), ctx.target_shape);
4166+
} else {
4167+
multi_reduction_op.emitOpError(
4168+
"Not implemented: Unsupported reduction dtype");
4169+
return absl::UnknownError("");
4170+
}
4171+
Value acc_vreg_32 = builder.create<tpu::UnpackSubelementsOp>(
4172+
loc, vreg_ty_32, *acc_vreg, 0, tpu::PackFormat::kInterleaved);
4173+
for (int i = 1; i < acc_layout.packing(); ++i) {
4174+
Value acc_vreg_part_32 = builder.create<tpu::UnpackSubelementsOp>(
4175+
loc, vreg_ty_32, *acc_vreg, i, tpu::PackFormat::kInterleaved);
4176+
acc_vreg_32 = reduce_elementwise(acc_vreg_32, acc_vreg_part_32);
4177+
}
4178+
acc_vreg = acc_vreg_32;
4179+
}
4180+
// At this point acc_vreg is always 32-bit.
41594181
acc_vreg = builder.create<tpu::AllReduceOp>(
41604182
multi_reduction_op->getLoc(), *acc_vreg, 0, tpu_kind);
4183+
// We pack the final result back into the original type.
4184+
if (acc_layout.packing() > 1) {
4185+
SmallVector<int32_t> positions(acc_layout.packing());
4186+
std::iota(positions.begin(), positions.end(),
4187+
static_cast<int32_t>(0));
4188+
SmallVector<Value> parts(acc_layout.packing(), *acc_vreg);
4189+
acc_vreg = builder.create<tpu::PackSubelementsOp>(
4190+
loc, acc_vreg_ty, parts,
4191+
builder.getDenseI32ArrayAttr(positions),
4192+
tpu::PackFormat::kInterleaved);
4193+
}
41614194
}
41624195
*dst_vreg = *acc_vreg;
41634196
return absl::OkStatus();

jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
353353
reduces_sublanes = true;
354354
}
355355
}
356-
if (hardware_generation <= 5 || reduces_sublanes) {
356+
if (hardware_generation <= 5) {
357357
auto new_source = builder.create<arith::ExtFOp>(
358358
VectorType::get(source_ty.getShape(), builder.getF32Type()),
359359
op.getSource());

0 commit comments

Comments
 (0)