2424 Phase .PREDICT : "predict_dataloader" ,
2525}
2626_TRAIN_DL_STATE_KEY = "train_dataloader"
27+
2728_TRAIN_PROGRESS_STATE_KEY = "train_progress"
2829_EVAL_PROGRESS_STATE_KEY = "eval_progress"
30+ _PREDICT_PROGRESS_STATE_KEY = "predict_progress"
2931
3032
3133def _get_step_phase_mapping (
@@ -56,6 +58,30 @@ def _prepare_app_state(unit: AppStateMixin) -> Dict[str, Any]:
5658 return app_state
5759
5860
61+ def _remove_app_state_keys (
62+ unit : AppStateMixin ,
63+ app_state : Dict [str , Any ],
64+ * ,
65+ remove_modules : bool = False ,
66+ remove_optimizers : bool = False ,
67+ remove_lr_schedulers : bool = False ,
68+ ) -> None :
69+ if remove_modules :
70+ # remove all module keys from app_state
71+ for module_keys in unit .tracked_modules ().keys ():
72+ app_state .pop (module_keys , None )
73+
74+ if remove_optimizers :
75+ # remove all optimizer keys from app_state
76+ for optim_keys in unit .tracked_optimizers ().keys ():
77+ app_state .pop (optim_keys , None )
78+
79+ if remove_lr_schedulers :
80+ # remove all lr scheduler keys from app_state
81+ for lr_scheduler_keys in unit .tracked_lr_schedulers ().keys ():
82+ app_state .pop (lr_scheduler_keys , None )
83+
84+
5985def _prepare_app_state_for_checkpoint (
6086 state : State , unit : AppStateMixin , intra_epoch : bool
6187) -> Dict [str , Stateful ]:
@@ -64,6 +90,16 @@ def _prepare_app_state_for_checkpoint(
6490 """
6591 app_state = _prepare_app_state (unit )
6692
93+ if state .entry_point in [EntryPoint .EVALUATE , EntryPoint .PREDICT ]:
94+ # Since model parameters are fixed, remove them from checkpoint.
95+ _remove_app_state_keys (
96+ unit ,
97+ app_state ,
98+ remove_modules = True ,
99+ remove_optimizers = True ,
100+ remove_lr_schedulers = True ,
101+ )
102+
67103 # for intra-epoch checkpointing, include dataloader state of the current phase
68104 phase_dl = state .active_phase_state ().dataloader
69105 if intra_epoch and isinstance (phase_dl , Stateful ):
@@ -85,24 +121,21 @@ def _prepare_app_state_for_restore(
85121
86122 restore_options = restore_options or RestoreOptions ()
87123
88- if not restore_options .restore_modules :
89- for module_keys in unit .tracked_modules ().keys ():
90- app_state .pop (module_keys , None )
91-
92124 if not restore_options .restore_train_progress :
93125 app_state .pop (_TRAIN_PROGRESS_STATE_KEY , None )
94126
95127 if not restore_options .restore_eval_progress :
96128 app_state .pop (_EVAL_PROGRESS_STATE_KEY , None )
97129
98- if not restore_options .restore_optimizers :
99- # remove all optimizer keys from app_state
100- for optim_keys in unit .tracked_optimizers ().keys ():
101- app_state .pop (optim_keys , None )
130+ if not restore_options .restore_predict_progress :
131+ app_state .pop (_PREDICT_PROGRESS_STATE_KEY , None )
102132
103- if not restore_options .restore_lr_schedulers :
104- # remove all lr scheduler keys from app_state
105- for lr_scheduler_keys in unit .tracked_lr_schedulers ().keys ():
106- app_state .pop (lr_scheduler_keys , None )
133+ _remove_app_state_keys (
134+ unit ,
135+ app_state ,
136+ remove_modules = not restore_options .restore_modules ,
137+ remove_optimizers = not restore_options .restore_optimizers ,
138+ remove_lr_schedulers = not restore_options .restore_lr_schedulers ,
139+ )
107140
108141 return app_state
0 commit comments