Skip to content

Commit b37b6c0

Browse files
naummoGoogle-ML-Automation
authored andcommitted
[Mosaic] Enable jnp.exp lowering on SparseCore
PiperOrigin-RevId: 834500194
1 parent 0b7818b commit b37b6c0

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2834,7 +2834,7 @@ def _square_lowering_rule(ctx: LoweringRuleContext, x):
28342834
return arith.mulf(x, x)
28352835

28362836

2837-
@register_lowering_rule(lax.exp_p)
2837+
@register_lowering_rule(lax.exp_p, kernel_types=[*tpu_core.KernelType])
28382838
def _exp_lowering_rule(ctx: LoweringRuleContext, x, accuracy):
28392839
if accuracy is not None:
28402840
raise NotImplementedError("Not implemented: accuracy")

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,24 @@ def kernel(in_ref, o_ref, scratch_ref):
14951495

14961496
np.testing.assert_array_equal(f(x), x)
14971497

1498+
def test_exp(self):
1499+
if not jtu.if_cloud_tpu_at_least(2025, 11, 21):
1500+
self.skipTest("Test requires a newer libtpu")
1501+
1502+
x = jnp.arange(8, dtype=jnp.float32)
1503+
1504+
def sc_exp_kernel(x_hbm_ref, out_ref):
1505+
out_ref[...] = jnp.exp(x_hbm_ref[...])
1506+
1507+
result = pl.pallas_call(
1508+
sc_exp_kernel,
1509+
compiler_params=pltpu.CompilerParams(
1510+
kernel_type=pltpu.KernelType.SC_VECTOR_SUBCORE
1511+
),
1512+
out_shape=x,
1513+
)(x)
1514+
np.testing.assert_array_equal(result, jnp.exp(x))
1515+
14981516

14991517
class VectorSubcoreTestWithTCTiling(TCTilingMixin, VectorSubcoreTest):
15001518
pass

0 commit comments

Comments
 (0)