Skip to content

chore: replace legacy jax.random.PRNGKey with modern jax.random.key#2134

Merged
fehiepsi merged 10 commits intopyro-ppl:masterfrom
Qazalbash:refactor-prng-key
Feb 15, 2026
Merged

chore: replace legacy jax.random.PRNGKey with modern jax.random.key#2134
fehiepsi merged 10 commits intopyro-ppl:masterfrom
Qazalbash:refactor-prng-key

Conversation

@Qazalbash
Copy link
Collaborator

As per the official JAX docs, jax.random.PRNGKey is a legacy API and should be replaced with jax.random.key wherever possible 1. This PR replaces calls to jax.random.PRNGKey with jax.random.key, accompanied by necessary documentation.

I have replaced "PRNGKey" with "PRNG key" in the documentation to distinguish PRNG key as a concept from the function jax.random.PRNGKey.

Footnotes

  1. See note in https://docs.jax.dev/en/latest/jax.random.html#prng-keys

@review-notebook-app
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!!

@Qazalbash
Copy link
Collaborator Author

The following test case is failing on optax-v0.2.7 but passing on older versions.

FAILED test/test_optimizers.py::test_numpyrooptim_no_double_jit[chain-args14-kwargs14-True] - assert 1 == 2

I am unable to reproduce the following test cases on my local machine. Maybe they pass in the next CI.

FAILED test/test_pickle.py::test_pickle_hmc[BarkerMH] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc[HMC] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc[NUTS] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc[SA] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc_enumeration[BarkerMH] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc_enumeration[HMC] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc_enumeration[NUTS] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmc_enumeration[SA] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_discrete_hmc[DiscreteHMCGibbs] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_discrete_hmc[MixedHMC] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_hmcecs - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_autoguide[AutoDelta] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_autoguide[AutoDiagonalNormal] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_pickle_autoguide[AutoNormal] - TypeError: cannot pickle 'PRNGKeyArray' object
FAILED test/test_pickle.py::test_mcmc_pickle_post_warmup - TypeError: cannot pickle 'PRNGKeyArray' object

@Qazalbash
Copy link
Collaborator Author

@fehiepsi, pickling failed again. Can you rerun the CI?

@Qazalbash
Copy link
Collaborator Author

Need #2136 and #2137 for CI to pass.

@fehiepsi
Copy link
Member

Could you add a simple test in test_pickle to check if we can pickle a PRNGKey? if it happens on CI and is unrelated to numpyro, we can report the issue to jax devs.

@fehiepsi
Copy link
Member

It seems the tests are failing consistently. How about using PRNGKey like before in this test?

@Qazalbash
Copy link
Collaborator Author

It seems the tests are failing consistently. How about using PRNGKey like before in this test?

And it is only in py3.11. I will use the legacy keys in the tests and report this issue to JAX team.

@Qazalbash
Copy link
Collaborator Author

Issue reported: jax-ml/jax#35065

@Qazalbash
Copy link
Collaborator Author

Qazalbash commented Feb 13, 2026

@fehiepsi, we fixed the issue related to optax==0.2.7 in #2137. Now flax has pinned the optax version (google/flax#5225), and it is effective in the latest flax==0.12.4. Should we revert the changes?

@fehiepsi
Copy link
Member

Sounds reasonable to me. We can revise the test to be flexible when uses_value_arg=True and update the failing ones with uses_value_arg=True.

@fehiepsi fehiepsi merged commit d5598e7 into pyro-ppl:master Feb 15, 2026
9 checks passed
@Qazalbash Qazalbash deleted the refactor-prng-key branch February 15, 2026 18:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants