Skip to content

Commit 6281238

Browse files
committed
Fixed other things due to reverted commits
1 parent 39f5030 commit 6281238

File tree

8 files changed

+17
-72
lines changed

8 files changed

+17
-72
lines changed

ignite/contrib/handlers/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from ignite.contrib.handlers.clearml_logger import ClearMLLogger
22
from ignite.contrib.handlers.custom_events import CustomPeriodicEvent
3-
from ignite.contrib.handlers.lr_finder import FastaiLRFinder
43
from ignite.contrib.handlers.mlflow_logger import MLflowLogger
54
from ignite.contrib.handlers.neptune_logger import NeptuneLogger
65
from ignite.contrib.handlers.polyaxon_logger import PolyaxonLogger

ignite/engine/engine.py

Lines changed: 11 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
730730

731731
@staticmethod
732732
def _is_done(state: State) -> bool:
733-
return state.iteration == state.epoch_length * state.max_epochs # type: ignore[operator]
733+
is_done_count = (
734+
state.epoch_length is not None
735+
and state.max_epochs is not None
736+
and state.iteration >= state.epoch_length * state.max_epochs
737+
)
738+
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
739+
return is_done_count or is_done_epochs
734740

735741
def set_data(self, data: Union[Iterable, DataLoader]) -> None:
736742
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -962,7 +968,6 @@ def _internal_run_as_gen(self) -> Generator:
962968
self.state.times[Events.EPOCH_COMPLETED.name] = epoch_time_taken
963969

964970
handlers_start_time = time.time()
965-
966971
self._fire_event(Events.EPOCH_COMPLETED)
967972
epoch_time_taken += time.time() - handlers_start_time
968973
# update time wrt handlers
@@ -1045,13 +1050,8 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10451050

10461051
# Should exit while loop if we can not iterate
10471052
if should_exit:
1048-
if not self._is_done(self.state):
1049-
total_iters = (
1050-
self.state.epoch_length * self.state.max_epochs
1051-
if self.state.max_epochs is not None
1052-
else self.state.max_iters
1053-
)
1054-
1053+
if not self._is_done(self.state) and self.state.max_epochs is not None:
1054+
total_iters = self.state.epoch_length * self.state.max_epochs
10551055
warnings.warn(
10561056
"Data iterator can not provide data anymore but required total number of "
10571057
"iterations to run is not reached. "
@@ -1078,10 +1078,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
10781078
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
10791079
break
10801080

1081-
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
1082-
self.should_terminate = True
1083-
raise _EngineTerminateException()
1084-
10851081
except _EngineTerminateSingleEpochException:
10861082
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
10871083
self.should_terminate_single_epoch = False
@@ -1197,19 +1193,12 @@ def _run_once_on_dataset_legacy(self) -> float:
11971193
if self.state.epoch_length is None:
11981194
# Define epoch length and stop the epoch
11991195
self.state.epoch_length = iter_counter
1200-
if self.state.max_iters is not None:
1201-
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
12021196
break
12031197

12041198
# Should exit while loop if we can not iterate
12051199
if should_exit:
1206-
if not self._is_done(self.state):
1207-
total_iters = (
1208-
self.state.epoch_length * self.state.max_epochs
1209-
if self.state.max_epochs is not None
1210-
else self.state.max_iters
1211-
)
1212-
1200+
if not self._is_done(self.state) and self.state.max_epochs is not None:
1201+
total_iters = self.state.epoch_length * self.state.max_epochs
12131202
warnings.warn(
12141203
"Data iterator can not provide data anymore but required total number of "
12151204
"iterations to run is not reached. "
@@ -1236,10 +1225,6 @@ def _run_once_on_dataset_legacy(self) -> float:
12361225
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
12371226
break
12381227

1239-
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
1240-
self.should_terminate = True
1241-
raise _EngineTerminateException()
1242-
12431228
except _EngineTerminateSingleEpochException:
12441229
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
12451230
self.should_terminate_single_epoch = False

ignite/engine/events.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
214214
)
215215

216216

217-
class EventEnum(CallableEventWithFilter, Enum): # type: ignore[misc]
217+
class EventEnum(CallableEventWithFilter, Enum):
218218
"""Base class for all :class:`~ignite.engine.events.Events`. User defined custom events should also inherit
219219
this class.
220220

ignite/handlers/checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -964,7 +964,7 @@ def __init__(
964964
self,
965965
dirname: Union[str, Path],
966966
filename_prefix: str = "",
967-
save_interval: Optional[Callable] = None,
967+
save_interval: Optional[int] = None,
968968
score_function: Optional[Callable] = None,
969969
score_name: Optional[str] = None,
970970
n_saved: Union[int, None] = 1,

ignite/handlers/lr_finder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def _run(
105105
max_iter = trainer.state.epoch_length * trainer.state.max_epochs # type: ignore[operator]
106106
if max_iter < num_iter:
107107
max_iter = num_iter
108-
trainer.state.max_iters = num_iter
109108
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length) # type: ignore[operator]
110109

111110
if not trainer.has_event_handler(self._reached_num_iterations):

mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,6 @@ ignore_missing_imports = True
7777

7878
[mypy-torchvision.*]
7979
ignore_missing_imports = True
80+
81+
[mypy-ignite.contrib.handlers.custom_events]
82+
ignore_errors = True

tests/ignite/engine/test_engine.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,47 +1027,6 @@ def switch_dataloader():
10271027

10281028
trainer.run(data1, max_epochs=10)
10291029

1030-
def test_run_with_max_iters(self):
1031-
max_iters = 8
1032-
engine = Engine(lambda e, b: 1)
1033-
engine.run([0] * 20, max_iters=max_iters)
1034-
assert engine.state.iteration == max_iters
1035-
assert engine.state.max_iters == max_iters
1036-
1037-
def test_run_with_max_iters_greater_than_epoch_length(self):
1038-
max_iters = 73
1039-
engine = Engine(lambda e, b: 1)
1040-
engine.run([0] * 20, max_iters=max_iters)
1041-
assert engine.state.iteration == max_iters
1042-
1043-
def test_run_with_invalid_max_iters_and_max_epoch(self):
1044-
max_iters = 12
1045-
max_epochs = 2
1046-
engine = Engine(lambda e, b: 1)
1047-
with pytest.raises(
1048-
ValueError,
1049-
match=r"Arguments max_iters and max_epochs are mutually exclusive."
1050-
"Please provide only max_epochs or max_iters.",
1051-
):
1052-
engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs)
1053-
1054-
def test_epoch_events_fired_max_iters(self):
1055-
max_iters = 32
1056-
engine = Engine(lambda e, b: 1)
1057-
1058-
@engine.on(Events.EPOCH_COMPLETED)
1059-
def fired_event(engine):
1060-
assert engine.state.iteration % engine.state.epoch_length == 0
1061-
1062-
engine.run([0] * 10, max_iters=max_iters)
1063-
1064-
def test_is_done_with_max_iters(self):
1065-
state = State(iteration=100, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
1066-
assert not Engine._is_done(state)
1067-
1068-
state = State(iteration=250, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
1069-
assert Engine._is_done(state)
1070-
10711030
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
10721031
def test_batch_is_released_before_new_one_is_loaded_on_cuda(self):
10731032
torch.cuda.empty_cache()

tests/ignite/handlers/test_lr_finder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def test_num_iter_is_not_enough(lr_finder, to_save, dummy_engine, dataloader):
344344
trainer_with_finder.run(dataloader)
345345
assert_output_sizes(lr_finder, dummy_engine)
346346
assert dummy_engine.state.iteration != len(dataloader)
347-
assert dummy_engine.state.iteration == 150
347+
assert dummy_engine.state.iteration == 150 + 1
348348

349349

350350
def test_detach_terminates(lr_finder, to_save, dummy_engine, dataloader):

0 commit comments

Comments
 (0)