Skip to content

Commit 8cd764a

Browse files
removed reduce on non-loss outputs from dp (#78)
* removed reduce on non-loss outputs from dp * fixed val reduce * fixed val reduce * fixed val reduce * fixed val reduce
1 parent fcea397 commit 8cd764a

File tree

2 files changed

+31
-9
lines changed

2 files changed

+31
-9
lines changed

examples/new_project_templates/lightning_module_template.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,12 +151,23 @@ def validation_end(self, outputs):
151151
val_loss_mean = 0
152152
val_acc_mean = 0
153153
for output in outputs:
154-
val_loss_mean += output['val_loss']
155-
val_acc_mean += output['val_acc']
154+
val_loss = output['val_loss']
155+
156+
# reduce manually when using dp
157+
if self.trainer.use_dp:
158+
val_loss = torch.mean(val_loss)
159+
val_loss_mean += val_loss
160+
161+
# reduce manually when using dp
162+
val_acc = output['val_acc']
163+
if self.trainer.use_dp:
164+
val_acc_mean = torch.mean(val_acc)
165+
166+
val_acc_mean += val_acc_mean
156167

157168
val_loss_mean /= len(outputs)
158169
val_acc_mean /= len(outputs)
159-
tqdm_dic = {'val_loss': val_loss_mean.item(), 'val_acc': val_acc_mean.item()}
170+
tqdm_dic = {'val_loss': val_loss_mean, 'val_acc': val_acc_mean}
160171
return tqdm_dic
161172

162173
# ---------------------

pytorch_lightning/models/trainer.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -399,13 +399,14 @@ def validate(self, model, dataloader, max_batches):
399399
output = model(data_batch, batch_i)
400400
elif self.use_dp:
401401
output = model(data_batch, batch_i)
402-
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
403-
404402
elif self.single_gpu:
403+
# put inputs on gpu manually
405404
gpu_id = self.data_parallel_device_ids[0]
406405
for i, x in enumerate(data_batch):
407406
if isinstance(x, torch.Tensor):
408407
data_batch[i] = x.cuda(gpu_id)
408+
409+
# do non dp, ddp step
409410
output = model.validation_step(data_batch, batch_i)
410411

411412
else:
@@ -862,7 +863,6 @@ def __run_tng_batch(self, data_batch, batch_nb):
862863
output = self.model(data_batch, batch_nb)
863864
elif self.use_dp:
864865
output = self.model(data_batch, batch_nb)
865-
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
866866
elif self.single_gpu:
867867
gpu_id = self.data_parallel_device_ids[0]
868868
for i, x in enumerate(data_batch):
@@ -874,7 +874,14 @@ def __run_tng_batch(self, data_batch, batch_nb):
874874
output = self.model.training_step(data_batch, batch_nb)
875875

876876
try:
877-
model_specific_tqdm_metrics_dic = output['prog']
877+
prog_output = output['prog']
878+
879+
# reduce prog metrics for tqdm when using dp
880+
if self.use_dp:
881+
nb_gpus = len(self.data_parallel_device_ids)
882+
prog_output = reduce_distributed_output(prog_output, nb_gpus)
883+
884+
model_specific_tqdm_metrics_dic = prog_output
878885
except Exception:
879886
model_specific_tqdm_metrics_dic = {}
880887

@@ -886,6 +893,10 @@ def __run_tng_batch(self, data_batch, batch_nb):
886893
if type(output) is torch.Tensor:
887894
loss = output
888895

896+
# when using dp need to reduce the loss
897+
if self.use_dp:
898+
loss = reduce_distributed_output(loss, len(self.data_parallel_device_ids))
899+
889900
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
890901

891902
# backward pass
@@ -968,12 +979,12 @@ def __run_validation(self):
968979
# use full val set on end of epoch
969980
# use a small portion otherwise
970981
max_batches = None if not self.fast_dev_run else 1
971-
model_specific_tqdm_metrics_dic = self.validate(
982+
validation_results = self.validate(
972983
self.model,
973984
self.val_dataloader,
974985
max_batches
975986
)
976-
self.__add_tqdm_metrics(model_specific_tqdm_metrics_dic)
987+
self.__add_tqdm_metrics(validation_results)
977988

978989
# hook
979990
if self.__is_function_implemented('on_post_performance_check'):

0 commit comments

Comments
 (0)