@@ -105,7 +105,7 @@ def __init__(self,
105105 :param log_save_interval:
106106 :param add_log_row_interval:
107107 :param distributed_backend:
108- 'np ' to use DistributedParallel, 'dp' to use DistributedDataParallel
108+ 'do ' to use DistributedParallel, 'dp' to use DistributedDataParallel, 'n' to use none
109109 :param use_amp:
110110 :param print_nan_grads:
111111 :param print_weights_summary:
@@ -147,6 +147,7 @@ def __init__(self,
147147 self .node_rank = 0
148148 self .use_ddp = False
149149 self .use_dp = False
150+ self .single_gpu = False
150151
151152 # training bookeeping
152153 self .total_batch_nb = 0
@@ -194,6 +195,12 @@ def __init__(self,
194195 'To silence this warning set distributed_backend=ddp'
195196 warnings .warn (w )
196197
198+ # remove dp and ddp when requesting single gpu
199+ if self .data_parallel_device_ids is not None and len (self .data_parallel_device_ids ) == 1 :
200+ self .use_ddp = False
201+ self .use_dp = False
202+ self .single_gpu = True
203+
197204 # extract SLURM flag vars
198205 # whenever we have the correct number of tasks, we let slurm manage processes
199206 # otherwise we launch the required number of processes
@@ -385,6 +392,13 @@ def validate(self, model, dataloader, max_batches):
385392 output = model (data_batch , batch_i )
386393 output = reduce_distributed_output (output , len (self .data_parallel_device_ids ))
387394
395+ elif self .single_gpu :
396+ gpu_id = self .data_parallel_device_ids [0 ]
397+ for i , x in enumerate (data_batch ):
398+ if isinstance (x , torch .Tensor ):
399+ data_batch [i ] = x .cuda (gpu_id )
400+ output = model .validation_step (data_batch , batch_i )
401+
388402 else :
389403 output = model .validation_step (data_batch , batch_i )
390404
@@ -463,6 +477,9 @@ def fit(self, model):
463477 elif self .use_dp :
464478 self .__dp_train (model )
465479
480+ elif self .single_gpu :
481+ self .__single_gpu_train (model )
482+
466483 # ON CPU
467484 else :
468485 # run through amp wrapper
@@ -482,6 +499,24 @@ def fit(self, model):
482499 # used for testing or when we need to know that training succeeded
483500 return 1
484501
502+ def __single_gpu_train (self , model ):
503+ # CHOOSE OPTIMIZER
504+ # allow for lr schedulers as well
505+ self .optimizers = model .configure_optimizers ()
506+ if len (self .optimizers ) == 2 :
507+ self .optimizers , self .lr_schedulers = self .optimizers
508+
509+ model .cuda (self .data_parallel_device_ids [0 ])
510+
511+ if self .use_amp :
512+ # An example
513+ model , optimizers = amp .initialize (
514+ model , self .optimizers , opt_level = self .amp_level ,
515+ )
516+ self .optimizers = optimizers
517+
518+ self .__run_pretrain_routine (model )
519+
485520 def __dp_train (self , model ):
486521
487522 # CHOOSE OPTIMIZER
@@ -814,6 +849,13 @@ def __run_tng_batch(self, data_batch, batch_nb):
814849 elif self .use_dp :
815850 output = self .model (data_batch , batch_nb )
816851 output = reduce_distributed_output (output , len (self .data_parallel_device_ids ))
852+ elif self .single_gpu :
853+ gpu_id = self .data_parallel_device_ids [0 ]
854+ for i , x in enumerate (data_batch ):
855+ if isinstance (x , torch .Tensor ):
856+ data_batch [i ] = x .cuda (gpu_id )
857+ output = self .model .training_step (data_batch , batch_nb )
858+
817859 else :
818860 output = self .model .training_step (data_batch , batch_nb )
819861
0 commit comments