11import functools
22import logging
3- import math
43import time
54import warnings
65import weakref
@@ -731,14 +730,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
731730
732731 @staticmethod
733732 def _is_done (state : State ) -> bool :
734- is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
735733 is_done_count = (
736734 state .epoch_length is not None
737735 and state .max_epochs is not None
738736 and state .iteration >= state .epoch_length * state .max_epochs
739737 )
740738 is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
741- return is_done_iters or is_done_count or is_done_epochs
739+ return is_done_count or is_done_epochs
742740
743741 def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
744742 """Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -780,14 +778,13 @@ def run(
780778 self ,
781779 data : Optional [Iterable ] = None ,
782780 max_epochs : Optional [int ] = None ,
783- max_iters : Optional [int ] = None ,
784781 epoch_length : Optional [int ] = None ,
785782 ) -> State :
786783 """Runs the ``process_function`` over the passed data.
787784
788785 Engine has a state and the following logic is applied in this function:
789786
790- - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, if provided.
787+ - At the first call, new state is defined by `max_epochs`, `epoch_length`, if provided.
791788 A timer for total and per-epoch time is initialized when Events.STARTED is handled.
792789 - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
793790 provided, state is kept and used in the function.
@@ -805,9 +802,6 @@ def run(
805802 `len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
806803 determined as the iteration on which data iterator raises `StopIteration`.
807804 This argument should not change if run is resuming from a state.
808- max_iters: Number of iterations to run for.
809- `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
810-
811805 Returns:
812806 State: output state.
813807
@@ -858,6 +852,8 @@ def switch_batch(engine):
858852
859853 if self .state .max_epochs is None or (self ._is_done (self .state ) and self ._internal_run_generator is None ):
860854 # Create new state
855+ if max_epochs is None :
856+ max_epochs = 1
861857 if epoch_length is None :
862858 if data is None :
863859 raise ValueError ("epoch_length should be provided if data is None" )
@@ -866,22 +862,9 @@ def switch_batch(engine):
866862 if epoch_length is not None and epoch_length < 1 :
867863 raise ValueError ("Input data has zero size. Please provide non-empty data" )
868864
869- if max_iters is None :
870- if max_epochs is None :
871- max_epochs = 1
872- else :
873- if max_epochs is not None :
874- raise ValueError (
875- "Arguments max_iters and max_epochs are mutually exclusive."
876- "Please provide only max_epochs or max_iters."
877- )
878- if epoch_length is not None :
879- max_epochs = math .ceil (max_iters / epoch_length )
880-
881865 self .state .iteration = 0
882866 self .state .epoch = 0
883867 self .state .max_epochs = max_epochs
884- self .state .max_iters = max_iters
885868 self .state .epoch_length = epoch_length
886869 # Reset generator if previously used
887870 self ._internal_run_generator = None
@@ -1062,18 +1045,12 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10621045 if self .state .epoch_length is None :
10631046 # Define epoch length and stop the epoch
10641047 self .state .epoch_length = iter_counter
1065- if self .state .max_iters is not None :
1066- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
10671048 break
10681049
10691050 # Should exit while loop if we can not iterate
10701051 if should_exit :
1071- if not self ._is_done (self .state ):
1072- total_iters = (
1073- self .state .epoch_length * self .state .max_epochs
1074- if self .state .max_epochs is not None
1075- else self .state .max_iters
1076- )
1052+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1053+ total_iters = self .state .epoch_length * self .state .max_epochs
10771054
10781055 warnings .warn (
10791056 "Data iterator can not provide data anymore but required total number of "
@@ -1104,10 +1081,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11041081 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
11051082 break
11061083
1107- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1108- self .should_terminate = True
1109- raise _EngineTerminateException ()
1110-
11111084 except _EngineTerminateSingleEpochException :
11121085 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
11131086 self .should_terminate_single_epoch = False
@@ -1229,18 +1202,12 @@ def _run_once_on_dataset_legacy(self) -> float:
12291202 if self .state .epoch_length is None :
12301203 # Define epoch length and stop the epoch
12311204 self .state .epoch_length = iter_counter
1232- if self .state .max_iters is not None :
1233- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
12341205 break
12351206
12361207 # Should exit while loop if we can not iterate
12371208 if should_exit :
1238- if not self ._is_done (self .state ):
1239- total_iters = (
1240- self .state .epoch_length * self .state .max_epochs
1241- if self .state .max_epochs is not None
1242- else self .state .max_iters
1243- )
1209+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1210+ total_iters = self .state .epoch_length * self .state .max_epochs
12441211
12451212 warnings .warn (
12461213 "Data iterator can not provide data anymore but required total number of "
@@ -1271,10 +1238,6 @@ def _run_once_on_dataset_legacy(self) -> float:
12711238 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
12721239 break
12731240
1274- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1275- self .should_terminate = True
1276- raise _EngineTerminateException ()
1277-
12781241 except _EngineTerminateSingleEpochException :
12791242 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
12801243 self .should_terminate_single_epoch = False
0 commit comments