@@ -62,7 +62,8 @@ def initialize(self,
6262 if local_device_ids is None and (env_ids := os .environ .get ('JAX_LOCAL_DEVICE_IDS' )):
6363 local_device_ids = list (map (int , env_ids .split ("," )))
6464
65- if None in (coordinator_address , num_processes , process_id , local_device_ids ):
65+ if (cluster_detection_method != 'deactivate' and
66+ None in (coordinator_address , num_processes , process_id , local_device_ids )):
6667 (coordinator_address , num_processes , process_id , local_device_ids ) = (
6768 clusters .ClusterEnv .auto_detect_unset_distributed_params (
6869 coordinator_address ,
@@ -217,7 +218,8 @@ def initialize(coordinator_address: str | None = None,
217218 cluster_detection_method: An optional string to attempt to autodetect the configuration of the distributed
218219 run. Note that "mpi4py" method requires you to have a working ``mpi4py`` install in your environment,
219220 and launch the applicatoin with an MPI-compatible job launcher such as ``mpiexec`` or ``mpirun``.
220- Legacy auto-detect options (OMPI, Slurm) remain enabled.
221+ Legacy auto-detect options "ompi" (OMPI) and "slurm" (Slurm) remain enabled. "deactivate" bypasses
222+ automatic cluster detection.
221223 initialization_timeout: Time period (in seconds) for which connection will
222224 be retried. If the initialization takes more than the timeout specified,
223225 the initialization will error. Defaults to 300 secs i.e. 5 mins.
0 commit comments