@@ -724,7 +724,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
724724
725725 @staticmethod
726726 def _is_done (state : State ) -> bool :
727- return state .iteration == state .epoch_length * state .max_epochs # type: ignore[operator]
727+ is_done_count = (
728+ state .epoch_length is not None
729+ and state .max_epochs is not None
730+ and state .iteration >= state .epoch_length * state .max_epochs
731+ )
732+ is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
733+ return is_done_count or is_done_epochs
728734
729735 def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
730736 """Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -956,7 +962,6 @@ def _internal_run_as_gen(self) -> Generator:
956962 self .state .times [Events .EPOCH_COMPLETED .name ] = epoch_time_taken
957963
958964 handlers_start_time = time .time ()
959-
960965 self ._fire_event (Events .EPOCH_COMPLETED )
961966 epoch_time_taken += time .time () - handlers_start_time
962967 # update time wrt handlers
@@ -1039,13 +1044,8 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10391044
10401045 # Should exit while loop if we can not iterate
10411046 if should_exit :
1042- if not self ._is_done (self .state ):
1043- total_iters = (
1044- self .state .epoch_length * self .state .max_epochs
1045- if self .state .max_epochs is not None
1046- else self .state .max_iters
1047- )
1048-
1047+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1048+ total_iters = self .state .epoch_length * self .state .max_epochs
10491049 warnings .warn (
10501050 "Data iterator can not provide data anymore but required total number of "
10511051 "iterations to run is not reached. "
@@ -1072,10 +1072,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10721072 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
10731073 break
10741074
1075- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1076- self .should_terminate = True
1077- raise _EngineTerminateException ()
1078-
10791075 except _EngineTerminateSingleEpochException :
10801076 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
10811077 self .should_terminate_single_epoch = False
@@ -1191,19 +1187,12 @@ def _run_once_on_dataset_legacy(self) -> float:
11911187 if self .state .epoch_length is None :
11921188 # Define epoch length and stop the epoch
11931189 self .state .epoch_length = iter_counter
1194- if self .state .max_iters is not None :
1195- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
11961190 break
11971191
11981192 # Should exit while loop if we can not iterate
11991193 if should_exit :
1200- if not self ._is_done (self .state ):
1201- total_iters = (
1202- self .state .epoch_length * self .state .max_epochs
1203- if self .state .max_epochs is not None
1204- else self .state .max_iters
1205- )
1206-
1194+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1195+ total_iters = self .state .epoch_length * self .state .max_epochs
12071196 warnings .warn (
12081197 "Data iterator can not provide data anymore but required total number of "
12091198 "iterations to run is not reached. "
@@ -1230,10 +1219,6 @@ def _run_once_on_dataset_legacy(self) -> float:
12301219 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
12311220 break
12321221
1233- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1234- self .should_terminate = True
1235- raise _EngineTerminateException ()
1236-
12371222 except _EngineTerminateSingleEpochException :
12381223 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
12391224 self .should_terminate_single_epoch = False
0 commit comments