@@ -48,7 +48,6 @@ def _test_setup_common_training_handlers(
4848 save_handler = None ,
4949 output_transform = lambda loss : loss ,
5050):
51-
5251 lr = 0.01
5352 step_size = 100
5453 gamma = 0.5
@@ -218,7 +217,6 @@ def test_setup_common_training_handlers(dirname, capsys):
218217
219218
220219def test_setup_common_training_handlers_using_save_handler (dirname , capsys ):
221-
222220 save_handler = DiskSaver (dirname = dirname , require_empty = False )
223221 _test_setup_common_training_handlers (dirname = None , device = "cpu" , save_handler = save_handler )
224222
@@ -231,43 +229,68 @@ def test_setup_common_training_handlers_using_save_handler(dirname, capsys):
231229
232230
233231def test_save_best_model_by_val_score (dirname ):
232+ acc_scores = [0.1 , 0.2 , 0.3 , 0.4 , 0.3 , 0.5 , 0.6 , 0.61 , 0.7 , 0.5 ]
234233
235- trainer = Engine (lambda e , b : None )
236- evaluator = Engine (lambda e , b : None )
237- model = DummyModel ()
234+ def setup_trainer ():
235+ trainer = Engine (lambda e , b : None )
236+ evaluator = Engine (lambda e , b : None )
237+ model = DummyModel ()
238238
239- acc_scores = [0.1 , 0.2 , 0.3 , 0.4 , 0.3 , 0.5 , 0.6 , 0.61 , 0.7 , 0.5 ]
239+ @trainer .on (Events .EPOCH_COMPLETED )
240+ def validate (engine ):
241+ evaluator .run ([0 , 1 ])
240242
241- @trainer .on (Events .EPOCH_COMPLETED )
242- def validate (engine ):
243- evaluator .run ([0 , 1 ])
243+ @evaluator .on (Events .EPOCH_COMPLETED )
244+ def set_eval_metric (engine ):
245+ acc = acc_scores [trainer .state .epoch - 1 ]
246+ engine .state .metrics = {"acc" : acc , "loss" : 1 - acc }
247+
248+ return trainer , evaluator , model
244249
245- @evaluator .on (Events .EPOCH_COMPLETED )
246- def set_eval_metric (engine ):
247- engine .state .metrics = {"acc" : acc_scores [trainer .state .epoch - 1 ]}
250+ trainer , evaluator , model = setup_trainer ()
248251
249252 save_best_model_by_val_score (dirname , evaluator , model , metric_name = "acc" , n_saved = 2 , trainer = trainer )
250253
251254 trainer .run ([0 , 1 ], max_epochs = len (acc_scores ))
252255
253256 assert set (os .listdir (dirname )) == {"best_model_8_val_acc=0.6100.pt" , "best_model_9_val_acc=0.7000.pt" }
254257
258+ for fname in os .listdir (dirname ):
259+ os .unlink (f"{ dirname } /{ fname } " )
255260
256- def test_gen_save_best_models_by_val_score ():
261+ trainer , evaluator , model = setup_trainer ()
262+
263+ save_best_model_by_val_score (
264+ dirname , evaluator , model , metric_name = "loss" , n_saved = 2 , trainer = trainer , score_sign = - 1.0
265+ )
266+
267+ trainer .run ([0 , 1 ], max_epochs = len (acc_scores ))
268+
269+ assert set (os .listdir (dirname )) == {"best_model_8_val_loss=-0.3900.pt" , "best_model_9_val_loss=-0.3000.pt" }
257270
258- trainer = Engine (lambda e , b : None )
259- evaluator = Engine (lambda e , b : None )
260- model = DummyModel ()
261271
272+ def test_gen_save_best_models_by_val_score ():
262273 acc_scores = [0.1 , 0.2 , 0.3 , 0.4 , 0.3 , 0.5 , 0.6 , 0.61 , 0.7 , 0.5 ]
274+ loss_scores = [0.9 , 0.8 , 0.7 , 0.6 , 0.7 , 0.5 , 0.4 , 0.39 , 0.3 , 0.5 ]
275+
276+ def setup_trainer ():
277+ trainer = Engine (lambda e , b : None )
278+ evaluator = Engine (lambda e , b : None )
279+ model = DummyModel ()
280+
281+ @trainer .on (Events .EPOCH_COMPLETED )
282+ def validate (engine ):
283+ evaluator .run ([0 , 1 ])
263284
264- @trainer .on (Events .EPOCH_COMPLETED )
265- def validate (engine ):
266- evaluator .run ([0 , 1 ])
285+ @evaluator .on (Events .EPOCH_COMPLETED )
286+ def set_eval_metric (engine ):
287+ acc = acc_scores [trainer .state .epoch - 1 ]
288+ loss = loss_scores [trainer .state .epoch - 1 ]
289+ engine .state .metrics = {"acc" : acc , "loss" : loss }
267290
268- @ evaluator . on ( Events . EPOCH_COMPLETED )
269- def set_eval_metric ( engine ):
270- engine . state . metrics = { "acc" : acc_scores [ trainer . state . epoch - 1 ]}
291+ return trainer , evaluator , model
292+
293+ trainer , evaluator , model = setup_trainer ()
271294
272295 save_handler = MagicMock ()
273296
@@ -291,36 +314,80 @@ def set_eval_metric(engine):
291314 any_order = True ,
292315 )
293316
317+ trainer , evaluator , model = setup_trainer ()
294318
295- def test_add_early_stopping_by_val_score ():
296- trainer = Engine (lambda e , b : None )
297- evaluator = Engine (lambda e , b : None )
319+ save_handler = MagicMock ()
320+
321+ gen_save_best_models_by_val_score (
322+ save_handler ,
323+ evaluator ,
324+ {"a" : model , "b" : model },
325+ metric_name = "loss" ,
326+ n_saved = 2 ,
327+ trainer = trainer ,
328+ score_sign = - 1.0 ,
329+ )
330+
331+ trainer .run ([0 , 1 ], max_epochs = len (acc_scores ))
298332
333+ assert save_handler .call_count == len (acc_scores ) - 2 # 2 score values (-0.7 and -0.5) are not the best
334+ obj_to_save = {"a" : model .state_dict (), "b" : model .state_dict ()}
335+ save_handler .assert_has_calls (
336+ [
337+ call (
338+ obj_to_save ,
339+ f"best_checkpoint_{ e } _val_loss={ p :.4f} .pt" ,
340+ dict ([("basename" , "best_checkpoint" ), ("score_name" , "val_loss" ), ("priority" , p )]),
341+ )
342+ for e , p in zip ([1 , 2 , 3 , 4 , 6 , 7 , 8 , 9 ], [- 0.9 , - 0.8 , - 0.7 , - 0.6 , - 0.5 , - 0.4 , - 0.39 , - 0.3 ])
343+ ],
344+ any_order = True ,
345+ )
346+
347+
348+ def test_add_early_stopping_by_val_score ():
299349 acc_scores = [0.1 , 0.2 , 0.3 , 0.4 , 0.3 , 0.3 , 0.2 , 0.1 , 0.1 , 0.0 ]
300350
301- @trainer .on (Events .EPOCH_COMPLETED )
302- def validate (engine ):
303- evaluator .run ([0 , 1 ])
351+ def setup_trainer ():
352+ trainer = Engine (lambda e , b : None )
353+ evaluator = Engine (lambda e , b : None )
354+
355+ @trainer .on (Events .EPOCH_COMPLETED )
356+ def validate (engine ):
357+ evaluator .run ([0 , 1 ])
304358
305- @evaluator .on (Events .EPOCH_COMPLETED )
306- def set_eval_metric (engine ):
307- engine .state .metrics = {"acc" : acc_scores [trainer .state .epoch - 1 ]}
359+ @evaluator .on (Events .EPOCH_COMPLETED )
360+ def set_eval_metric (engine ):
361+ acc = acc_scores [trainer .state .epoch - 1 ]
362+ engine .state .metrics = {"acc" : acc , "loss" : 1 - acc }
363+
364+ return trainer , evaluator
365+
366+ trainer , evaluator = setup_trainer ()
308367
309368 add_early_stopping_by_val_score (patience = 3 , evaluator = evaluator , trainer = trainer , metric_name = "acc" )
310369
311370 state = trainer .run ([0 , 1 ], max_epochs = len (acc_scores ))
312371
313372 assert state .epoch == 7
314373
374+ trainer , evaluator = setup_trainer ()
315375
316- def test_deprecated_setup_any_logging ():
376+ add_early_stopping_by_val_score (
377+ patience = 3 , evaluator = evaluator , trainer = trainer , metric_name = "loss" , score_sign = - 1.0
378+ )
379+
380+ state = trainer .run ([0 , 1 ], max_epochs = len (acc_scores ))
381+
382+ assert state .epoch == 7
317383
384+
385+ def test_deprecated_setup_any_logging ():
318386 with pytest .raises (DeprecationWarning , match = r"deprecated since version 0.4.0" ):
319387 setup_any_logging (None , None , None , None , None , None )
320388
321389
322390def test__setup_logging_wrong_args ():
323-
324391 with pytest .raises (TypeError , match = r"Argument optimizers should be either a single optimizer or" ):
325392 _setup_logging (MagicMock (), MagicMock (), "abc" , MagicMock (), 1 )
326393
@@ -406,7 +473,6 @@ def set_eval_metric(engine):
406473
407474
408475def test_setup_tb_logging (dirname ):
409-
410476 tb_logger = _test_setup_logging (
411477 setup_logging_fn = setup_tb_logging ,
412478 kwargs_dict = {"output_path" : dirname / "t1" },
@@ -462,7 +528,6 @@ def test_setup_visdom_logging(visdom_offline_logfile):
462528
463529
464530def test_setup_plx_logging ():
465-
466531 os .environ ["POLYAXON_NO_OP" ] = "1"
467532
468533 _test_setup_logging (
@@ -506,15 +571,13 @@ def test_setup_mlflow_logging(dirname):
506571
507572
508573def test_setup_wandb_logging (dirname ):
509-
510574 from unittest .mock import patch
511575
512576 with patch ("ignite.contrib.engines.common.WandBLogger" ) as _ :
513577 setup_wandb_logging (MagicMock ())
514578
515579
516580def test_setup_clearml_logging ():
517-
518581 handlers .clearml_logger .ClearMLLogger .set_bypass_mode (True )
519582
520583 with pytest .warns (UserWarning , match = r"running in bypass mode" ):
@@ -583,7 +646,6 @@ def test_setup_neptune_logging(dirname):
583646@pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
584647@pytest .mark .skipif (torch .cuda .device_count () < 1 , reason = "Skip if no GPU" )
585648def test_distrib_nccl_gpu (dirname , distributed_context_single_node_nccl ):
586-
587649 local_rank = distributed_context_single_node_nccl ["local_rank" ]
588650 device = idist .device ()
589651 _test_setup_common_training_handlers (dirname , device , rank = local_rank , local_rank = local_rank , distributed = True )
@@ -593,7 +655,6 @@ def test_distrib_nccl_gpu(dirname, distributed_context_single_node_nccl):
593655@pytest .mark .distributed
594656@pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
595657def test_distrib_gloo_cpu_or_gpu (dirname , distributed_context_single_node_gloo ):
596-
597658 device = idist .device ()
598659 local_rank = distributed_context_single_node_gloo ["local_rank" ]
599660 _test_setup_common_training_handlers (dirname , device , rank = local_rank , local_rank = local_rank , distributed = True )
@@ -610,7 +671,6 @@ def test_distrib_gloo_cpu_or_gpu(dirname, distributed_context_single_node_gloo):
610671@pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
611672@pytest .mark .skipif ("MULTINODE_DISTRIB" not in os .environ , reason = "Skip if not multi-node distributed" )
612673def test_multinode_distrib_gloo_cpu_or_gpu (dirname , distributed_context_multi_node_gloo ):
613-
614674 device = idist .device ()
615675 rank = distributed_context_multi_node_gloo ["rank" ]
616676 _test_setup_common_training_handlers (dirname , device , rank = rank )
@@ -621,7 +681,6 @@ def test_multinode_distrib_gloo_cpu_or_gpu(dirname, distributed_context_multi_no
621681@pytest .mark .skipif (not idist .has_native_dist_support , reason = "Skip if no native dist support" )
622682@pytest .mark .skipif ("GPU_MULTINODE_DISTRIB" not in os .environ , reason = "Skip if not multi-node distributed" )
623683def test_multinode_distrib_nccl_gpu (dirname , distributed_context_multi_node_nccl ):
624-
625684 local_rank = distributed_context_multi_node_nccl ["local_rank" ]
626685 rank = distributed_context_multi_node_nccl ["rank" ]
627686 device = idist .device ()
0 commit comments