1919
2020import pandas as pd
2121import pytest
22+ import pytorch_lightning as pl
2223import torch
2324from pytest import FixtureRequest
2425from pytorch_lightning import Trainer , seed_everything
2526from pytorch_lightning .loggers import CSVLogger
27+ from torch import nn
2628
2729from rectools import Columns
2830from rectools .dataset import Dataset
3537from .utils import custom_trainer , custom_trainer_ckpt , custom_trainer_multiple_ckpt , leave_one_out_mask
3638
3739
40+ def assert_torch_models_equal (model_a : nn .Module , model_b : nn .Module ) -> None :
41+ assert type (model_a ) is type (model_b ), "different types"
42+
43+ with torch .no_grad ():
44+ for (apn , apv ), (bpn , bpv ) in zip (model_a .named_parameters (), model_b .named_parameters ()):
45+ assert apn == bpn , "different parameter name"
46+ assert torch .isclose (apv , bpv ).all (), "different parameter value"
47+
48+
49+ def assert_pl_models_equal (model_a : pl .LightningModule , model_b : pl .LightningModule ) -> None :
50+ """Assert pl modules are equal in terms of weights and trainer"""
51+ assert_torch_models_equal (model_a , model_b )
52+
53+ trainer_a = model_a .trainer
54+ trainer_b = model_a .trainer
55+
56+ assert_pl_trainers_equal (trainer_a , trainer_b )
57+
58+
59+ def assert_pl_trainers_equal (trainer_a : Trainer , trainer_b : Trainer ) -> None :
60+ """Assert pl trainers are equal in terms of optimizers state"""
61+ assert len (trainer_a .optimizers ) == len (trainer_b .optimizers ), "Different number of optimizers"
62+
63+ for opt_a , opt_b in zip (trainer_b .optimizers , trainer_b .optimizers ):
64+ # Check optimizer class
65+ assert type (opt_a ) is type (opt_b ), f"Optimizer types differ: { type (opt_a )} vs { type (opt_b )} "
66+ assert opt_a .state_dict () == opt_b .state_dict (), "optimizers state dict differs"
67+
68+
3869class TestTransformerModelBase :
3970 def setup_method (self ) -> None :
4071 torch .use_deterministic_algorithms (True )
@@ -209,28 +240,6 @@ def test_load_from_checkpoint(
209240
210241 self ._assert_same_reco (model , recovered_model , dataset )
211242
212- @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
213- def test_raises_when_save_model_loaded_from_checkpoint (
214- self ,
215- model_cls : tp .Type [TransformerModelBase ],
216- dataset : Dataset ,
217- ) -> None :
218- model = model_cls .from_config (
219- {
220- "deterministic" : True ,
221- "get_trainer_func" : custom_trainer_ckpt ,
222- }
223- )
224- model .fit (dataset )
225- assert model .fit_trainer is not None
226- if model .fit_trainer .log_dir is None :
227- raise ValueError ("No log dir" )
228- ckpt_path = os .path .join (model .fit_trainer .log_dir , "checkpoints" , "last_epoch.ckpt" )
229- recovered_model = model_cls .load_from_checkpoint (ckpt_path )
230- with pytest .raises (RuntimeError ):
231- with NamedTemporaryFile () as f :
232- recovered_model .save (f .name )
233-
234243 @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
235244 def test_load_weights_from_checkpoint (
236245 self ,
@@ -391,8 +400,6 @@ def test_fit_partial_from_checkpoint(
391400 recovered_fit_partial_model = model_cls .load_from_checkpoint (ckpt_path )
392401
393402 seed_everything (32 , workers = True )
394- fit_partial_model .fit_trainer = deepcopy (fit_partial_model ._trainer ) # pylint: disable=protected-access
395- fit_partial_model .lightning_model .optimizer = None
396403 fit_partial_model .fit_partial (dataset , min_epochs = 1 , max_epochs = 1 )
397404
398405 seed_everything (32 , workers = True )
@@ -410,3 +417,108 @@ def test_raises_when_incorrect_similarity_dist(
410417 with pytest .raises (ValueError ):
411418 model = model_cls .from_config (model_config )
412419 model .fit (dataset = dataset )
420+
421+ @pytest .mark .parametrize ("fit" , (True , False ))
422+ @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
423+ @pytest .mark .parametrize ("default_trainer" , (True , False ))
424+ def test_resaving (
425+ self ,
426+ model_cls : tp .Type [TransformerModelBase ],
427+ dataset : Dataset ,
428+ default_trainer : bool ,
429+ fit : bool ,
430+ ) -> None :
431+ config : tp .Dict [str , tp .Any ] = {"deterministic" : True }
432+ if not default_trainer :
433+ config ["get_trainer_func" ] = custom_trainer
434+ model = model_cls .from_config (config )
435+
436+ seed_everything (32 , workers = True )
437+ if fit :
438+ model .fit (dataset )
439+
440+ with NamedTemporaryFile () as f :
441+ model .save (f .name )
442+ recovered_model = model_cls .load (f .name )
443+
444+ with NamedTemporaryFile () as f :
445+ recovered_model .save (f .name )
446+ second_recovered_model = model_cls .load (f .name )
447+
448+ assert isinstance (recovered_model , model_cls )
449+
450+ original_model_config = model .get_config ()
451+ second_recovered_model_config = recovered_model .get_config ()
452+ assert second_recovered_model_config == original_model_config
453+
454+ if fit :
455+ assert_pl_models_equal (model .lightning_model , second_recovered_model .lightning_model )
456+
457+ # check if trainer keep state on multiple call partial fit
458+ @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
459+ def test_fit_partial_multiple_times (
460+ self ,
461+ dataset : Dataset ,
462+ model_cls : tp .Type [TransformerModelBase ],
463+ ) -> None :
464+ class FixSeedLightningModule (TransformerLightningModule ):
465+ def on_train_epoch_start (self ) -> None :
466+ seed_everything (32 , workers = True )
467+
468+ seed_everything (32 , workers = True )
469+ model = model_cls .from_config (
470+ {
471+ "epochs" : 3 ,
472+ "data_preparator_kwargs" : {"shuffle_train" : False },
473+ "get_trainer_func" : custom_trainer ,
474+ "lightning_module_type" : FixSeedLightningModule ,
475+ }
476+ )
477+ model .fit_partial (dataset , min_epochs = 1 , max_epochs = 1 )
478+ t1 = deepcopy (model .fit_trainer )
479+ model .fit_partial (
480+ Dataset .construct (pd .DataFrame (columns = Columns .Interactions )),
481+ min_epochs = 1 ,
482+ max_epochs = 1 ,
483+ )
484+ t2 = deepcopy (model .fit_trainer )
485+
486+ # Since for the second we are fitting on an empty dataset,
487+ # the trainer state should be kept exactly the same as after the first fit
488+ # to prove that fit_partial does not change trainer state before proceeding to training."
489+ assert t1 is not None
490+ assert t2 is not None
491+ assert_pl_trainers_equal (t1 , t2 )
492+
493+ @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
494+ def test_raises_when_fit_trainer_is_none_on_save_trained_model (
495+ self , model_cls : tp .Type [TransformerModelBase ], dataset : Dataset
496+ ) -> None :
497+ config : tp .Dict [str , tp .Any ] = {"deterministic" : True }
498+ model = model_cls .from_config (config )
499+
500+ seed_everything (32 , workers = True )
501+ model .fit (dataset )
502+ model .fit_trainer = None
503+
504+ with NamedTemporaryFile () as f :
505+ with pytest .raises (RuntimeError ):
506+ model .save (f .name )
507+
508+ @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
509+ def test_raises_when_fit_trainer_is_none_on_fit_partial_trained_model (
510+ self , model_cls : tp .Type [TransformerModelBase ], dataset : Dataset
511+ ) -> None :
512+ config : tp .Dict [str , tp .Any ] = {"deterministic" : True }
513+ model = model_cls .from_config (config )
514+
515+ seed_everything (32 , workers = True )
516+ model .fit (dataset )
517+ model .fit_trainer = None
518+
519+ with pytest .raises (RuntimeError ):
520+ model .fit_partial (
521+ dataset ,
522+ min_epochs = 1 ,
523+ max_epochs = 1 ,
524+ )
0 commit comments