Skip to content

Commit f3dea81

Browse files
committed
added single gpu train test
1 parent ab49957 commit f3dea81

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

pytorch_lightning/models/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,7 +854,7 @@ def __run_tng_batch(self, data_batch, batch_nb):
854854
for i, x in enumerate(data_batch):
855855
if isinstance(x, torch.Tensor):
856856
data_batch[i] = x.cuda(gpu_id)
857-
output = self.model.training_step(data_batch.cuda(self.data_parallel_device_ids[0]), batch_nb)
857+
output = self.model.training_step(data_batch, batch_nb)
858858

859859
else:
860860
output = self.model.training_step(data_batch, batch_nb)

0 commit comments

Comments
 (0)