Skip to content

Commit c60bafc

Browse files
ayaka14732Google-ML-Automation
authored andcommitted
[Pallas TPU] Fix lowering for jnp.remainder
Fixes jax-ml#24027 PiperOrigin-RevId: 688614799
1 parent 2b7b074 commit c60bafc

File tree

2 files changed

+9
-16
lines changed

2 files changed

+9
-16
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,11 +1945,11 @@ def _div_lowering_rule(ctx: LoweringRuleContext, x, y):
19451945
def _rem_lowering_rule(ctx: LoweringRuleContext, x, y):
19461946
x, y = _bcast(x, y, ctx.avals_in[0], ctx.avals_in[1], ctx.avals_out[0])
19471947
(aval_out,) = ctx.avals_out
1948-
if jnp.issubdtype(aval_out.dtype, jnp.integer):
1948+
if jnp.issubdtype(aval_out.dtype, jnp.signedinteger):
19491949
return arith.remsi(x, y)
19501950
if jnp.issubdtype(aval_out.dtype, jnp.unsignedinteger):
19511951
return arith.remui(x, y)
1952-
elif jnp.issubdtype(aval_out.dtype, jnp.floating):
1952+
if jnp.issubdtype(aval_out.dtype, jnp.floating):
19531953
return arith.remf(x, y)
19541954
raise NotImplementedError(aval_out.dtype)
19551955

tests/pallas/ops_test.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,16 +1056,12 @@ def test_binary(self, f, dtype):
10561056
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
10571057
self.skipTest("16-bit types are not supported on TPU")
10581058

1059-
# TODO: skipped due to https://github.com/jax-ml/jax/issues/24027
1059+
# TODO(ayx): Fix these operations on TPU
10601060
if (
10611061
jtu.test_device_matches(["tpu"])
1062-
and f == jnp.remainder
1063-
and not self.INTERPRET
1062+
and f in (jnp.floor_divide, jnp.subtract)
1063+
and dtype == "uint32"
10641064
):
1065-
self.skipTest("jnp.remainder on TPU is only supported in interpret mode")
1066-
1067-
# TODO(ayx): fix this on TPU
1068-
if jtu.test_device_matches(["tpu"]) and dtype == "uint32":
10691065
self.skipTest("Not supported on TPU")
10701066

10711067
@functools.partial(
@@ -1092,16 +1088,13 @@ def test_binary_scalar(self, f, dtype):
10921088
self.skipTest("Test only supported on TPU.")
10931089
if jtu.test_device_matches(["tpu"]) and jnp.dtype(dtype).itemsize == 2:
10941090
self.skipTest("16-bit types are not supported on TPU")
1095-
# TODO: skipped due to https://github.com/jax-ml/jax/issues/24027
1091+
1092+
# TODO(ayx): Fix these operations on TPU
10961093
if (
10971094
jtu.test_device_matches(["tpu"])
1098-
and f == jnp.remainder
1099-
and not self.INTERPRET
1095+
and f in (jnp.floor_divide, jnp.subtract)
1096+
and dtype == "uint32"
11001097
):
1101-
self.skipTest("jnp.remainder on TPU is only supported in interpret mode")
1102-
1103-
# TODO: skipped due to https://github.com/jax-ml/jax/issues/23972
1104-
if jtu.test_device_matches(["tpu"]) and dtype == "uint32":
11051098
self.skipTest("Not supported on TPU")
11061099

11071100
@functools.partial(

0 commit comments

Comments
 (0)