66
77from datetime import datetime
88from pathlib import Path
9+ from typing import Dict
910
1011from src .misc import dist
1112from src .core import BaseConfig
@@ -28,6 +29,11 @@ def setup(self, ):
2829 self .criterion = cfg .criterion .to (device )
2930 self .postprocessor = cfg .postprocessor
3031
32+ # NOTE (lvwenyu): should load_tuning_state before ema instance building
33+ if self .cfg .tuning :
34+ print (f'Tuning checkpoint from { self .cfg .tuning } ' )
35+ self .load_tuning_state (self .cfg .tuning )
36+
3137 self .scaler = cfg .scaler
3238 self .ema = cfg .ema .to (device ) if cfg .ema is not None else None
3339
@@ -133,10 +139,44 @@ def resume(self, path):
133139 state = torch .load (path , map_location = 'cpu' )
134140 self .load_state_dict (state )
135141
142+ def load_tuning_state (self , path ,):
143+ """only load model for tuning and skip missed/dismatched keys
144+ """
145+ if 'http' in path :
146+ state = torch .hub .load_state_dict_from_url (path , map_location = 'cpu' )
147+ else :
148+ state = torch .load (path , map_location = 'cpu' )
149+
150+ module = dist .de_parallel (self .model )
151+
152+ # TODO hard code
153+ if 'ema' in state :
154+ stat , infos = self ._matched_state (module .state_dict (), state ['ema' ]['module' ])
155+ else :
156+ stat , infos = self ._matched_state (module .state_dict (), state ['model' ])
157+
158+ module .load_state_dict (stat , strict = False )
159+ print (f'Load model.state_dict, { infos } ' )
160+
161+ @staticmethod
162+ def _matched_state (state : Dict [str , torch .Tensor ], params : Dict [str , torch .Tensor ]):
163+ missed_list = []
164+ unmatched_list = []
165+ matched_state = {}
166+ for k , v in state .items ():
167+ if k in params :
168+ if v .shape == params [k ].shape :
169+ matched_state [k ] = params [k ]
170+ else :
171+ unmatched_list .append (k )
172+ else :
173+ missed_list .append (k )
174+
175+ return matched_state , {'missed' : missed_list , 'unmatched' : unmatched_list }
176+
136177
137178 def fit (self , ):
138179 raise NotImplementedError ('' )
139180
140-
141181 def val (self , ):
142182 raise NotImplementedError ('' )
0 commit comments