Skip to content

Commit c9c043c

Browse files
Merge pull request jax-ml#24964 from emilyfertig:emilyaf-deactivate-cluster-detection
PiperOrigin-RevId: 702152342
2 parents 908865f + 6a8bbcb commit c9c043c

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

jax/_src/distributed.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def initialize(self,
6262
if local_device_ids is None and (env_ids := os.environ.get('JAX_LOCAL_DEVICE_IDS')):
6363
local_device_ids = list(map(int, env_ids.split(",")))
6464

65-
if None in (coordinator_address, num_processes, process_id, local_device_ids):
65+
if (cluster_detection_method != 'deactivate' and
66+
None in (coordinator_address, num_processes, process_id, local_device_ids)):
6667
(coordinator_address, num_processes, process_id, local_device_ids) = (
6768
clusters.ClusterEnv.auto_detect_unset_distributed_params(
6869
coordinator_address,
@@ -217,7 +218,8 @@ def initialize(coordinator_address: str | None = None,
217218
cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed
218219
run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment,
219220
and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``.
220-
Legacy auto-detect options (OMPI, Slurm) remain enabled.
221+
Legacy auto-detect options "ompi" (OMPI) and "slurm" (Slurm) remain enabled. "deactivate" bypasses
222+
automatic cluster detection.
221223
initialization_timeout: Time period (in seconds) for which connection will
222224
be retried. If the initialization takes more than the timeout specified,
223225
the initialization will error. Defaults to 300 secs i.e. 5 mins.

0 commit comments

Comments
 (0)