11import functools
22import logging
3- import math
43import time
54import warnings
65import weakref
@@ -686,7 +685,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
686685 `seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
687686 Iteration and epoch values are 0-based: the first iteration or epoch is zero.
688687
689- This method does not remove any custom attributes added by user.
688+ This method does not remove any custom attributs added by user.
690689
691690 Args:
692691 state_dict: a dict with parameters
@@ -731,14 +730,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
731730
732731 @staticmethod
733732 def _is_done (state : State ) -> bool :
734- is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
735- is_done_count = (
736- state .epoch_length is not None
737- and state .max_epochs is not None
738- and state .iteration >= state .epoch_length * state .max_epochs
739- )
740- is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
741- return is_done_iters or is_done_count or is_done_epochs
733+ return state .iteration == state .epoch_length * state .max_epochs # type: ignore[operator]
742734
743735 def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
744736 """Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -780,15 +772,14 @@ def run(
780772 self ,
781773 data : Optional [Iterable ] = None ,
782774 max_epochs : Optional [int ] = None ,
783- max_iters : Optional [int ] = None ,
784775 epoch_length : Optional [int ] = None ,
785776 seed : Optional [int ] = None ,
786777 ) -> State :
787778 """Runs the ``process_function`` over the passed data.
788779
789780 Engine has a state and the following logic is applied in this function:
790781
791- - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, `seed`, if provided.
782+ - At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed`, if provided.
792783 A timer for total and per-epoch time is initialized when Events.STARTED is handled.
793784 - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
794785 provided, state is kept and used in the function.
@@ -806,8 +797,6 @@ def run(
806797 `len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
807798 determined as the iteration on which data iterator raises `StopIteration`.
808799 This argument should not change if run is resuming from a state.
809- max_iters: Number of iterations to run for.
810- `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
811800 seed: Deprecated argument. Please, use `torch.manual_seed` or :meth:`~ignite.utils.manual_seed`.
812801
813802 Returns:
@@ -866,6 +855,8 @@ def switch_batch(engine):
866855
867856 if self .state .max_epochs is None or (self ._is_done (self .state ) and self ._internal_run_generator is None ):
868857 # Create new state
858+ if max_epochs is None :
859+ max_epochs = 1
869860 if epoch_length is None :
870861 if data is None :
871862 raise ValueError ("epoch_length should be provided if data is None" )
@@ -874,22 +865,9 @@ def switch_batch(engine):
874865 if epoch_length is not None and epoch_length < 1 :
875866 raise ValueError ("Input data has zero size. Please provide non-empty data" )
876867
877- if max_iters is None :
878- if max_epochs is None :
879- max_epochs = 1
880- else :
881- if max_epochs is not None :
882- raise ValueError (
883- "Arguments max_iters and max_epochs are mutually exclusive."
884- "Please provide only max_epochs or max_iters."
885- )
886- if epoch_length is not None :
887- max_epochs = math .ceil (max_iters / epoch_length )
888-
889868 self .state .iteration = 0
890869 self .state .epoch = 0
891870 self .state .max_epochs = max_epochs
892- self .state .max_iters = max_iters
893871 self .state .epoch_length = epoch_length
894872 # Reset generator if previously used
895873 self ._internal_run_generator = None
@@ -984,6 +962,7 @@ def _internal_run_as_gen(self) -> Generator:
984962 self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
985963
986964 handlers_start_time = time .time ()
965+
987966 self ._fire_event (Events .EPOCH_COMPLETED )
988967 epoch_time_taken += time .time () - handlers_start_time
989968 # update time wrt handlers
@@ -1062,8 +1041,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10621041 if self .state .epoch_length is None :
10631042 # Define epoch length and stop the epoch
10641043 self .state .epoch_length = iter_counter
1065- if self .state .max_iters is not None :
1066- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
10671044 break
10681045
10691046 # Should exit while loop if we can not iterate
0 commit comments