Skip to content

Commit c90ff5e

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
simplify logging during caffe2 model loading
Summary: Group similar weight names together. Produce ~5x fewer lines Reviewed By: theschnitz Differential Revision: D26421225 fbshipit-source-id: 0696af09fe18d0faa47d8c1d1e6f8d26081dee41
1 parent c470b67 commit c90ff5e

File tree

3 files changed

+127
-39
lines changed

3 files changed

+127
-39
lines changed

detectron2/checkpoint/c2_model_loading.py

Lines changed: 123 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,9 @@
22
import copy
33
import logging
44
import re
5+
from typing import Dict, List
56
import torch
6-
from fvcore.common.checkpoint import (
7-
get_missing_parameters_message,
8-
get_unexpected_parameters_message,
9-
)
7+
from tabulate import tabulate
108

119

1210
def convert_basic_c2_names(original_keys):
@@ -77,7 +75,7 @@ def convert_c2_detectron_names(weights):
7775
dict: detectron2 names -> C2 names
7876
"""
7977
logger = logging.getLogger(__name__)
80-
logger.info("Remapping C2 weights ......")
78+
logger.info("Renaming Caffe2 weights ......")
8179
original_keys = sorted(weights.keys())
8280
layer_keys = copy.deepcopy(original_keys)
8381

@@ -210,8 +208,9 @@ def fpn_map(name):
210208
# it assumes model_state_dict will have longer names.
211209
def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True):
212210
"""
213-
Match names between the two state-dict, and update the values of model_state_dict in-place with
214-
copies of the matched tensor in ckpt_state_dict.
211+
Match names between the two state-dict, and returns a new chkpt_state_dict with names
212+
converted to match model_state_dict with heuristics. The returned dict can be later
213+
loaded with fvcore checkpointer.
215214
If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2
216215
model and will be renamed at first.
217216
@@ -251,13 +250,10 @@ def match(a, b):
251250
# remove indices that correspond to no-match
252251
idxs[max_match_size == 0] = -1
253252

254-
# used for logging
255-
max_len_model = max(len(key) for key in model_keys) if model_keys else 1
256-
max_len_ckpt = max(len(key) for key in ckpt_keys) if ckpt_keys else 1
257-
log_str_template = "{: <{}} loaded from {: <{}} of shape {}"
258253
logger = logging.getLogger(__name__)
259254
# matched_pairs (matched checkpoint key --> matched model key)
260255
matched_keys = {}
256+
result_state_dict = {}
261257
for idx_model, idx_ckpt in enumerate(idxs.tolist()):
262258
if idx_ckpt == -1:
263259
continue
@@ -279,7 +275,8 @@ def match(a, b):
279275
)
280276
continue
281277

282-
model_state_dict[key_model] = value_ckpt.clone()
278+
assert key_model not in result_state_dict
279+
result_state_dict[key_model] = value_ckpt
283280
if key_ckpt in matched_keys: # already added to matched_keys
284281
logger.error(
285282
"Ambiguity found for {} in checkpoint!"
@@ -290,24 +287,118 @@ def match(a, b):
290287
raise ValueError("Cannot match one checkpoint key to multiple keys in the model.")
291288

292289
matched_keys[key_ckpt] = key_model
293-
logger.info(
294-
log_str_template.format(
295-
key_model,
296-
max_len_model,
297-
original_keys[key_ckpt],
298-
max_len_ckpt,
299-
tuple(shape_in_model),
290+
291+
# logging:
292+
matched_model_keys = sorted(matched_keys.values())
293+
common_prefix = _longest_common_prefix(matched_model_keys)
294+
rev_matched_keys = {v: k for k, v in matched_keys.items()}
295+
original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys}
296+
297+
model_key_groups = _group_keys_by_module(matched_model_keys, original_keys)
298+
table = []
299+
memo = set()
300+
for key_model in matched_model_keys:
301+
if key_model in memo:
302+
continue
303+
if key_model in model_key_groups:
304+
group = model_key_groups[key_model]
305+
memo |= set(group)
306+
shapes = [tuple(model_state_dict[k].shape) for k in group]
307+
table.append(
308+
(
309+
_longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*",
310+
_group_str([original_keys[k] for k in group]),
311+
" ".join([str(x).replace(" ", "") for x in shapes]),
312+
)
300313
)
301-
)
302-
matched_model_keys = matched_keys.values()
303-
matched_ckpt_keys = matched_keys.keys()
304-
# print warnings about unmatched keys on both side
305-
unmatched_model_keys = [k for k in model_keys if k not in matched_model_keys]
306-
if len(unmatched_model_keys):
307-
logger.info(get_missing_parameters_message(unmatched_model_keys))
308-
309-
unmatched_ckpt_keys = [k for k in ckpt_keys if k not in matched_ckpt_keys]
310-
if len(unmatched_ckpt_keys):
311-
logger.info(
312-
get_unexpected_parameters_message(original_keys[x] for x in unmatched_ckpt_keys)
313-
)
314+
else:
315+
key_checkpoint = original_keys[key_model]
316+
shape = str(tuple(model_state_dict[key_model].shape))
317+
table.append((key_model[len(common_prefix) :], key_checkpoint, shape))
318+
table_str = tabulate(
319+
table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"]
320+
)
321+
logger.info(
322+
"Following weights matched with "
323+
+ (f"submodule {common_prefix[:-1]}" if common_prefix else "model")
324+
+ ":\n"
325+
+ table_str
326+
)
327+
328+
unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())]
329+
for k in unmatched_ckpt_keys:
330+
result_state_dict[k] = ckpt_state_dict[k]
331+
return result_state_dict
332+
333+
334+
def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]):
335+
"""
336+
Params in the same submodule are grouped together.
337+
338+
Args:
339+
keys: names of all parameters
340+
original_names: mapping from parameter name to their name in the checkpoint
341+
342+
Returns:
343+
dict[name -> all other names in the same group]
344+
"""
345+
346+
def _submodule_name(key):
347+
pos = key.rfind(".")
348+
if pos < 0:
349+
return None
350+
prefix = key[: pos + 1]
351+
return prefix
352+
353+
all_submodules = [_submodule_name(k) for k in keys]
354+
all_submodules = [x for x in all_submodules if x]
355+
all_submodules = sorted(all_submodules, key=len)
356+
357+
ret = {}
358+
for prefix in all_submodules:
359+
group = [k for k in keys if k.startswith(prefix)]
360+
if len(group) <= 1:
361+
continue
362+
original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group])
363+
if len(original_name_lcp) == 0:
364+
# don't group weights if original names don't share prefix
365+
continue
366+
367+
for k in group:
368+
if k in ret:
369+
continue
370+
ret[k] = group
371+
return ret
372+
373+
374+
def _longest_common_prefix(names: List[str]) -> str:
375+
"""
376+
["abc.zfg", "abc.zef"] -> "abc."
377+
"""
378+
names = [n.split(".") for n in names]
379+
m1, m2 = min(names), max(names)
380+
ret = [a for a, b in zip(m1, m2) if a == b]
381+
ret = ".".join(ret) + "." if len(ret) else ""
382+
return ret
383+
384+
385+
def _longest_common_prefix_str(names: List[str]) -> str:
386+
m1, m2 = min(names), max(names)
387+
lcp = [a for a, b in zip(m1, m2) if a == b]
388+
lcp = "".join(lcp)
389+
return lcp
390+
391+
392+
def _group_str(names: List[str]) -> str:
393+
"""
394+
Turn "common1", "common2", "common3" into "common{1,2,3}"
395+
"""
396+
lcp = _longest_common_prefix_str(names)
397+
rest = [x[len(lcp) :] for x in names]
398+
rest = "{" + ",".join(rest) + "}"
399+
ret = lcp + rest
400+
401+
# add some simplification for BN specifically
402+
ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*")
403+
ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*")
404+
return ret

detectron2/checkpoint/detection_checkpoint.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,17 +60,13 @@ def _load_model(self, checkpoint):
6060
if checkpoint.get("matching_heuristics", False):
6161
self._convert_ndarray_to_tensor(checkpoint["model"])
6262
# convert weights by name-matching heuristics
63-
model_state_dict = self.model.state_dict()
64-
align_and_update_state_dicts(
65-
model_state_dict,
63+
checkpoint["model"] = align_and_update_state_dicts(
64+
self.model.state_dict(),
6665
checkpoint["model"],
6766
c2_conversion=checkpoint.get("__author__", None) == "Caffe2",
6867
)
69-
checkpoint["model"] = model_state_dict
7068
# for non-caffe2 models, use standard ways to load it
7169
incompatible = super()._load_model(checkpoint)
72-
if incompatible is None: # support older versions of fvcore
73-
return None
7470

7571
model_buffers = dict(self.model.named_buffers(recurse=False))
7672
for k in ["pixel_mean", "pixel_std"]:

tests/test_checkpoint.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ def test_complex_model_loaded(self):
3636
model = nn.DataParallel(model)
3737
model_sd = model.state_dict()
3838

39-
align_and_update_state_dicts(model_sd, state_dict)
39+
sd_to_load = align_and_update_state_dicts(model_sd, state_dict)
40+
model.load_state_dict(sd_to_load)
4041
for loaded, stored in zip(model_sd.values(), state_dict.values()):
4142
# different tensor references
4243
self.assertFalse(id(loaded) == id(stored))

0 commit comments

Comments
 (0)