We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8bb4421 commit 8d1ed78Copy full SHA for 8d1ed78
axlearn/common/utils_spmd.py
@@ -43,8 +43,6 @@ def setup(
43
"""
44
# Use a GSPMD-friendly PRNG implementation.
45
jax.config.update("jax_default_prng_impl", "rbg")
46
- # This allows replicated jax.Arrays to be used for computation on the host.
47
- jax.config.update("jax_spmd_mode", "allow_all")
48
49
global _jax_distributed_initialized # pylint: disable=global-statement
50
if not _jax_distributed_initialized:
0 commit comments