Skip to content

Commit ab49957

Browse files
committed
added single gpu train test
1 parent 56f1669 commit ab49957

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

pytorch_lightning/models/trainer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,9 @@ def validate(self, model, dataloader, max_batches):
394394

395395
elif self.single_gpu:
396396
gpu_id = self.data_parallel_device_ids[0]
397-
data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)]
397+
for i, x in enumerate(data_batch):
398+
if isinstance(x, torch.Tensor):
399+
data_batch[i] = x.cuda(gpu_id)
398400
output = model.validation_step(data_batch, batch_i)
399401

400402
else:
@@ -849,7 +851,9 @@ def __run_tng_batch(self, data_batch, batch_nb):
849851
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
850852
elif self.single_gpu:
851853
gpu_id = self.data_parallel_device_ids[0]
852-
data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)]
854+
for i, x in enumerate(data_batch):
855+
if isinstance(x, torch.Tensor):
856+
data_batch[i] = x.cuda(gpu_id)
853857
output = self.model.training_step(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb)
854858

855859
else:

0 commit comments

Comments
 (0)