Skip to content

Commit fc0452f

Browse files
rlouftwiecki
authored andcommitted
Add StudentTRV JAX implementation
1 parent 383d4ef commit fc0452f

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

pytensor/link/jax/dispatch/random.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,24 @@ def sample_fn(rng, size, dtype, *parameters):
208208
return sample_fn
209209

210210

211+
@jax_sample_fn.register(aer.StudentTRV)
212+
def jax_sample_fn_t(op):
213+
"""JAX implementation of `StudentTRV`."""
214+
215+
def sample_fn(rng, size, dtype, *parameters):
216+
rng_key = rng["jax_state"]
217+
(
218+
df,
219+
loc,
220+
scale,
221+
) = parameters
222+
sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
223+
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
224+
return (rng, sample)
225+
226+
return sample_fn
227+
228+
211229
@jax_sample_fn.register(aer.ChoiceRV)
212230
def jax_funcify_choice(op):
213231
"""JAX implementation of `ChoiceRV`."""

tests/link/jax/test_random.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,26 @@ def test_random_updates(rng_ctor):
205205
"randint",
206206
lambda *args: args,
207207
),
208+
(
209+
aer.t,
210+
[
211+
set_test_value(
212+
at.dscalar(),
213+
np.array(2.0, dtype=np.float64),
214+
),
215+
set_test_value(
216+
at.dvector(),
217+
np.array([1.0, 2.0], dtype=np.float64),
218+
),
219+
set_test_value(
220+
at.dscalar(),
221+
np.array(1.0, dtype=np.float64),
222+
),
223+
],
224+
(2,),
225+
"t",
226+
lambda *args: args,
227+
),
208228
(
209229
aer.uniform,
210230
[

0 commit comments

Comments
 (0)