Skip to content

Commit c0f3b6b

Browse files
ananyahjha93williamFalcon
authored andcommitted
added set_epoch for distributed sampler, fix for #224 (#225)
1 parent e339799 commit c0f3b6b

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,10 @@ def __run_pretrain_routine(self, model):
929929
def __train(self):
930930
# run all epochs
931931
for epoch_nb in range(self.current_epoch, self.max_nb_epochs):
932+
# set seed for distributed sampler (enables shuffling for each epoch)
933+
if self.use_ddp:
934+
self.tng_dataloader.sampler.set_epoch(epoch_nb)
935+
932936
# get model
933937
model = self.__get_model()
934938

0 commit comments

Comments
 (0)