Skip to content

Commit c5dc980

Browse files
cperivolGoogle-ML-Automation
authored andcommitted
[mgpu/pallas_mgpu] Pointwise tanh support
PiperOrigin-RevId: 700158250
1 parent ef7df1a commit c5dc980

File tree

4 files changed

+21
-2
lines changed

4 files changed

+21
-2
lines changed

jax/_src/pallas/mosaic_gpu/lowering.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1179,6 +1179,12 @@ def _rsqrt_lowering_rule(ctx: LoweringRuleContext, x):
11791179
[x_aval] = ctx.avals_in
11801180
return _ensure_fa(x, x_aval.dtype).rsqrt(approx=ctx.module_ctx.approx_math)
11811181

1182+
@register_lowering_rule(lax.tanh_p)
1183+
def _tanh_lowering_rule(ctx: LoweringRuleContext, x):
1184+
[x_aval] = ctx.avals_in
1185+
return _ensure_fa(x, x_aval.dtype).tanh(approx=ctx.module_ctx.approx_math)
1186+
1187+
11821188
@register_lowering_rule(lax.logistic_p)
11831189
def _logistic_lowering_rule(ctx: LoweringRuleContext, x):
11841190
[x_aval] = ctx.avals_in

jax/experimental/mosaic/gpu/fragmented_array.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -918,6 +918,15 @@ def cos(self, *, approx: bool = False):
918918
self._lift_fast_unary("cos.approx.f32") if approx else mlir_math.cos
919919
)
920920

921+
def tanh(self, *, approx: bool = False):
922+
if not ir.FloatType.isinstance(self.mlir_dtype):
923+
raise NotImplementedError
924+
if approx and self.mlir_dtype != ir.F32Type.get():
925+
raise NotImplementedError
926+
return self._pointwise(
927+
self._lift_fast_unary("tanh.approx.f32") if approx else mlir_math.tanh
928+
)
929+
921930
def rsqrt(self, *, approx: bool = False):
922931
if not ir.FloatType.isinstance(self.mlir_dtype):
923932
raise NotImplementedError

tests/mosaic/gpu_test.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,11 +1354,14 @@ def kernel(ctx, dst, _):
13541354
ops=(
13551355
(lambda x: -x, jax.lax.neg),
13561356
(lambda x: x + 42, lambda x: x + 42),
1357+
(lambda x: x.tanh(), jax.lax.tanh),
13571358
),
13581359
dtype=[jnp.float32, jnp.int32, jnp.uint32],
13591360
)
13601361
def test_unary(self, ops, dtype, m=64, n=32):
13611362
op, np_op = ops
1363+
if np_op is jax.lax.tanh and jnp.issubdtype(dtype, jnp.integer):
1364+
raise self.skipTest("Tanh not supported for integer types")
13621365

13631366
def kernel(ctx, dst, _):
13641367
iota = iota_tensor(m, n, dtype)

tests/pallas/mosaic_gpu_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,9 @@ class PallasCallTest(PallasTest):
7070
("exp", jax.lax.exp),
7171
("square", lambda x: x ** 2),
7272
("rsqrt", jax.lax.rsqrt),
73+
("tanh", jax.lax.tanh, 1e-6),
7374
)
74-
def test_unary_ops(self, unary):
75+
def test_unary_ops(self, unary, rtol=1e-7):
7576
@functools.partial(
7677
pl.pallas_call,
7778
out_shape=jax.ShapeDtypeStruct([256], jnp.float32),
@@ -80,7 +81,7 @@ def kernel(x_ref, o_ref):
8081
o_ref[...] = unary(x_ref[...])
8182

8283
x = jnp.arange(256).astype(jnp.float32)
83-
np.testing.assert_array_equal(kernel(x), unary(x))
84+
np.testing.assert_allclose(kernel(x), unary(x), rtol=rtol)
8485

8586
def test_add_first(self):
8687
@functools.partial(

0 commit comments

Comments
 (0)