Skip to content

Commit 9ecb1f2

Browse files
committed
added single gpu train test
1 parent e305149 commit 9ecb1f2

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

pytorch_lightning/models/trainer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,9 @@ def validate(self, model, dataloader, max_batches):
392392
output = model(data_batch, batch_i)
393393
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
394394

395+
elif self.single_gpu:
396+
output = model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_i)
397+
395398
else:
396399
output = model.validation_step(data_batch, batch_i)
397400

@@ -842,6 +845,8 @@ def __run_tng_batch(self, data_batch, batch_nb):
842845
elif self.use_dp:
843846
output = self.model(data_batch, batch_nb)
844847
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
848+
elif self.single_gpu:
849+
output = self.model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb)
845850
else:
846851
output = self.model.training_step(data_batch, batch_nb)
847852

0 commit comments

Comments
 (0)