Skip to content

Commit 3a5c4da

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Support i32 vector multi reduction except cross lane.
PiperOrigin-RevId: 707708236
1 parent 6bcec91 commit 3a5c4da

File tree

4 files changed

+91
-22
lines changed

4 files changed

+91
-22
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,10 +1459,14 @@ def _proxy_fun(val, *, axes):
14591459
kind = type_to_kind[jnp.floating]
14601460
val = type_to_identity[jnp.floating]
14611461
val = ir.FloatAttr.get(aval_to_ir_type(x_aval, shape=()), val)
1462-
elif jnp.issubdtype(x_aval.dtype, jnp.signedinteger):
1463-
raise NotImplementedError("Reductions over integers not implemented.")
1462+
elif x_aval.dtype == jnp.int32:
1463+
kind = type_to_kind[jnp.signedinteger]
1464+
val = type_to_identity[jnp.signedinteger]
1465+
val = ir.IntegerAttr.get(ir.IntegerType.get_signless(32), val)
14641466
elif jnp.issubdtype(x_aval.dtype, jnp.unsignedinteger):
1465-
raise NotImplementedError("Reductions over integers not implemented.")
1467+
raise NotImplementedError(
1468+
"Reductions over unsigned integers not implemented."
1469+
)
14661470
else:
14671471
raise NotImplementedError(
14681472
f"Reductions over {x_aval.dtype} not implemented.")

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,11 @@
5151
#include "absl/log/log.h"
5252
#include "absl/status/status.h"
5353
#include "absl/types/span.h"
54+
#include "llvm/include/llvm/ADT/APInt.h"
5455
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h"
5556
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h"
5657
#include "mlir/include/mlir/Dialect/Vector/IR/VectorOps.h"
58+
#include "mlir/include/mlir/IR/Attributes.h"
5759
#include "mlir/include/mlir/IR/Builders.h"
5860
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h"
5961
#include "mlir/include/mlir/IR/OperationSupport.h"
@@ -554,7 +556,7 @@ VectorType getNativeVregType(Type elem_ty,
554556
FailureOr<Value> maskOOB(RewriteContext &ctx, OpBuilder &builder,
555557
TypedValue<VectorType> value,
556558
const VRegDataBounds &bounds,
557-
const TypedAttr neutral) {
559+
const Attribute neutral) {
558560
auto native_vreg_ty =
559561
getNativeVregType(value.getType().getElementType(), ctx.target_shape);
560562
TPU_ASSERT_LOC(value.getLoc(), llvm::equal(value.getType().getShape(),
@@ -3926,6 +3928,7 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
39263928
ImplicitLocOpBuilder builder(op.getLoc(), &op);
39273929
auto multi_reduction_op = cast<vector::MultiDimReductionOp>(op);
39283930
const VectorType src_ty = multi_reduction_op.getSourceVectorType();
3931+
auto element_type = src_ty.getElementType();
39293932
int64_t src_rank = src_ty.getRank();
39303933
const auto res_ty = dyn_cast<VectorType>(multi_reduction_op.getDestType());
39313934
if (res_ty == nullptr) {
@@ -3953,44 +3956,56 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
39533956
return multi_reduction_op.emitOpError(
39543957
"Not implemented: Only constant accumulator supported");
39553958
}
3956-
if (!src_ty.getElementType().isF32() && !src_ty.getElementType().isBF16()) {
3959+
if (!element_type.isF32() && !element_type.isBF16() &&
3960+
!element_type.isSignlessInteger((32))) {
39573961
return multi_reduction_op.emitOpError(
3958-
"Not implemented: Only FP32 and BF16 reductions supported, but "
3959-
"got ")
3960-
<< src_ty;
3962+
"Not implemented: unsupported element type");
39613963
}
3962-
auto element_type = cast<FloatType>(src_ty.getElementType());
3963-
const auto acc_def_value = dyn_cast<DenseFPElementsAttr>(acc_def.getValue());
3964+
bool is_int = element_type.isSignlessInteger(32);
3965+
const auto acc_def_value = dyn_cast<DenseElementsAttr>(acc_def.getValue());
39643966
if (acc_def_value == nullptr || !acc_def_value.isSplat()) {
39653967
return multi_reduction_op.emitOpError("Expected a splat constant");
39663968
}
39673969
TPU_ASSERT_OP(acc_def_value.getElementType() == element_type);
3968-
const auto val = acc_def_value.getSplatValue<FloatAttr>();
3969-
FloatAttr neutral;
3970+
Attribute neutral;
39703971
switch (multi_reduction_op.getKind()) {
39713972
case vector::CombiningKind::ADD:
3972-
neutral = builder.getFloatAttr(element_type, 0);
3973+
neutral = builder.getZeroAttr(element_type);
39733974
break;
39743975
case vector::CombiningKind::MAXIMUMF: {
39753976
// TODO(b/322836633): The semantics of maximumf don't match the lowering
39763977
// for older TPU versions because older TPU versions don't respect the
39773978
// -0.0 vs +0.0 ordering.
39783979
neutral = builder.getFloatAttr(
3979-
element_type, APFloat::getInf(element_type.getFloatSemantics(),
3980-
/*Negative=*/true));
3980+
element_type,
3981+
APFloat::getInf(cast<FloatType>(element_type).getFloatSemantics(),
3982+
/*Negative=*/true));
39813983
} break;
39823984
case vector::CombiningKind::MINIMUMF: {
39833985
neutral = builder.getFloatAttr(
3984-
element_type, APFloat::getInf(element_type.getFloatSemantics(),
3985-
/*Negative=*/false));
3986+
element_type,
3987+
APFloat::getInf(cast<FloatType>(element_type).getFloatSemantics(),
3988+
/*Negative=*/false));
3989+
} break;
3990+
case vector::CombiningKind::MAXSI: {
3991+
neutral = builder.getIntegerAttr(
3992+
element_type,
3993+
APInt::getSignedMinValue(element_type.getIntOrFloatBitWidth()));
3994+
} break;
3995+
case vector::CombiningKind::MINSI: {
3996+
neutral = builder.getIntegerAttr(
3997+
element_type,
3998+
APInt::getSignedMaxValue(element_type.getIntOrFloatBitWidth()));
39863999
} break;
39874000
default:
39884001
return multi_reduction_op.emitOpError(
39894002
"Not implemented: unsupported kind");
39904003
}
3991-
if (val != neutral) {
4004+
if (auto val = acc_def_value.getSplatValue<Attribute>(); val != neutral) {
39924005
return multi_reduction_op.emitOpError(
3993-
"Not implemented: Only neutral accumulator supported");
4006+
"Not implemented: Only neutral accumulator supported for "
4007+
"float reduction. Expected ")
4008+
<< neutral << ", but got " << val;
39944009
}
39954010

39964011
std::array<bool, 2> reduces;
@@ -4074,9 +4089,11 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
40744089
tpu_kind = tpu::ReductionKind::SUM;
40754090
break;
40764091
case vector::CombiningKind::MAXIMUMF:
4092+
case vector::CombiningKind::MAXSI:
40774093
tpu_kind = tpu::ReductionKind::MAX;
40784094
break;
40794095
case vector::CombiningKind::MINIMUMF:
4096+
case vector::CombiningKind::MINSI:
40804097
tpu_kind = tpu::ReductionKind::MIN;
40814098
break;
40824099
default:
@@ -4103,14 +4120,29 @@ LogicalResult vector_multi_reduction_rule(RewriteContext &ctx, Operation &op,
41034120
src_vregs.Slice(src_slice_start, src_slice_end);
41044121
std::optional<Value> acc_vreg;
41054122
auto reduce_elementwise = [&](Value lhs, Value rhs) -> Value {
4123+
Value result;
41064124
switch (tpu_kind) {
41074125
case tpu::ReductionKind::SUM:
4108-
return builder.create<arith::AddFOp>(loc, lhs, rhs);
4126+
result =
4127+
is_int
4128+
? builder.create<arith::AddIOp>(loc, lhs, rhs).getResult()
4129+
: builder.create<arith::AddFOp>(loc, lhs, rhs)
4130+
.getResult();
4131+
break;
41094132
case tpu::ReductionKind::MAX:
4110-
return builder.create<arith::MaximumFOp>(loc, lhs, rhs);
4133+
result = is_int ? builder.create<arith::MaxSIOp>(loc, lhs, rhs)
4134+
.getResult()
4135+
: builder.create<arith::MaximumFOp>(loc, lhs, rhs)
4136+
.getResult();
4137+
break;
41114138
case tpu::ReductionKind::MIN:
4112-
return builder.create<arith::MinimumFOp>(loc, lhs, rhs);
4139+
result = is_int ? builder.create<arith::MinSIOp>(loc, lhs, rhs)
4140+
.getResult()
4141+
: builder.create<arith::MinimumFOp>(loc, lhs, rhs)
4142+
.getResult();
4143+
break;
41134144
}
4145+
return result;
41144146
};
41154147
auto reduction_status = reduced_vregs.EachStatus(
41164148
[&](const absl::Span<const int64_t> red_idx,

jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,14 @@ LogicalResult canonicalize_multi_dim_reduction(int hardware_generation,
382382
op.erase();
383383
}
384384
return success();
385+
} else if (element_type.isSignlessInteger(32) &&
386+
// TODO(b/384774084): Add support for u32 reductions.
387+
(op.getKind() == vector::CombiningKind::ADD ||
388+
op.getKind() == vector::CombiningKind::MAXSI ||
389+
op.getKind() == vector::CombiningKind::MINSI)) {
390+
return success();
385391
}
392+
op.emitOpError("Unsupported element type for the selected reduction");
386393
return failure();
387394
}
388395

tests/pallas/tpu_ops_test.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,32 @@ def kernel(x_ref, mask_ref, o_ref):
282282
expected = jnp.where(mask, x, jnp.zeros_like(x))
283283
self.assertArraysEqual(out, expected)
284284

285+
@parameterized.product(
286+
dtype = [jnp.float32, jnp.bfloat16, jnp.int32],
287+
axis = [0, 1, 2],
288+
reduce_func = [jnp.sum, jnp.max, jnp.min]
289+
)
290+
def test_reduction(self, dtype, axis, reduce_func):
291+
if dtype == jnp.int32 and axis == 2:
292+
self.skipTest("Int32 reduction on minor is not supported.")
293+
# TODO(b/384127570): fix bfloat16 reduction.
294+
if dtype == jnp.bfloat16 and reduce_func != jnp.sum:
295+
self.skipTest("b/384127570")
296+
in_shape = (2, 16, 128)
297+
out_shape = list(in_shape)
298+
out_shape[axis] = 1
299+
300+
def kernel(x, out):
301+
out[:] = reduce_func(x[:], axis, keepdims=True)
302+
303+
x = jnp.arange(np.prod(in_shape), dtype=dtype).reshape(in_shape)
304+
result = pl.pallas_call(
305+
kernel,
306+
out_shape=jax.ShapeDtypeStruct(out_shape, x.dtype),
307+
)(x)
308+
expected = reduce_func(x, axis, keepdims=True)
309+
np.testing.assert_array_equal(result, expected)
310+
285311

286312
class OpsInterpretTest(OpsTest):
287313
INTERPRET = True

0 commit comments

Comments
 (0)