diff --git a/utils/batcher.py b/utils/batcher.py index a192195..17e0f59 100755 --- a/utils/batcher.py +++ b/utils/batcher.py @@ -56,16 +56,16 @@ def get_training_batch(self, batch_size): if self.current_state == 0: random.shuffle(self.training_indices) - if (self.current_state + batch_size) > (len(self.training_indices) + 1): + next_state = self.current_state + batch_size + + if next_state > len(self.training_indices): self.current_state = 0 return self.get_training_batch(batch_size) - else: - self.current_state += batch_size - batch_indices = self.training_indices[self.current_state:(self.current_state + batch_size)] - if len(batch_indices) != batch_size: - self.current_state = 0 - return self.get_training_batch(batch_size) - return self.data_handler.slice_data(batch_indices) + + batch_indices = self.training_indices[self.current_state:next_state] + self.current_state = next_state + + return self.data_handler.slice_data(batch_indices) def get_validation_batch(self, batch_size): """