Skip to content

[pmap] In-line definitions of jax.device_put_sharded and jax.device_put_replicated.#5374

Open
copybara-service[bot] wants to merge 1 commit intomainfrom
test_888582438
Open

[pmap] In-line definitions of jax.device_put_sharded and jax.device_put_replicated.#5374
copybara-service[bot] wants to merge 1 commit intomainfrom
test_888582438

Conversation

@copybara-service
Copy link

[pmap] In-line definitions of jax.device_put_sharded and jax.device_put_replicated.

Both jax.device_put_sharded and jax.device_put_replicated were deprecated in JAX v0.8.1 in November 2025. We in-line their definitions using public JAX APIs taking the jax_pmap_shmap_merge=True branch, which was made the default in JAX v0.8.0 in October 2025.

Please see the below for more information:

…e_put_replicated`.

Both `jax.device_put_sharded` and `jax.device_put_replicated` were deprecated in JAX v0.8.1 in November 2025. We in-line their definitions using public JAX APIs taking the `jax_pmap_shmap_merge=True` branch, which was made the default in JAX v0.8.0 in October 2025.

Please see the below for more information:
- JAX CHANGELOG: https://docs.jax.dev/en/latest/changelog.html
- Migrating from `jax.pmap`: https://docs.jax.dev/en/latest/migrate_pmap.html

PiperOrigin-RevId: 888582438
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Mar 25, 2026

@danielsuo I tried to tackle this problem some time ago and there were some internal failures that I could not fix in the flax code: #5101
Can you take the tests from my PR here to ensure we get the same behaviour with the rewritten methods. Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants