Skip to content

Commit e88b578

Browse files
ayaka14732Google-ML-Automation
authored andcommitted
[Pallas TPU] Add WeirdOp to TPU dialect and add lowering for lax.is_finite
PiperOrigin-RevId: 704888940
1 parent 3ca9f14 commit e88b578

File tree

4 files changed

+85
-0
lines changed

4 files changed

+85
-0
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,6 +2385,15 @@ def _and_lowering_rule(ctx: LoweringRuleContext, x, y):
23852385
skip_mlir_conversions.add(lax.and_p)
23862386

23872387

2388+
def _is_finite_lowering_rule(ctx: LoweringRuleContext, x):
2389+
out_aval, = ctx.avals_out
2390+
out_type = aval_to_ir_type(out_aval)
2391+
return _not_lowering_rule(ctx, tpu.weird(out_type, x))
2392+
2393+
2394+
lowering_rules[lax.is_finite_p] = _is_finite_lowering_rule
2395+
2396+
23882397
def _or_lowering_rule(ctx: LoweringRuleContext, x, y):
23892398
x, y = _bcast(x, y, *ctx.avals_in, *ctx.avals_out)
23902399
return arith.ori(x, y)

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,13 @@ def TPU_BitcastVregOp : TPU_Op<"bitcast_vreg", [Pure]> {
473473
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
474474
}
475475

476+
def TPU_WeirdOp : TPU_Op<"weird", [Pure, ElementwiseMappable]> {
477+
let arguments = (ins AnyType:$input); // F32 vector or scalar
478+
let results = (outs AnyType:$output); // I1 vector or scalar
479+
let assemblyFormat = [{ $input attr-dict `:` type($input) `->` type($output) }];
480+
let hasVerifier = 1;
481+
}
482+
476483
def TPU_RollVectorsOp : TPU_Op<"roll_vectors", [Pure]> {
477484
let arguments = (ins Variadic<AnyVectorOfNonZeroRank>:$input);
478485
let results = (outs AnyVectorOfNonZeroRank:$output);

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,6 +1087,32 @@ LogicalResult LogOp::verify() {
10871087
stringifyCoreType(logging_core_type_maybe->value())));
10881088
}
10891089

1090+
LogicalResult WeirdOp::verify() {
1091+
const mlir::Type in_type = getInput().getType();
1092+
if (const auto in_vec_type = dyn_cast<VectorType>(in_type)) { // Vector case.
1093+
if (!in_vec_type.getElementType().isF32()) {
1094+
return emitOpError("Input type must be F32");
1095+
}
1096+
const mlir::Type out_type = getResult().getType();
1097+
const auto out_vec_type = dyn_cast<VectorType>(out_type);
1098+
if (!out_vec_type) {
1099+
return emitOpError("Output type must be a vector when input is a vector");
1100+
}
1101+
if (!out_vec_type.getElementType().isInteger(1)) {
1102+
return emitOpError("Output type must be I1");
1103+
}
1104+
} else { // Scalar case.
1105+
if (!in_type.isF32()) {
1106+
return emitOpError("Input type must be F32");
1107+
}
1108+
const mlir::Type out_type = getResult().getType();
1109+
if (!out_type.isInteger(1)) {
1110+
return emitOpError("Output type must be I1 scalar");
1111+
}
1112+
}
1113+
return success();
1114+
}
1115+
10901116
} // namespace tpu
10911117
} // namespace mlir
10921118

tests/pallas/ops_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -757,6 +757,49 @@ def kernel(x_ref, o_ref):
757757
expected = lax.erf_inv(x)
758758
np.testing.assert_array_equal(out, expected)
759759

760+
IS_FINITE_TEST_VALUES = [
761+
-0.2, jnp.inf, -jnp.inf, jnp.nan, 0.0, 1.0, -1.0, 0.5,
762+
]
763+
764+
def test_is_finite(self):
765+
if jtu.test_device_matches(["gpu"]):
766+
self.skipTest("Not supported on GPU")
767+
768+
size = len(self.IS_FINITE_TEST_VALUES)
769+
770+
@functools.partial(
771+
self.pallas_call,
772+
out_shape=jax.ShapeDtypeStruct((size,), jnp.bool_),
773+
)
774+
def kernel(x_ref, o_ref):
775+
o_ref[...] = lax.is_finite(x_ref[...])
776+
777+
x = jnp.array(self.IS_FINITE_TEST_VALUES, dtype=jnp.float32)
778+
out = kernel(x)
779+
expected = lax.is_finite(x)
780+
self.assertArraysEqual(out, expected)
781+
782+
def test_is_finite_scalar(self):
783+
if jtu.test_device_matches(["gpu"]):
784+
self.skipTest("Not supported on GPU")
785+
786+
size = len(self.IS_FINITE_TEST_VALUES)
787+
788+
@functools.partial(
789+
self.pallas_call,
790+
in_specs=(pl.BlockSpec(memory_space=smem_on_tpu()),),
791+
out_specs=pl.BlockSpec(memory_space=smem_on_tpu()),
792+
out_shape=jax.ShapeDtypeStruct((size,), jnp.bool_),
793+
)
794+
def kernel(x_ref, o_ref):
795+
for i in range(8):
796+
o_ref[i] = jnp.isfinite(x_ref[i])
797+
798+
x = jnp.array(self.IS_FINITE_TEST_VALUES, dtype=jnp.float32)
799+
out = kernel(x)
800+
expected = lax.is_finite(x)
801+
self.assertArraysEqual(out, expected)
802+
760803
ELEMENTWISE_OPS = [
761804
(
762805
[jnp.abs, jnp.negative],

0 commit comments

Comments
 (0)