@@ -399,13 +399,14 @@ def validate(self, model, dataloader, max_batches):
399399 output = model (data_batch , batch_i )
400400 elif self .use_dp :
401401 output = model (data_batch , batch_i )
402- output = reduce_distributed_output (output , len (self .data_parallel_device_ids ))
403-
404402 elif self .single_gpu :
403+ # put inputs on gpu manually
405404 gpu_id = self .data_parallel_device_ids [0 ]
406405 for i , x in enumerate (data_batch ):
407406 if isinstance (x , torch .Tensor ):
408407 data_batch [i ] = x .cuda (gpu_id )
408+
409+ # do non dp, ddp step
409410 output = model .validation_step (data_batch , batch_i )
410411
411412 else :
@@ -862,7 +863,6 @@ def __run_tng_batch(self, data_batch, batch_nb):
862863 output = self .model (data_batch , batch_nb )
863864 elif self .use_dp :
864865 output = self .model (data_batch , batch_nb )
865- output = reduce_distributed_output (output , len (self .data_parallel_device_ids ))
866866 elif self .single_gpu :
867867 gpu_id = self .data_parallel_device_ids [0 ]
868868 for i , x in enumerate (data_batch ):
@@ -874,7 +874,14 @@ def __run_tng_batch(self, data_batch, batch_nb):
874874 output = self .model .training_step (data_batch , batch_nb )
875875
876876 try :
877- model_specific_tqdm_metrics_dic = output ['prog' ]
877+ prog_output = output ['prog' ]
878+
879+ # reduce prog metrics for tqdm when using dp
880+ if self .use_dp :
881+ nb_gpus = len (self .data_parallel_device_ids )
882+ prog_output = reduce_distributed_output (prog_output , nb_gpus )
883+
884+ model_specific_tqdm_metrics_dic = prog_output
878885 except Exception :
879886 model_specific_tqdm_metrics_dic = {}
880887
@@ -886,6 +893,10 @@ def __run_tng_batch(self, data_batch, batch_nb):
886893 if type (output ) is torch .Tensor :
887894 loss = output
888895
896+ # when using dp need to reduce the loss
897+ if self .use_dp :
898+ loss = reduce_distributed_output (loss , len (self .data_parallel_device_ids ))
899+
889900 self .__add_tqdm_metrics (model_specific_tqdm_metrics_dic )
890901
891902 # backward pass
@@ -968,12 +979,12 @@ def __run_validation(self):
968979 # use full val set on end of epoch
969980 # use a small portion otherwise
970981 max_batches = None if not self .fast_dev_run else 1
971- model_specific_tqdm_metrics_dic = self .validate (
982+ validation_results = self .validate (
972983 self .model ,
973984 self .val_dataloader ,
974985 max_batches
975986 )
976- self .__add_tqdm_metrics (model_specific_tqdm_metrics_dic )
987+ self .__add_tqdm_metrics (validation_results )
977988
978989 # hook
979990 if self .__is_function_implemented ('on_post_performance_check' ):
0 commit comments