Skip to content

Commit 2f7a9ad

Browse files
committed
added single gpu train test
1 parent 9ecb1f2 commit 2f7a9ad

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

pytorch_lightning/models/trainer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -393,7 +393,9 @@ def validate(self, model, dataloader, max_batches):
393393
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
394394

395395
elif self.single_gpu:
396-
output = model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_i)
396+
gpu_id = self.data_parallel_device_ids[0]
397+
data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)]
398+
output = model(data_batch, batch_i)
397399

398400
else:
399401
output = model.validation_step(data_batch, batch_i)
@@ -474,7 +476,7 @@ def fit(self, model):
474476
self.__dp_train(model)
475477

476478
elif self.single_gpu:
477-
self.__single_gpu_train(model)\
479+
self.__single_gpu_train(model)
478480

479481
# ON CPU
480482
else:
@@ -846,7 +848,10 @@ def __run_tng_batch(self, data_batch, batch_nb):
846848
output = self.model(data_batch, batch_nb)
847849
output = reduce_distributed_output(output, len(self.data_parallel_device_ids))
848850
elif self.single_gpu:
851+
gpu_id = self.data_parallel_device_ids[0]
852+
data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)]
849853
output = self.model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb)
854+
850855
else:
851856
output = self.model.training_step(data_batch, batch_nb)
852857

0 commit comments

Comments
 (0)