11import  functools 
22import  logging 
3- import  math 
43import  time 
54import  warnings 
65import  weakref 
@@ -731,14 +730,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
731730
732731    @staticmethod  
733732    def  _is_done (state : State ) ->  bool :
734-         is_done_iters  =  state .max_iters  is  not   None  and  state .iteration  >=  state .max_iters 
735733        is_done_count  =  (
736734            state .epoch_length  is  not   None 
737735            and  state .max_epochs  is  not   None 
738736            and  state .iteration  >=  state .epoch_length  *  state .max_epochs 
739737        )
740738        is_done_epochs  =  state .max_epochs  is  not   None  and  state .epoch  >=  state .max_epochs 
741-         return  is_done_iters   or   is_done_count  or  is_done_epochs 
739+         return  is_done_count  or  is_done_epochs 
742740
743741    def  set_data (self , data : Union [Iterable , DataLoader ]) ->  None :
744742        """Method to set data. After calling the method the next batch passed to `processing_function` is 
@@ -780,14 +778,13 @@ def run(
780778        self ,
781779        data : Optional [Iterable ] =  None ,
782780        max_epochs : Optional [int ] =  None ,
783-         max_iters : Optional [int ] =  None ,
784781        epoch_length : Optional [int ] =  None ,
785782    ) ->  State :
786783        """Runs the ``process_function`` over the passed data. 
787784
788785        Engine has a state and the following logic is applied in this function: 
789786
790-         - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, if provided. 
787+         - At the first call, new state is defined by `max_epochs`, `epoch_length`, if provided. 
791788          A timer for total and per-epoch time is initialized when Events.STARTED is handled. 
792789        - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments 
793790          provided, state is kept and used in the function. 
@@ -805,9 +802,6 @@ def run(
805802                `len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically 
806803                determined as the iteration on which data iterator raises `StopIteration`. 
807804                This argument should not change if run is resuming from a state. 
808-             max_iters: Number of iterations to run for. 
809-                 `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided. 
810- 
811805        Returns: 
812806            State: output state. 
813807
@@ -858,6 +852,8 @@ def switch_batch(engine):
858852
859853        if  self .state .max_epochs  is  None  or  (self ._is_done (self .state ) and  self ._internal_run_generator  is  None ):
860854            # Create new state 
855+             if  max_epochs  is  None :
856+                 max_epochs  =  1 
861857            if  epoch_length  is  None :
862858                if  data  is  None :
863859                    raise  ValueError ("epoch_length should be provided if data is None" )
@@ -866,22 +862,9 @@ def switch_batch(engine):
866862                if  epoch_length  is  not   None  and  epoch_length  <  1 :
867863                    raise  ValueError ("Input data has zero size. Please provide non-empty data" )
868864
869-             if  max_iters  is  None :
870-                 if  max_epochs  is  None :
871-                     max_epochs  =  1 
872-             else :
873-                 if  max_epochs  is  not   None :
874-                     raise  ValueError (
875-                         "Arguments max_iters and max_epochs are mutually exclusive." 
876-                         "Please provide only max_epochs or max_iters." 
877-                     )
878-                 if  epoch_length  is  not   None :
879-                     max_epochs  =  math .ceil (max_iters  /  epoch_length )
880- 
881865            self .state .iteration  =  0 
882866            self .state .epoch  =  0 
883867            self .state .max_epochs  =  max_epochs 
884-             self .state .max_iters  =  max_iters 
885868            self .state .epoch_length  =  epoch_length 
886869            # Reset generator if previously used 
887870            self ._internal_run_generator  =  None 
@@ -1062,18 +1045,12 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10621045                    if  self .state .epoch_length  is  None :
10631046                        # Define epoch length and stop the epoch 
10641047                        self .state .epoch_length  =  iter_counter 
1065-                         if  self .state .max_iters  is  not   None :
1066-                             self .state .max_epochs  =  math .ceil (self .state .max_iters  /  self .state .epoch_length )
10671048                        break 
10681049
10691050                    # Should exit while loop if we can not iterate 
10701051                    if  should_exit :
1071-                         if  not  self ._is_done (self .state ):
1072-                             total_iters  =  (
1073-                                 self .state .epoch_length  *  self .state .max_epochs 
1074-                                 if  self .state .max_epochs  is  not   None 
1075-                                 else  self .state .max_iters 
1076-                             )
1052+                         if  not  self ._is_done (self .state ) and  self .state .max_epochs  is  not   None :
1053+                             total_iters  =  self .state .epoch_length  *  self .state .max_epochs 
10771054
10781055                            warnings .warn (
10791056                                "Data iterator can not provide data anymore but required total number of " 
@@ -1104,10 +1081,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11041081                if  self .state .epoch_length  is  not   None  and  iter_counter  ==  self .state .epoch_length :
11051082                    break 
11061083
1107-                 if  self .state .max_iters  is  not   None  and  self .state .iteration  ==  self .state .max_iters :
1108-                     self .should_terminate  =  True 
1109-                     raise  _EngineTerminateException ()
1110- 
11111084        except  _EngineTerminateSingleEpochException :
11121085            self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
11131086            self .should_terminate_single_epoch  =  False 
@@ -1229,18 +1202,12 @@ def _run_once_on_dataset_legacy(self) -> float:
12291202                    if  self .state .epoch_length  is  None :
12301203                        # Define epoch length and stop the epoch 
12311204                        self .state .epoch_length  =  iter_counter 
1232-                         if  self .state .max_iters  is  not   None :
1233-                             self .state .max_epochs  =  math .ceil (self .state .max_iters  /  self .state .epoch_length )
12341205                        break 
12351206
12361207                    # Should exit while loop if we can not iterate 
12371208                    if  should_exit :
1238-                         if  not  self ._is_done (self .state ):
1239-                             total_iters  =  (
1240-                                 self .state .epoch_length  *  self .state .max_epochs 
1241-                                 if  self .state .max_epochs  is  not   None 
1242-                                 else  self .state .max_iters 
1243-                             )
1209+                         if  not  self ._is_done (self .state ) and  self .state .max_epochs  is  not   None :
1210+                             total_iters  =  self .state .epoch_length  *  self .state .max_epochs 
12441211
12451212                            warnings .warn (
12461213                                "Data iterator can not provide data anymore but required total number of " 
@@ -1271,10 +1238,6 @@ def _run_once_on_dataset_legacy(self) -> float:
12711238                if  self .state .epoch_length  is  not   None  and  iter_counter  ==  self .state .epoch_length :
12721239                    break 
12731240
1274-                 if  self .state .max_iters  is  not   None  and  self .state .iteration  ==  self .state .max_iters :
1275-                     self .should_terminate  =  True 
1276-                     raise  _EngineTerminateException ()
1277- 
12781241        except  _EngineTerminateSingleEpochException :
12791242            self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
12801243            self .should_terminate_single_epoch  =  False 
0 commit comments