Skip to content

Commit 7bafab7

Browse files
Hummer12007vfdev-5
andauthored
Added score_sign to add_early_stopping_by_val_score (#2929)
* Added score_sign to add_early_stopping_by_val_score * Updated docstring * Updated tests for add_early_stopping_by_val_score * Added score_sign to save_best_model_by_val_score * Add tests for score_sign with save_best_model_by_val_score * Fmt fix * Update ignite/contrib/engines/common.py Co-authored-by: vfdev <[email protected]> --------- Co-authored-by: vfdev <[email protected]>
1 parent 83e487f commit 7bafab7

File tree

2 files changed

+119
-44
lines changed

2 files changed

+119
-44
lines changed

ignite/contrib/engines/common.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,7 @@ def gen_save_best_models_by_val_score(
580580
n_saved: int = 3,
581581
trainer: Optional[Engine] = None,
582582
tag: str = "val",
583+
score_sign: float = 1.0,
583584
**kwargs: Any,
584585
) -> Checkpoint:
585586
"""Method adds a handler to ``evaluator`` to save ``n_saved`` of best models based on the metric
@@ -602,6 +603,8 @@ def gen_save_best_models_by_val_score(
602603
n_saved: number of best models to store
603604
trainer: trainer engine to fetch the epoch when saving the best model.
604605
tag: score name prefix: `{tag}_{metric_name}`. By default, tag is "val".
606+
score_sign: sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better,
607+
a negative score sign should be used (objects with larger score are retained). Default, 1.0.
605608
kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.
606609
607610
Returns:
@@ -623,7 +626,7 @@ def gen_save_best_models_by_val_score(
623626
n_saved=n_saved,
624627
global_step_transform=global_step_transform,
625628
score_name=f"{tag}_{metric_name.lower()}",
626-
score_function=Checkpoint.get_default_score_fn(metric_name),
629+
score_function=get_default_score_fn(metric_name, score_sign=score_sign),
627630
**kwargs,
628631
)
629632
evaluator.add_event_handler(Events.COMPLETED, best_model_handler)
@@ -639,6 +642,7 @@ def save_best_model_by_val_score(
639642
n_saved: int = 3,
640643
trainer: Optional[Engine] = None,
641644
tag: str = "val",
645+
score_sign: float = 1.0,
642646
**kwargs: Any,
643647
) -> Checkpoint:
644648
"""Method adds a handler to ``evaluator`` to save on a disk ``n_saved`` of best models based on the metric
@@ -654,6 +658,9 @@ def save_best_model_by_val_score(
654658
n_saved: number of best models to store
655659
trainer: trainer engine to fetch the epoch when saving the best model.
656660
tag: score name prefix: `{tag}_{metric_name}`. By default, tag is "val".
661+
score_sign: sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better,
662+
a negative score sign should be used (objects with larger score are retained). Default, 1.0.
663+
657664
kwargs: optional keyword args to be passed to construct :class:`~ignite.handlers.checkpoint.Checkpoint`.
658665
659666
Returns:
@@ -667,12 +674,17 @@ def save_best_model_by_val_score(
667674
n_saved=n_saved,
668675
trainer=trainer,
669676
tag=tag,
677+
score_sign=score_sign,
670678
**kwargs,
671679
)
672680

673681

674682
def add_early_stopping_by_val_score(
675-
patience: int, evaluator: Engine, trainer: Engine, metric_name: str
683+
patience: int,
684+
evaluator: Engine,
685+
trainer: Engine,
686+
metric_name: str,
687+
score_sign: float = 1.0,
676688
) -> EarlyStopping:
677689
"""Method setups early stopping handler based on the score (named by `metric_name`) provided by `evaluator`.
678690
Metric value should increase in order to keep training and not early stop.
@@ -683,11 +695,15 @@ def add_early_stopping_by_val_score(
683695
trainer: trainer engine to stop the run if no improvement.
684696
metric_name: metric name to use for score evaluation. This metric should be present in
685697
`evaluator.state.metrics`.
698+
score_sign: sign of the score: 1.0 or -1.0. For error-like metrics, e.g. smaller is better,
699+
a negative score sign should be used (objects with larger score are retained). Default, 1.0.
686700
687701
Returns:
688702
A :class:`~ignite.handlers.early_stopping.EarlyStopping` handler.
689703
"""
690-
es_handler = EarlyStopping(patience=patience, score_function=get_default_score_fn(metric_name), trainer=trainer)
704+
es_handler = EarlyStopping(
705+
patience=patience, score_function=get_default_score_fn(metric_name, score_sign=score_sign), trainer=trainer
706+
)
691707
evaluator.add_event_handler(Events.COMPLETED, es_handler)
692708

693709
return es_handler

tests/ignite/contrib/engines/test_common.py

Lines changed: 100 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def _test_setup_common_training_handlers(
4848
save_handler=None,
4949
output_transform=lambda loss: loss,
5050
):
51-
5251
lr = 0.01
5352
step_size = 100
5453
gamma = 0.5
@@ -218,7 +217,6 @@ def test_setup_common_training_handlers(dirname, capsys):
218217

219218

220219
def test_setup_common_training_handlers_using_save_handler(dirname, capsys):
221-
222220
save_handler = DiskSaver(dirname=dirname, require_empty=False)
223221
_test_setup_common_training_handlers(dirname=None, device="cpu", save_handler=save_handler)
224222

@@ -231,43 +229,68 @@ def test_setup_common_training_handlers_using_save_handler(dirname, capsys):
231229

232230

233231
def test_save_best_model_by_val_score(dirname):
232+
acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5]
234233

235-
trainer = Engine(lambda e, b: None)
236-
evaluator = Engine(lambda e, b: None)
237-
model = DummyModel()
234+
def setup_trainer():
235+
trainer = Engine(lambda e, b: None)
236+
evaluator = Engine(lambda e, b: None)
237+
model = DummyModel()
238238

239-
acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5]
239+
@trainer.on(Events.EPOCH_COMPLETED)
240+
def validate(engine):
241+
evaluator.run([0, 1])
240242

241-
@trainer.on(Events.EPOCH_COMPLETED)
242-
def validate(engine):
243-
evaluator.run([0, 1])
243+
@evaluator.on(Events.EPOCH_COMPLETED)
244+
def set_eval_metric(engine):
245+
acc = acc_scores[trainer.state.epoch - 1]
246+
engine.state.metrics = {"acc": acc, "loss": 1 - acc}
247+
248+
return trainer, evaluator, model
244249

245-
@evaluator.on(Events.EPOCH_COMPLETED)
246-
def set_eval_metric(engine):
247-
engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]}
250+
trainer, evaluator, model = setup_trainer()
248251

249252
save_best_model_by_val_score(dirname, evaluator, model, metric_name="acc", n_saved=2, trainer=trainer)
250253

251254
trainer.run([0, 1], max_epochs=len(acc_scores))
252255

253256
assert set(os.listdir(dirname)) == {"best_model_8_val_acc=0.6100.pt", "best_model_9_val_acc=0.7000.pt"}
254257

258+
for fname in os.listdir(dirname):
259+
os.unlink(f"{dirname}/{fname}")
255260

256-
def test_gen_save_best_models_by_val_score():
261+
trainer, evaluator, model = setup_trainer()
262+
263+
save_best_model_by_val_score(
264+
dirname, evaluator, model, metric_name="loss", n_saved=2, trainer=trainer, score_sign=-1.0
265+
)
266+
267+
trainer.run([0, 1], max_epochs=len(acc_scores))
268+
269+
assert set(os.listdir(dirname)) == {"best_model_8_val_loss=-0.3900.pt", "best_model_9_val_loss=-0.3000.pt"}
257270

258-
trainer = Engine(lambda e, b: None)
259-
evaluator = Engine(lambda e, b: None)
260-
model = DummyModel()
261271

272+
def test_gen_save_best_models_by_val_score():
262273
acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.5, 0.6, 0.61, 0.7, 0.5]
274+
loss_scores = [0.9, 0.8, 0.7, 0.6, 0.7, 0.5, 0.4, 0.39, 0.3, 0.5]
275+
276+
def setup_trainer():
277+
trainer = Engine(lambda e, b: None)
278+
evaluator = Engine(lambda e, b: None)
279+
model = DummyModel()
280+
281+
@trainer.on(Events.EPOCH_COMPLETED)
282+
def validate(engine):
283+
evaluator.run([0, 1])
263284

264-
@trainer.on(Events.EPOCH_COMPLETED)
265-
def validate(engine):
266-
evaluator.run([0, 1])
285+
@evaluator.on(Events.EPOCH_COMPLETED)
286+
def set_eval_metric(engine):
287+
acc = acc_scores[trainer.state.epoch - 1]
288+
loss = loss_scores[trainer.state.epoch - 1]
289+
engine.state.metrics = {"acc": acc, "loss": loss}
267290

268-
@evaluator.on(Events.EPOCH_COMPLETED)
269-
def set_eval_metric(engine):
270-
engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]}
291+
return trainer, evaluator, model
292+
293+
trainer, evaluator, model = setup_trainer()
271294

272295
save_handler = MagicMock()
273296

@@ -291,36 +314,80 @@ def set_eval_metric(engine):
291314
any_order=True,
292315
)
293316

317+
trainer, evaluator, model = setup_trainer()
294318

295-
def test_add_early_stopping_by_val_score():
296-
trainer = Engine(lambda e, b: None)
297-
evaluator = Engine(lambda e, b: None)
319+
save_handler = MagicMock()
320+
321+
gen_save_best_models_by_val_score(
322+
save_handler,
323+
evaluator,
324+
{"a": model, "b": model},
325+
metric_name="loss",
326+
n_saved=2,
327+
trainer=trainer,
328+
score_sign=-1.0,
329+
)
330+
331+
trainer.run([0, 1], max_epochs=len(acc_scores))
298332

333+
assert save_handler.call_count == len(acc_scores) - 2 # 2 score values (-0.7 and -0.5) are not the best
334+
obj_to_save = {"a": model.state_dict(), "b": model.state_dict()}
335+
save_handler.assert_has_calls(
336+
[
337+
call(
338+
obj_to_save,
339+
f"best_checkpoint_{e}_val_loss={p:.4f}.pt",
340+
dict([("basename", "best_checkpoint"), ("score_name", "val_loss"), ("priority", p)]),
341+
)
342+
for e, p in zip([1, 2, 3, 4, 6, 7, 8, 9], [-0.9, -0.8, -0.7, -0.6, -0.5, -0.4, -0.39, -0.3])
343+
],
344+
any_order=True,
345+
)
346+
347+
348+
def test_add_early_stopping_by_val_score():
299349
acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.3, 0.2, 0.1, 0.1, 0.0]
300350

301-
@trainer.on(Events.EPOCH_COMPLETED)
302-
def validate(engine):
303-
evaluator.run([0, 1])
351+
def setup_trainer():
352+
trainer = Engine(lambda e, b: None)
353+
evaluator = Engine(lambda e, b: None)
354+
355+
@trainer.on(Events.EPOCH_COMPLETED)
356+
def validate(engine):
357+
evaluator.run([0, 1])
304358

305-
@evaluator.on(Events.EPOCH_COMPLETED)
306-
def set_eval_metric(engine):
307-
engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]}
359+
@evaluator.on(Events.EPOCH_COMPLETED)
360+
def set_eval_metric(engine):
361+
acc = acc_scores[trainer.state.epoch - 1]
362+
engine.state.metrics = {"acc": acc, "loss": 1 - acc}
363+
364+
return trainer, evaluator
365+
366+
trainer, evaluator = setup_trainer()
308367

309368
add_early_stopping_by_val_score(patience=3, evaluator=evaluator, trainer=trainer, metric_name="acc")
310369

311370
state = trainer.run([0, 1], max_epochs=len(acc_scores))
312371

313372
assert state.epoch == 7
314373

374+
trainer, evaluator = setup_trainer()
315375

316-
def test_deprecated_setup_any_logging():
376+
add_early_stopping_by_val_score(
377+
patience=3, evaluator=evaluator, trainer=trainer, metric_name="loss", score_sign=-1.0
378+
)
379+
380+
state = trainer.run([0, 1], max_epochs=len(acc_scores))
381+
382+
assert state.epoch == 7
317383

384+
385+
def test_deprecated_setup_any_logging():
318386
with pytest.raises(DeprecationWarning, match=r"deprecated since version 0.4.0"):
319387
setup_any_logging(None, None, None, None, None, None)
320388

321389

322390
def test__setup_logging_wrong_args():
323-
324391
with pytest.raises(TypeError, match=r"Argument optimizers should be either a single optimizer or"):
325392
_setup_logging(MagicMock(), MagicMock(), "abc", MagicMock(), 1)
326393

@@ -406,7 +473,6 @@ def set_eval_metric(engine):
406473

407474

408475
def test_setup_tb_logging(dirname):
409-
410476
tb_logger = _test_setup_logging(
411477
setup_logging_fn=setup_tb_logging,
412478
kwargs_dict={"output_path": dirname / "t1"},
@@ -462,7 +528,6 @@ def test_setup_visdom_logging(visdom_offline_logfile):
462528

463529

464530
def test_setup_plx_logging():
465-
466531
os.environ["POLYAXON_NO_OP"] = "1"
467532

468533
_test_setup_logging(
@@ -506,15 +571,13 @@ def test_setup_mlflow_logging(dirname):
506571

507572

508573
def test_setup_wandb_logging(dirname):
509-
510574
from unittest.mock import patch
511575

512576
with patch("ignite.contrib.engines.common.WandBLogger") as _:
513577
setup_wandb_logging(MagicMock())
514578

515579

516580
def test_setup_clearml_logging():
517-
518581
handlers.clearml_logger.ClearMLLogger.set_bypass_mode(True)
519582

520583
with pytest.warns(UserWarning, match=r"running in bypass mode"):
@@ -583,7 +646,6 @@ def test_setup_neptune_logging(dirname):
583646
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
584647
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
585648
def test_distrib_nccl_gpu(dirname, distributed_context_single_node_nccl):
586-
587649
local_rank = distributed_context_single_node_nccl["local_rank"]
588650
device = idist.device()
589651
_test_setup_common_training_handlers(dirname, device, rank=local_rank, local_rank=local_rank, distributed=True)
@@ -593,7 +655,6 @@ def test_distrib_nccl_gpu(dirname, distributed_context_single_node_nccl):
593655
@pytest.mark.distributed
594656
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
595657
def test_distrib_gloo_cpu_or_gpu(dirname, distributed_context_single_node_gloo):
596-
597658
device = idist.device()
598659
local_rank = distributed_context_single_node_gloo["local_rank"]
599660
_test_setup_common_training_handlers(dirname, device, rank=local_rank, local_rank=local_rank, distributed=True)
@@ -610,7 +671,6 @@ def test_distrib_gloo_cpu_or_gpu(dirname, distributed_context_single_node_gloo):
610671
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
611672
@pytest.mark.skipif("MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
612673
def test_multinode_distrib_gloo_cpu_or_gpu(dirname, distributed_context_multi_node_gloo):
613-
614674
device = idist.device()
615675
rank = distributed_context_multi_node_gloo["rank"]
616676
_test_setup_common_training_handlers(dirname, device, rank=rank)
@@ -621,7 +681,6 @@ def test_multinode_distrib_gloo_cpu_or_gpu(dirname, distributed_context_multi_no
621681
@pytest.mark.skipif(not idist.has_native_dist_support, reason="Skip if no native dist support")
622682
@pytest.mark.skipif("GPU_MULTINODE_DISTRIB" not in os.environ, reason="Skip if not multi-node distributed")
623683
def test_multinode_distrib_nccl_gpu(dirname, distributed_context_multi_node_nccl):
624-
625684
local_rank = distributed_context_multi_node_nccl["local_rank"]
626685
rank = distributed_context_multi_node_nccl["rank"]
627686
device = idist.device()

0 commit comments

Comments
 (0)