11import functools
22import logging
3- import math
43import time
54import warnings
65import weakref
@@ -747,14 +746,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
747746
748747 @staticmethod
749748 def _is_done (state : State ) -> bool :
750- is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
751749 is_done_count = (
752750 state .epoch_length is not None
753751 and state .max_epochs is not None
754752 and state .iteration >= state .epoch_length * state .max_epochs
755753 )
756754 is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
757- return is_done_iters or is_done_count or is_done_epochs
755+ return is_done_count or is_done_epochs
758756
759757 def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
760758 """Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -796,14 +794,13 @@ def run(
796794 self ,
797795 data : Optional [Iterable ] = None ,
798796 max_epochs : Optional [int ] = None ,
799- max_iters : Optional [int ] = None ,
800797 epoch_length : Optional [int ] = None ,
801798 ) -> State :
802799 """Runs the ``process_function`` over the passed data.
803800
804801 Engine has a state and the following logic is applied in this function:
805802
806- - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, if provided.
803+ - At the first call, new state is defined by `max_epochs`, `epoch_length`, if provided.
807804 A timer for total and per-epoch time is initialized when Events.STARTED is handled.
808805 - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
809806 provided, state is kept and used in the function.
@@ -821,9 +818,6 @@ def run(
821818 `len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
822819 determined as the iteration on which data iterator raises `StopIteration`.
823820 This argument should not change if run is resuming from a state.
824- max_iters: Number of iterations to run for.
825- `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
826-
827821 Returns:
828822 State: output state.
829823
@@ -874,6 +868,8 @@ def switch_batch(engine):
874868
875869 if self .state .max_epochs is None or (self ._is_done (self .state ) and self ._internal_run_generator is None ):
876870 # Create new state
871+ if max_epochs is None :
872+ max_epochs = 1
877873 if epoch_length is None :
878874 if data is None :
879875 raise ValueError ("epoch_length should be provided if data is None" )
@@ -882,22 +878,9 @@ def switch_batch(engine):
882878 if epoch_length is not None and epoch_length < 1 :
883879 raise ValueError ("Input data has zero size. Please provide non-empty data" )
884880
885- if max_iters is None :
886- if max_epochs is None :
887- max_epochs = 1
888- else :
889- if max_epochs is not None :
890- raise ValueError (
891- "Arguments max_iters and max_epochs are mutually exclusive."
892- "Please provide only max_epochs or max_iters."
893- )
894- if epoch_length is not None :
895- max_epochs = math .ceil (max_iters / epoch_length )
896-
897881 self .state .iteration = 0
898882 self .state .epoch = 0
899883 self .state .max_epochs = max_epochs
900- self .state .max_iters = max_iters
901884 self .state .epoch_length = epoch_length
902885 # Reset generator if previously used
903886 self ._internal_run_generator = None
@@ -1095,18 +1078,12 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10951078 if self .state .epoch_length is None :
10961079 # Define epoch length and stop the epoch
10971080 self .state .epoch_length = iter_counter
1098- if self .state .max_iters is not None :
1099- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
11001081 break
11011082
11021083 # Should exit while loop if we can not iterate
11031084 if should_exit :
1104- if not self ._is_done (self .state ):
1105- total_iters = (
1106- self .state .epoch_length * self .state .max_epochs
1107- if self .state .max_epochs is not None
1108- else self .state .max_iters
1109- )
1085+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1086+ total_iters = self .state .epoch_length * self .state .max_epochs
11101087
11111088 warnings .warn (
11121089 "Data iterator can not provide data anymore but required total number of "
@@ -1137,10 +1114,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11371114 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
11381115 break
11391116
1140- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1141- self .should_terminate = True
1142- raise _EngineTerminateException ()
1143-
11441117 except _EngineTerminateSingleEpochException :
11451118 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
11461119 self ._setup_dataloader_iter ()
@@ -1278,18 +1251,12 @@ def _run_once_on_dataset_legacy(self) -> float:
12781251 if self .state .epoch_length is None :
12791252 # Define epoch length and stop the epoch
12801253 self .state .epoch_length = iter_counter
1281- if self .state .max_iters is not None :
1282- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
12831254 break
12841255
12851256 # Should exit while loop if we can not iterate
12861257 if should_exit :
1287- if not self ._is_done (self .state ):
1288- total_iters = (
1289- self .state .epoch_length * self .state .max_epochs
1290- if self .state .max_epochs is not None
1291- else self .state .max_iters
1292- )
1258+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1259+ total_iters = self .state .epoch_length * self .state .max_epochs
12931260
12941261 warnings .warn (
12951262 "Data iterator can not provide data anymore but required total number of "
@@ -1320,10 +1287,6 @@ def _run_once_on_dataset_legacy(self) -> float:
13201287 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
13211288 break
13221289
1323- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1324- self .should_terminate = True
1325- raise _EngineTerminateException ()
1326-
13271290 except _EngineTerminateSingleEpochException :
13281291 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
13291292 self ._setup_dataloader_iter ()
0 commit comments