Skip to content

Commit 39f5030

Browse files
committed
Revert "Adding max_iters as an optional arg in Engine run (#1381)"
This reverts commit 307ac11.
1 parent 56c3cf5 commit 39f5030

File tree

2 files changed

+6
-31
lines changed

2 files changed

+6
-31
lines changed

ignite/engine/engine.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import logging
3-
import math
43
import time
54
import warnings
65
import weakref
@@ -686,7 +685,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
686685
`seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
687686
Iteration and epoch values are 0-based: the first iteration or epoch is zero.
688687
689-
This method does not remove any custom attributes added by user.
688+
This method does not remove any custom attributs added by user.
690689
691690
Args:
692691
state_dict: a dict with parameters
@@ -731,14 +730,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
731730

732731
@staticmethod
733732
def _is_done(state: State) -> bool:
734-
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
735-
is_done_count = (
736-
state.epoch_length is not None
737-
and state.max_epochs is not None
738-
and state.iteration >= state.epoch_length * state.max_epochs
739-
)
740-
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
741-
return is_done_iters or is_done_count or is_done_epochs
733+
return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator]
742734

743735
def set_data(self, data: Union[Iterable, DataLoader]) -> None:
744736
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -780,15 +772,14 @@ def run(
780772
self,
781773
data: Optional[Iterable] = None,
782774
max_epochs: Optional[int] = None,
783-
max_iters: Optional[int] = None,
784775
epoch_length: Optional[int] = None,
785776
seed: Optional[int] = None,
786777
) -> State:
787778
"""Runs the ``process_function`` over the passed data.
788779
789780
Engine has a state and the following logic is applied in this function:
790781
791-
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, `seed`, if provided.
782+
- At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed`, if provided.
792783
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
793784
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
794785
provided, state is kept and used in the function.
@@ -806,8 +797,6 @@ def run(
806797
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
807798
determined as the iteration on which data iterator raises `StopIteration`.
808799
This argument should not change if run is resuming from a state.
809-
max_iters: Number of iterations to run for.
810-
`max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
811800
seed: Deprecated argument. Please, use `torch.manual_seed` or :meth:`~ignite.utils.manual_seed`.
812801
813802
Returns:
@@ -866,6 +855,8 @@ def switch_batch(engine):
866855

867856
if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None):
868857
# Create new state
858+
if max_epochs is None:
859+
max_epochs = 1
869860
if epoch_length is None:
870861
if data is None:
871862
raise ValueError("epoch_length should be provided if data is None")
@@ -874,22 +865,9 @@ def switch_batch(engine):
874865
if epoch_length is not None and epoch_length < 1:
875866
raise ValueError("Input data has zero size. Please provide non-empty data")
876867

877-
if max_iters is None:
878-
if max_epochs is None:
879-
max_epochs = 1
880-
else:
881-
if max_epochs is not None:
882-
raise ValueError(
883-
"Arguments max_iters and max_epochs are mutually exclusive."
884-
"Please provide only max_epochs or max_iters."
885-
)
886-
if epoch_length is not None:
887-
max_epochs = math.ceil(max_iters / epoch_length)
888-
889868
self.state.iteration = 0
890869
self.state.epoch = 0
891870
self.state.max_epochs = max_epochs
892-
self.state.max_iters = max_iters
893871
self.state.epoch_length = epoch_length
894872
# Reset generator if previously used
895873
self._internal_run_generator = None
@@ -984,6 +962,7 @@ def _internal_run_as_gen(self) -> Generator:
984962
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
985963

986964
handlers_start_time = time.time()
965+
987966
self._fire_event(Events.EPOCH_COMPLETED)
988967
epoch_time_taken += time.time() - handlers_start_time
989968
# update time wrt handlers
@@ -1062,8 +1041,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10621041
if self.state.epoch_length is None:
10631042
# Define epoch length and stop the epoch
10641043
self.state.epoch_length = iter_counter
1065-
if self.state.max_iters is not None:
1066-
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
10671044
break
10681045

10691046
# Should exit while loop if we can not iterate

ignite/engine/events.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,6 @@ class State:
454454
state.dataloader # data passed to engine
455455
state.epoch_length # optional length of an epoch
456456
state.max_epochs # number of epochs to run
457-
state.max_iters # number of iterations to run
458457
state.batch # batch passed to `process_function`
459458
state.output # output of `process_function` after a single iteration
460459
state.metrics # dictionary with defined metrics if any
@@ -481,7 +480,6 @@ def __init__(self, **kwargs: Any) -> None:
481480
self.epoch = 0
482481
self.epoch_length: Optional[int] = None
483482
self.max_epochs: Optional[int] = None
484-
self.max_iters: Optional[int] = None
485483
self.output: Optional[int] = None
486484
self.batch: Optional[int] = None
487485
self.metrics: Dict[str, Any] = {}

0 commit comments

Comments
 (0)