Skip to content

Commit 9653ade

Browse files
committed
Implement JAX dispatch for TriGamma
1 parent a1fcb77 commit 9653ade

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

pytensor/link/jax/dispatch/scalar.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,17 @@
2020
Second,
2121
Sub,
2222
)
23-
from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi
23+
from pytensor.scalar.math import (
24+
Erf,
25+
Erfc,
26+
Erfcinv,
27+
Erfcx,
28+
Erfinv,
29+
Iv,
30+
Log1mexp,
31+
Psi,
32+
TriGamma,
33+
)
2434

2535

2636
def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Callable:
@@ -275,6 +285,14 @@ def psi(x):
275285
return psi
276286

277287

288+
@jax_funcify.register(TriGamma)
289+
def jax_funcify_TriGamma(op, node, **kwargs):
290+
def tri_gamma(x):
291+
return jax.scipy.special.polygamma(1, x)
292+
293+
return tri_gamma
294+
295+
278296
@jax_funcify.register(Softplus)
279297
def jax_funcify_Softplus(op, **kwargs):
280298
def softplus(x):

tests/link/jax/test_scalar.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ def test_psi():
171171
compare_jax_and_py(fg, [3.0])
172172

173173

174+
def test_tri_gamma():
175+
x = vector("x", dtype="float64")
176+
out = tri_gamma(x)
177+
fg = FunctionGraph([x], [out])
178+
compare_jax_and_py(fg, [np.array([3.0, 5.0])])
179+
180+
174181
def test_log1mexp():
175182
x = vector("x")
176183
out = log1mexp(x)

0 commit comments

Comments
 (0)