@@ -730,7 +730,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
730730
731731 @staticmethod
732732 def _is_done (state : State ) -> bool :
733- return state .iteration == state .epoch_length * state .max_epochs # type: ignore[operator]
733+ is_done_count = (
734+ state .epoch_length is not None
735+ and state .max_epochs is not None
736+ and state .iteration >= state .epoch_length * state .max_epochs
737+ )
738+ is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
739+ return is_done_count or is_done_epochs
734740
735741 def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
736742 """Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -962,7 +968,6 @@ def _internal_run_as_gen(self) -> Generator:
962968 self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
963969
964970 handlers_start_time = time .time ()
965-
966971 self ._fire_event (Events .EPOCH_COMPLETED )
967972 epoch_time_taken += time .time () - handlers_start_time
968973 # update time wrt handlers
@@ -1045,13 +1050,8 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10451050
10461051 # Should exit while loop if we can not iterate
10471052 if should_exit :
1048- if not self ._is_done (self .state ):
1049- total_iters = (
1050- self .state .epoch_length * self .state .max_epochs
1051- if self .state .max_epochs is not None
1052- else self .state .max_iters
1053- )
1054-
1053+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1054+ total_iters = self .state .epoch_length * self .state .max_epochs
10551055 warnings .warn (
10561056 "Data iterator can not provide data anymore but required total number of "
10571057 "iterations to run is not reached. "
@@ -1078,10 +1078,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10781078 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
10791079 break
10801080
1081- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1082- self .should_terminate = True
1083- raise _EngineTerminateException ()
1084-
10851081 except _EngineTerminateSingleEpochException :
10861082 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
10871083 self .should_terminate_single_epoch = False
@@ -1197,19 +1193,12 @@ def _run_once_on_dataset_legacy(self) -> float:
11971193 if self .state .epoch_length is None :
11981194 # Define epoch length and stop the epoch
11991195 self .state .epoch_length = iter_counter
1200- if self .state .max_iters is not None :
1201- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
12021196 break
12031197
12041198 # Should exit while loop if we can not iterate
12051199 if should_exit :
1206- if not self ._is_done (self .state ):
1207- total_iters = (
1208- self .state .epoch_length * self .state .max_epochs
1209- if self .state .max_epochs is not None
1210- else self .state .max_iters
1211- )
1212-
1200+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1201+ total_iters = self .state .epoch_length * self .state .max_epochs
12131202 warnings .warn (
12141203 "Data iterator can not provide data anymore but required total number of "
12151204 "iterations to run is not reached. "
@@ -1236,10 +1225,6 @@ def _run_once_on_dataset_legacy(self) -> float:
12361225 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
12371226 break
12381227
1239- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1240- self .should_terminate = True
1241- raise _EngineTerminateException ()
1242-
12431228 except _EngineTerminateSingleEpochException :
12441229 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
12451230 self .should_terminate_single_epoch = False
0 commit comments