Skip to content

Commit 56f1669

Browse files
committed
added single gpu train test
1 parent 2f7a9ad commit 56f1669

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytorch_lightning/models/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def validate(self, model, dataloader, max_batches):
395395
elif self.single_gpu:
396396
gpu_id = self.data_parallel_device_ids[0]
397397
data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)]
398-
output = model(data_batch, batch_i)
398+
output = model.validation_step(data_batch, batch_i)
399399

400400
else:
401401
output = model.validation_step(data_batch, batch_i)
@@ -850,7 +850,7 @@ def __run_tng_batch(self, data_batch, batch_nb):
850850
elif self.single_gpu:
851851
gpu_id = self.data_parallel_device_ids[0]
852852
data_batch = [x.cuda(gpu_id) for x in data_batch if isinstance(x, torch.Tensor)]
853-
output = self.model(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb)
853+
output = self.model.training_step(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb)
854854

855855
else:
856856
output = self.model.training_step(data_batch, batch_nb)

0 commit comments

Comments
 (0)