@@ -399,13 +399,14 @@ def validate(self, model, dataloader, max_batches):
399
399
output = model (data_batch , batch_i )
400
400
elif self .use_dp :
401
401
output = model (data_batch , batch_i )
402
- output = reduce_distributed_output (output , len (self .data_parallel_device_ids ))
403
-
404
402
elif self .single_gpu :
403
+ # put inputs on gpu manually
405
404
gpu_id = self .data_parallel_device_ids [0 ]
406
405
for i , x in enumerate (data_batch ):
407
406
if isinstance (x , torch .Tensor ):
408
407
data_batch [i ] = x .cuda (gpu_id )
408
+
409
+ # do non dp, ddp step
409
410
output = model .validation_step (data_batch , batch_i )
410
411
411
412
else :
@@ -862,7 +863,6 @@ def __run_tng_batch(self, data_batch, batch_nb):
862
863
output = self .model (data_batch , batch_nb )
863
864
elif self .use_dp :
864
865
output = self .model (data_batch , batch_nb )
865
- output = reduce_distributed_output (output , len (self .data_parallel_device_ids ))
866
866
elif self .single_gpu :
867
867
gpu_id = self .data_parallel_device_ids [0 ]
868
868
for i , x in enumerate (data_batch ):
@@ -874,7 +874,14 @@ def __run_tng_batch(self, data_batch, batch_nb):
874
874
output = self .model .training_step (data_batch , batch_nb )
875
875
876
876
try :
877
- model_specific_tqdm_metrics_dic = output ['prog' ]
877
+ prog_output = output ['prog' ]
878
+
879
+ # reduce prog metrics for tqdm when using dp
880
+ if self .use_dp :
881
+ nb_gpus = len (self .data_parallel_device_ids )
882
+ prog_output = reduce_distributed_output (prog_output , nb_gpus )
883
+
884
+ model_specific_tqdm_metrics_dic = prog_output
878
885
except Exception :
879
886
model_specific_tqdm_metrics_dic = {}
880
887
@@ -886,6 +893,10 @@ def __run_tng_batch(self, data_batch, batch_nb):
886
893
if type (output ) is torch .Tensor :
887
894
loss = output
888
895
896
+ # when using dp need to reduce the loss
897
+ if self .use_dp :
898
+ loss = reduce_distributed_output (loss , len (self .data_parallel_device_ids ))
899
+
889
900
self .__add_tqdm_metrics (model_specific_tqdm_metrics_dic )
890
901
891
902
# backward pass
@@ -968,12 +979,12 @@ def __run_validation(self):
968
979
# use full val set on end of epoch
969
980
# use a small portion otherwise
970
981
max_batches = None if not self .fast_dev_run else 1
971
- model_specific_tqdm_metrics_dic = self .validate (
982
+ validation_results = self .validate (
972
983
self .model ,
973
984
self .val_dataloader ,
974
985
max_batches
975
986
)
976
- self .__add_tqdm_metrics (model_specific_tqdm_metrics_dic )
987
+ self .__add_tqdm_metrics (validation_results )
977
988
978
989
# hook
979
990
if self .__is_function_implemented ('on_post_performance_check' ):
0 commit comments