22import copy
33import logging
44import re
5+ from typing import Dict , List
56import torch
6- from fvcore .common .checkpoint import (
7- get_missing_parameters_message ,
8- get_unexpected_parameters_message ,
9- )
7+ from tabulate import tabulate
108
119
1210def convert_basic_c2_names (original_keys ):
@@ -77,7 +75,7 @@ def convert_c2_detectron_names(weights):
7775 dict: detectron2 names -> C2 names
7876 """
7977 logger = logging .getLogger (__name__ )
80- logger .info ("Remapping C2 weights ......" )
78+ logger .info ("Renaming Caffe2 weights ......" )
8179 original_keys = sorted (weights .keys ())
8280 layer_keys = copy .deepcopy (original_keys )
8381
@@ -210,8 +208,9 @@ def fpn_map(name):
210208# it assumes model_state_dict will have longer names.
211209def align_and_update_state_dicts (model_state_dict , ckpt_state_dict , c2_conversion = True ):
212210 """
213- Match names between the two state-dict, and update the values of model_state_dict in-place with
214- copies of the matched tensor in ckpt_state_dict.
211+ Match names between the two state-dict, and returns a new chkpt_state_dict with names
212+ converted to match model_state_dict with heuristics. The returned dict can be later
213+ loaded with fvcore checkpointer.
215214 If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
216215 model and will be renamed at first.
217216
@@ -251,13 +250,10 @@ def match(a, b):
251250 # remove indices that correspond to no-match
252251 idxs [max_match_size == 0 ] = - 1
253252
254- # used for logging
255- max_len_model = max (len (key ) for key in model_keys ) if model_keys else 1
256- max_len_ckpt = max (len (key ) for key in ckpt_keys ) if ckpt_keys else 1
257- log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
258253 logger = logging .getLogger (__name__ )
259254 # matched_pairs (matched checkpoint key --> matched model key)
260255 matched_keys = {}
256+ result_state_dict = {}
261257 for idx_model , idx_ckpt in enumerate (idxs .tolist ()):
262258 if idx_ckpt == - 1 :
263259 continue
@@ -279,7 +275,8 @@ def match(a, b):
279275 )
280276 continue
281277
282- model_state_dict [key_model ] = value_ckpt .clone ()
278+ assert key_model not in result_state_dict
279+ result_state_dict [key_model ] = value_ckpt
283280 if key_ckpt in matched_keys : # already added to matched_keys
284281 logger .error (
285282 "Ambiguity found for {} in checkpoint!"
@@ -290,24 +287,118 @@ def match(a, b):
290287 raise ValueError ("Cannot match one checkpoint key to multiple keys in the model." )
291288
292289 matched_keys [key_ckpt ] = key_model
293- logger .info (
294- log_str_template .format (
295- key_model ,
296- max_len_model ,
297- original_keys [key_ckpt ],
298- max_len_ckpt ,
299- tuple (shape_in_model ),
290+
291+ # logging:
292+ matched_model_keys = sorted (matched_keys .values ())
293+ common_prefix = _longest_common_prefix (matched_model_keys )
294+ rev_matched_keys = {v : k for k , v in matched_keys .items ()}
295+ original_keys = {k : original_keys [rev_matched_keys [k ]] for k in matched_model_keys }
296+
297+ model_key_groups = _group_keys_by_module (matched_model_keys , original_keys )
298+ table = []
299+ memo = set ()
300+ for key_model in matched_model_keys :
301+ if key_model in memo :
302+ continue
303+ if key_model in model_key_groups :
304+ group = model_key_groups [key_model ]
305+ memo |= set (group )
306+ shapes = [tuple (model_state_dict [k ].shape ) for k in group ]
307+ table .append (
308+ (
309+ _longest_common_prefix ([k [len (common_prefix ) :] for k in group ]) + "*" ,
310+ _group_str ([original_keys [k ] for k in group ]),
311+ " " .join ([str (x ).replace (" " , "" ) for x in shapes ]),
312+ )
300313 )
301- )
302- matched_model_keys = matched_keys .values ()
303- matched_ckpt_keys = matched_keys .keys ()
304- # print warnings about unmatched keys on both side
305- unmatched_model_keys = [k for k in model_keys if k not in matched_model_keys ]
306- if len (unmatched_model_keys ):
307- logger .info (get_missing_parameters_message (unmatched_model_keys ))
308-
309- unmatched_ckpt_keys = [k for k in ckpt_keys if k not in matched_ckpt_keys ]
310- if len (unmatched_ckpt_keys ):
311- logger .info (
312- get_unexpected_parameters_message (original_keys [x ] for x in unmatched_ckpt_keys )
313- )
314+ else :
315+ key_checkpoint = original_keys [key_model ]
316+ shape = str (tuple (model_state_dict [key_model ].shape ))
317+ table .append ((key_model [len (common_prefix ) :], key_checkpoint , shape ))
318+ table_str = tabulate (
319+ table , tablefmt = "pipe" , headers = ["Names in Model" , "Names in Checkpoint" , "Shapes" ]
320+ )
321+ logger .info (
322+ "Following weights matched with "
323+ + (f"submodule { common_prefix [:- 1 ]} " if common_prefix else "model" )
324+ + ":\n "
325+ + table_str
326+ )
327+
328+ unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set (matched_keys .keys ())]
329+ for k in unmatched_ckpt_keys :
330+ result_state_dict [k ] = ckpt_state_dict [k ]
331+ return result_state_dict
332+
333+
334+ def _group_keys_by_module (keys : List [str ], original_names : Dict [str , str ]):
335+ """
336+ Params in the same submodule are grouped together.
337+
338+ Args:
339+ keys: names of all parameters
340+ original_names: mapping from parameter name to their name in the checkpoint
341+
342+ Returns:
343+ dict[name -> all other names in the same group]
344+ """
345+
346+ def _submodule_name (key ):
347+ pos = key .rfind ("." )
348+ if pos < 0 :
349+ return None
350+ prefix = key [: pos + 1 ]
351+ return prefix
352+
353+ all_submodules = [_submodule_name (k ) for k in keys ]
354+ all_submodules = [x for x in all_submodules if x ]
355+ all_submodules = sorted (all_submodules , key = len )
356+
357+ ret = {}
358+ for prefix in all_submodules :
359+ group = [k for k in keys if k .startswith (prefix )]
360+ if len (group ) <= 1 :
361+ continue
362+ original_name_lcp = _longest_common_prefix_str ([original_names [k ] for k in group ])
363+ if len (original_name_lcp ) == 0 :
364+ # don't group weights if original names don't share prefix
365+ continue
366+
367+ for k in group :
368+ if k in ret :
369+ continue
370+ ret [k ] = group
371+ return ret
372+
373+
374+ def _longest_common_prefix (names : List [str ]) -> str :
375+ """
376+ ["abc.zfg", "abc.zef"] -> "abc."
377+ """
378+ names = [n .split ("." ) for n in names ]
379+ m1 , m2 = min (names ), max (names )
380+ ret = [a for a , b in zip (m1 , m2 ) if a == b ]
381+ ret = "." .join (ret ) + "." if len (ret ) else ""
382+ return ret
383+
384+
385+ def _longest_common_prefix_str (names : List [str ]) -> str :
386+ m1 , m2 = min (names ), max (names )
387+ lcp = [a for a , b in zip (m1 , m2 ) if a == b ]
388+ lcp = "" .join (lcp )
389+ return lcp
390+
391+
392+ def _group_str (names : List [str ]) -> str :
393+ """
394+ Turn "common1", "common2", "common3" into "common{1,2,3}"
395+ """
396+ lcp = _longest_common_prefix_str (names )
397+ rest = [x [len (lcp ) :] for x in names ]
398+ rest = "{" + "," .join (rest ) + "}"
399+ ret = lcp + rest
400+
401+ # add some simplification for BN specifically
402+ ret = ret .replace ("bn_{beta,running_mean,running_var,gamma}" , "bn_*" )
403+ ret = ret .replace ("bn_beta,bn_running_mean,bn_running_var,bn_gamma" , "bn_*" )
404+ return ret
0 commit comments