@@ -153,6 +153,16 @@ def train_fx(trial_hparams, cluster_manager, _):
153153 HYDRA_AVAILABLE = True
154154
155155
156+ try :
157+ import torch_xla
158+ import torch_xla .core .xla_model as xm
159+ import torch_xla .distributed .xla_multiprocessing as xmp
160+ except ImportError :
161+ XLA_AVAILABLE = False
162+ else :
163+ XLA_AVAILABLE = True
164+
165+
156166class TrainerDDPMixin (ABC ):
157167
158168 # this is just a summary on variables used in this abstract class,
@@ -172,6 +182,7 @@ class TrainerDDPMixin(ABC):
172182 num_processes : int
173183 num_nodes : int
174184 node_rank : int
185+ tpu_cores : int
175186
176187 @property
177188 def is_global_zero (self ) -> int :
@@ -277,6 +288,8 @@ def set_distributed_mode(self, distributed_backend):
277288 )
278289
279290 rank_zero_info (f'GPU available: { torch .cuda .is_available ()} , used: { self .on_gpu } ' )
291+ num_cores = self .tpu_cores if self .tpu_cores is not None else 0
292+ rank_zero_info (f'TPU available: { XLA_AVAILABLE } , using: { num_cores } TPU cores' )
280293
281294 def configure_slurm_ddp (self , num_gpu_nodes ):
282295 self .is_slurm_managing_tasks = False
@@ -329,7 +342,6 @@ def determine_ddp_node_rank(self):
329342 node_ids = [(k , os .environ .get (k , None )) for k in env_vars ]
330343 node_ids = [(k , v ) for k , v in node_ids if v is not None ]
331344 if len (node_ids ) == 0 :
332- log .warning ("No environment variable for node rank defined. Set as 0." )
333345 return 0
334346 if len (node_ids ) > 1 :
335347 log .warning (f"Multiple environment variables ({ node_ids } ) defined for node rank. "
0 commit comments