@@ -159,7 +159,7 @@ def test_running_test_pretrained_model_distrib_dp(tmpdir):
159159 max_epochs = 2 ,
160160 limit_train_batches = 0.4 ,
161161 limit_val_batches = 0.2 ,
162- checkpoint_callback = checkpoint ,
162+ callbacks = [ checkpoint ] ,
163163 logger = logger ,
164164 gpus = [0 , 1 ],
165165 accelerator = 'dp' ,
@@ -209,7 +209,7 @@ def test_running_test_pretrained_model_distrib_ddp_spawn(tmpdir):
209209 max_epochs = 2 ,
210210 limit_train_batches = 0.4 ,
211211 limit_val_batches = 0.2 ,
212- checkpoint_callback = checkpoint ,
212+ callbacks = [ checkpoint ] ,
213213 logger = logger ,
214214 gpus = [0 , 1 ],
215215 accelerator = 'ddp_spawn' ,
@@ -257,7 +257,7 @@ def test_running_test_pretrained_model_cpu(tmpdir):
257257 max_epochs = 3 ,
258258 limit_train_batches = 0.4 ,
259259 limit_val_batches = 0.2 ,
260- checkpoint_callback = checkpoint ,
260+ callbacks = [ checkpoint ] ,
261261 logger = logger ,
262262 default_root_dir = tmpdir ,
263263 )
@@ -288,7 +288,7 @@ def test_load_model_from_checkpoint(tmpdir, model_template):
288288 max_epochs = 2 ,
289289 limit_train_batches = 0.4 ,
290290 limit_val_batches = 0.2 ,
291- checkpoint_callback = ModelCheckpoint (dirpath = tmpdir , monitor = 'early_stop_on' , save_top_k = - 1 ),
291+ callbacks = [ ModelCheckpoint (dirpath = tmpdir , monitor = 'early_stop_on' , save_top_k = - 1 )] ,
292292 default_root_dir = tmpdir ,
293293 )
294294
@@ -404,8 +404,10 @@ def test_model_saving_loading(tmpdir):
404404
405405 # fit model
406406 trainer = Trainer (
407- max_epochs = 1 , logger = logger ,
408- checkpoint_callback = ModelCheckpoint (dirpath = tmpdir ), default_root_dir = tmpdir ,
407+ max_epochs = 1 ,
408+ logger = logger ,
409+ callbacks = [ModelCheckpoint (dirpath = tmpdir )],
410+ default_root_dir = tmpdir ,
409411 )
410412 result = trainer .fit (model )
411413
@@ -460,7 +462,7 @@ def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server, url_c
460462 # fit model
461463 trainer = Trainer (
462464 default_root_dir = tmpdir , max_epochs = 1 , logger = logger ,
463- checkpoint_callback = ModelCheckpoint (dirpath = tmpdir ),
465+ callbacks = [ ModelCheckpoint (dirpath = tmpdir )] ,
464466 )
465467 result = trainer .fit (model )
466468
@@ -500,7 +502,7 @@ def test_strict_model_load_less_params(monkeypatch, tmpdir, tmpdir_server, url_c
500502 # fit model
501503 trainer = Trainer (
502504 default_root_dir = tmpdir , max_epochs = 1 , logger = logger ,
503- checkpoint_callback = ModelCheckpoint (dirpath = tmpdir ),
505+ callbacks = [ ModelCheckpoint (dirpath = tmpdir )] ,
504506 )
505507 result = trainer .fit (model )
506508
0 commit comments