Skip to content

Conversation

@copybara-service
Copy link

Handle PRNG keys in reshard by using jax.device_put.

The reshard function now treats jax.Array instances that are PRNG keys as non-reshardable via the experimental reshard sidechannel API, directing them through jax.device_put instead of the custom device-to-device transfer logic. This ensures correct handling of PRNG keys when resharding PyTrees. A new test case is added to confirm proper resharding of PyTrees containing random keys.

The `reshard` function now treats `jax.Array` instances that are PRNG keys as non-reshardable via the experimental reshard sidechannel API, directing them through `jax.device_put` instead of the custom device-to-device transfer logic. This ensures correct handling of PRNG keys when resharding PyTrees. A new test case is added to confirm proper resharding of PyTrees containing random keys.

PiperOrigin-RevId: 814746994
@copybara-service copybara-service bot merged commit c9fb204 into main Oct 3, 2025
@copybara-service copybara-service bot deleted the test_814737436 branch October 3, 2025 17:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant