Skip to content

Commit 8d1ed78

Browse files
authored
remove jax_spmd_mode (#1211)
1 parent 8bb4421 commit 8d1ed78

File tree

1 file changed

+0
-2
lines changed

1 file changed

+0
-2
lines changed

axlearn/common/utils_spmd.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ def setup(
4343
"""
4444
# Use a GSPMD-friendly PRNG implementation.
4545
jax.config.update("jax_default_prng_impl", "rbg")
46-
# This allows replicated jax.Arrays to be used for computation on the host.
47-
jax.config.update("jax_spmd_mode", "allow_all")
4846

4947
global _jax_distributed_initialized # pylint: disable=global-statement
5048
if not _jax_distributed_initialized:

0 commit comments

Comments
 (0)