diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index 8a33dfac13..9af17b0a68 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -105,14 +105,24 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): assert_size_argument_jax_compatible(node) def sample_fn(rng, size, *parameters): - return jax_sample_fn(op, node=node)(rng, size, out_dtype, *parameters) + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + rng["jax_state"] = rng_key + sample = jax_sample_fn(op, node=node)( + sampling_key, size, out_dtype, *parameters + ) + return (rng, sample) else: def sample_fn(rng, size, *parameters): - return jax_sample_fn(op, node=node)( - rng, static_size, out_dtype, *parameters + rng_key = rng["jax_state"] + rng_key, sampling_key = jax.random.split(rng_key, 2) + rng["jax_state"] = rng_key + sample = jax_sample_fn(op, node=node)( + sampling_key, static_size, out_dtype, *parameters ) + return (rng, sample) return sample_fn @@ -133,12 +143,9 @@ def jax_sample_fn_generic(op, node): name = op.name jax_op = getattr(jax.random, name) - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) - sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype) - rng["jax_state"] = rng_key - return (rng, sample) + def sample_fn(rng_key, size, dtype, *parameters): + sample = jax_op(rng_key, *parameters, shape=size, dtype=dtype) + return sample return sample_fn @@ -159,29 +166,23 @@ def jax_sample_fn_loc_scale(op, node): name = op.name jax_op = getattr(jax.random, name) - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) + def sample_fn(rng_key, size, dtype, *parameters): loc, scale = parameters if size is None: size = jax.numpy.broadcast_arrays(loc, scale)[0].shape - sample = loc + jax_op(sampling_key, size, dtype) * scale - rng["jax_state"] = rng_key - return (rng, sample) + sample = loc + jax_op(rng_key, size, dtype) * scale + return sample return sample_fn @jax_sample_fn.register(ptr.MvNormalRV) def jax_sample_mvnormal(op, node): - def sample_fn(rng, size, dtype, mean, cov): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) + def sample_fn(rng_key, size, dtype, mean, cov): sample = jax.random.multivariate_normal( - sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method + rng_key, mean, cov, shape=size, dtype=dtype, method=op.method ) - rng["jax_state"] = rng_key - return (rng, sample) + return sample return sample_fn @@ -191,12 +192,9 @@ def jax_sample_fn_bernoulli(op, node): """JAX implementation of `BernoulliRV`.""" # We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX - def sample_fn(rng, size, dtype, p): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) - sample = jax.random.bernoulli(sampling_key, p, shape=size) - rng["jax_state"] = rng_key - return (rng, sample) + def sample_fn(rng_key, size, dtype, p): + sample = jax.random.bernoulli(rng_key, p, shape=size) + return sample return sample_fn @@ -206,14 +204,10 @@ def jax_sample_fn_categorical(op, node): """JAX implementation of `CategoricalRV`.""" # We need a separate dispatch because Categorical expects logits in JAX - def sample_fn(rng, size, dtype, p): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) - + def sample_fn(rng_key, size, dtype, p): logits = jax.scipy.special.logit(p) - sample = jax.random.categorical(sampling_key, logits=logits, shape=size) - rng["jax_state"] = rng_key - return (rng, sample) + sample = jax.random.categorical(rng_key, logits=logits, shape=size) + return sample return sample_fn @@ -233,15 +227,10 @@ def jax_sample_fn_uniform(op, node): name = "randint" jax_op = getattr(jax.random, name) - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) + def sample_fn(rng_key, size, dtype, *parameters): minval, maxval = parameters - sample = jax_op( - sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval - ) - rng["jax_state"] = rng_key - return (rng, sample) + sample = jax_op(rng_key, shape=size, dtype=dtype, minval=minval, maxval=maxval) + return sample return sample_fn @@ -258,14 +247,11 @@ def jax_sample_fn_shape_scale(op, node): name = op.name jax_op = getattr(jax.random, name) - def sample_fn(rng, size, dtype, shape, scale): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) + def sample_fn(rng_key, size, dtype, shape, scale): if size is None: size = jax.numpy.broadcast_arrays(shape, scale)[0].shape - sample = jax_op(sampling_key, shape, size, dtype) * scale - rng["jax_state"] = rng_key - return (rng, sample) + sample = jax_op(rng_key, shape, size, dtype) * scale + return sample return sample_fn @@ -274,14 +260,11 @@ def sample_fn(rng, size, dtype, shape, scale): def jax_sample_fn_exponential(op, node): """JAX implementation of `ExponentialRV`.""" - def sample_fn(rng, size, dtype, scale): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) + def sample_fn(rng_key, size, dtype, scale): if size is None: size = jax.numpy.asarray(scale).shape - sample = jax.random.exponential(sampling_key, size, dtype) * scale - rng["jax_state"] = rng_key - return (rng, sample) + sample = jax.random.exponential(rng_key, size, dtype) * scale + return sample return sample_fn @@ -290,14 +273,11 @@ def sample_fn(rng, size, dtype, scale): def jax_sample_fn_t(op, node): """JAX implementation of `StudentTRV`.""" - def sample_fn(rng, size, dtype, df, loc, scale): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) + def sample_fn(rng_key, size, dtype, df, loc, scale): if size is None: size = jax.numpy.broadcast_arrays(df, loc, scale)[0].shape - sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale - rng["jax_state"] = rng_key - return (rng, sample) + sample = loc + jax.random.t(rng_key, df, size, dtype) * scale + return sample return sample_fn @@ -315,10 +295,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node): "A default JAX rewrite should have materialized the implicit arange" ) - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) - + def sample_fn(rng_key, size, dtype, *parameters): if op.has_p_param: a, p, core_shape = parameters else: @@ -327,9 +304,7 @@ def sample_fn(rng, size, dtype, *parameters): core_shape = tuple(np.asarray(core_shape)[(0,) * batch_ndim]) if batch_ndim == 0: - sample = jax.random.choice( - sampling_key, a, shape=core_shape, replace=False, p=p - ) + sample = jax.random.choice(rng_key, a, shape=core_shape, replace=False, p=p) else: if size is None: @@ -345,7 +320,7 @@ def sample_fn(rng, size, dtype, *parameters): if p is not None: p = jax.numpy.broadcast_to(p, size + p.shape[batch_ndim:]) - batch_sampling_keys = jax.random.split(sampling_key, np.prod(size)) + batch_sampling_keys = jax.random.split(rng_key, np.prod(size)) # Ravel the batch dimensions because vmap only works along a single axis raveled_batch_a = a.reshape((-1,) + a.shape[batch_ndim:]) @@ -366,8 +341,7 @@ def sample_fn(rng, size, dtype, *parameters): # Reshape the batch dimensions sample = raveled_sample.reshape(size + raveled_sample.shape[1:]) - rng["jax_state"] = rng_key - return (rng, sample) + return sample return sample_fn @@ -378,9 +352,7 @@ def jax_sample_fn_permutation(op, node): batch_ndim = op.batch_ndim(node) - def sample_fn(rng, size, dtype, *parameters): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) + def sample_fn(rng_key, size, dtype, *parameters): (x,) = parameters if batch_ndim: # jax.random.permutation has no concept of batch dims @@ -389,17 +361,16 @@ def sample_fn(rng, size, dtype, *parameters): else: x = jax.numpy.broadcast_to(x, size + x.shape[batch_ndim:]) - batch_sampling_keys = jax.random.split(sampling_key, np.prod(size)) + batch_sampling_keys = jax.random.split(rng_key, np.prod(size)) raveled_batch_x = x.reshape((-1,) + x.shape[batch_ndim:]) raveled_sample = jax.vmap(lambda key, x: jax.random.permutation(key, x))( batch_sampling_keys, raveled_batch_x ) sample = raveled_sample.reshape(size + raveled_sample.shape[1:]) else: - sample = jax.random.permutation(sampling_key, x) + sample = jax.random.permutation(rng_key, x) - rng["jax_state"] = rng_key - return (rng, sample) + return sample return sample_fn @@ -414,15 +385,9 @@ def jax_sample_fn_binomial(op, node): from numpyro.distributions.util import binomial - def sample_fn(rng, size, dtype, n, p): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) - - sample = binomial(key=sampling_key, n=n, p=p, shape=size) - - rng["jax_state"] = rng_key - - return (rng, sample) + def sample_fn(rng_key, size, dtype, n, p): + sample = binomial(key=rng_key, n=n, p=p, shape=size) + return sample return sample_fn @@ -437,15 +402,9 @@ def jax_sample_fn_multinomial(op, node): from numpyro.distributions.util import multinomial - def sample_fn(rng, size, dtype, n, p): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) - - sample = multinomial(key=sampling_key, n=n, p=p, shape=size) - - rng["jax_state"] = rng_key - - return (rng, sample) + def sample_fn(rng_key, size, dtype, n, p): + sample = multinomial(key=rng_key, n=n, p=p, shape=size) + return sample return sample_fn @@ -460,17 +419,12 @@ def jax_sample_fn_vonmises(op, node): from numpyro.distributions.util import von_mises_centered - def sample_fn(rng, size, dtype, mu, kappa): - rng_key = rng["jax_state"] - rng_key, sampling_key = jax.random.split(rng_key, 2) - + def sample_fn(rng_key, size, dtype, mu, kappa): sample = von_mises_centered( - key=sampling_key, concentration=kappa, shape=size, dtype=dtype + key=rng_key, concentration=kappa, shape=size, dtype=dtype ) sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi - rng["jax_state"] = rng_key - - return (rng, sample) + return sample return sample_fn diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 2a6ebca0af..fb2f6d9bb9 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -796,7 +796,7 @@ def rng_fn(cls, rng, size): @jax_sample_fn.register(CustomRV) def jax_sample_fn_custom(op, node): def sample_fn(rng, size, dtype, *parameters): - return (rng, 0) + return 0 return sample_fn