11import functools
22import logging
3- import math
43import time
54import warnings
65import weakref
@@ -508,7 +507,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
508507 `seed`. If `engine.state_dict_user_keys` contains keys, they should be also present in the state dictionary.
509508 Iteration and epoch values are 0-based: the first iteration or epoch is zero.
510509
511- This method does not remove any custom attributes added by user.
510+ This method does not remove any custom attributs added by user.
512511
513512 Args:
514513 state_dict: a dict with parameters
@@ -553,14 +552,7 @@ def load_state_dict(self, state_dict: Mapping) -> None:
553552
554553 @staticmethod
555554 def _is_done (state : State ) -> bool :
556- is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
557- is_done_count = (
558- state .epoch_length is not None
559- and state .max_epochs is not None
560- and state .iteration >= state .epoch_length * state .max_epochs
561- )
562- is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
563- return is_done_iters or is_done_count or is_done_epochs
555+ return state .iteration == state .epoch_length * state .max_epochs # type: ignore[operator]
564556
565557 def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
566558 """Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -602,15 +594,14 @@ def run(
602594 self ,
603595 data : Optional [Iterable ] = None ,
604596 max_epochs : Optional [int ] = None ,
605- max_iters : Optional [int ] = None ,
606597 epoch_length : Optional [int ] = None ,
607598 seed : Optional [int ] = None ,
608599 ) -> State :
609600 """Runs the ``process_function`` over the passed data.
610601
611602 Engine has a state and the following logic is applied in this function:
612603
613- - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, `seed`, if provided.
604+ - At the first call, new state is defined by `max_epochs`, `epoch_length`, `seed`, if provided.
614605 A timer for total and per-epoch time is initialized when Events.STARTED is handled.
615606 - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
616607 provided, state is kept and used in the function.
@@ -628,8 +619,6 @@ def run(
628619 `len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
629620 determined as the iteration on which data iterator raises `StopIteration`.
630621 This argument should not change if run is resuming from a state.
631- max_iters: Number of iterations to run for.
632- `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
633622 seed: Deprecated argument. Please, use `torch.manual_seed` or :meth:`~ignite.utils.manual_seed`.
634623
635624 Returns:
@@ -688,6 +677,8 @@ def switch_batch(engine):
688677
689678 if self .state .max_epochs is None or self ._is_done (self .state ):
690679 # Create new state
680+ if max_epochs is None :
681+ max_epochs = 1
691682 if epoch_length is None :
692683 if data is None :
693684 raise ValueError ("epoch_length should be provided if data is None" )
@@ -696,22 +687,9 @@ def switch_batch(engine):
696687 if epoch_length is not None and epoch_length < 1 :
697688 raise ValueError ("Input data has zero size. Please provide non-empty data" )
698689
699- if max_iters is None :
700- if max_epochs is None :
701- max_epochs = 1
702- else :
703- if max_epochs is not None :
704- raise ValueError (
705- "Arguments max_iters and max_epochs are mutually exclusive."
706- "Please provide only max_epochs or max_iters."
707- )
708- if epoch_length is not None :
709- max_epochs = math .ceil (max_iters / epoch_length )
710-
711690 self .state .iteration = 0
712691 self .state .epoch = 0
713692 self .state .max_epochs = max_epochs
714- self .state .max_iters = max_iters
715693 self .state .epoch_length = epoch_length
716694 self .logger .info (f"Engine run starting with max_epochs={ max_epochs } ." )
717695 else :
@@ -765,7 +743,7 @@ def _internal_run(self) -> State:
765743 try :
766744 start_time = time .time ()
767745 self ._fire_event (Events .STARTED )
768- while not self ._is_done ( self .state ) and not self .should_terminate :
746+ while self .state . epoch < self .state . max_epochs and not self .should_terminate : # type: ignore[operator]
769747 self .state .epoch += 1
770748 self ._fire_event (Events .EPOCH_STARTED )
771749
@@ -835,8 +813,6 @@ def _run_once_on_dataset(self) -> float:
835813 if self .state .epoch_length is None :
836814 # Define epoch length and stop the epoch
837815 self .state .epoch_length = iter_counter
838- if self .state .max_iters is not None :
839- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
840816 break
841817
842818 # Should exit while loop if we can not iterate
@@ -876,10 +852,6 @@ def _run_once_on_dataset(self) -> float:
876852 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
877853 break
878854
879- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
880- self .should_terminate = True
881- break
882-
883855 except Exception as e :
884856 self .logger .error (f"Current run is terminating due to exception: { e } " )
885857 self ._handle_exception (e )
0 commit comments