|
16 | 16 | logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) |
17 | 17 | logger = logging |
18 | 18 |
|
19 | | -""" |
20 | | -def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1): |
21 | | - assert os.path.isfile(checkpoint_path) |
22 | | - checkpoint_dict = torch.load(checkpoint_path, map_location="cpu") |
23 | | -
|
24 | | - ################## |
25 | | - def go(model, bkey): |
26 | | - saved_state_dict = checkpoint_dict[bkey] |
27 | | - if hasattr(model, "module"): |
28 | | - state_dict = model.module.state_dict() |
29 | | - else: |
30 | | - state_dict = model.state_dict() |
31 | | - new_state_dict = {} |
32 | | - for k, v in state_dict.items(): # 模型需要的shape |
33 | | - try: |
34 | | - new_state_dict[k] = saved_state_dict[k] |
35 | | - if saved_state_dict[k].shape != state_dict[k].shape: |
36 | | - logger.warning( |
37 | | - "shape-%s-mismatch. need: %s, get: %s", |
38 | | - k, |
39 | | - state_dict[k].shape, |
40 | | - saved_state_dict[k].shape, |
41 | | - ) # |
42 | | - raise KeyError |
43 | | - except: |
44 | | - # logger.info(traceback.format_exc()) |
45 | | - logger.info("%s is not in the checkpoint", k) # pretrain缺失的 |
46 | | - new_state_dict[k] = v # 模型自带的随机值 |
47 | | - if hasattr(model, "module"): |
48 | | - model.module.load_state_dict(new_state_dict, strict=False) |
49 | | - else: |
50 | | - model.load_state_dict(new_state_dict, strict=False) |
51 | | - return model |
52 | | -
|
53 | | - go(combd, "combd") |
54 | | - model = go(sbd, "sbd") |
55 | | - ############# |
56 | | - logger.info("Loaded model weights") |
57 | | -
|
58 | | - iteration = checkpoint_dict["iteration"] |
59 | | - learning_rate = checkpoint_dict["learning_rate"] |
60 | | - if ( |
61 | | - optimizer is not None and load_opt == 1 |
62 | | - ): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch |
63 | | - # try: |
64 | | - optimizer.load_state_dict(checkpoint_dict["optimizer"]) |
65 | | - # except: |
66 | | - # traceback.print_exc() |
67 | | - logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration)) |
68 | | - return model, optimizer, learning_rate, iteration |
69 | | -""" |
70 | | - |
71 | 19 |
|
72 | 20 | def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1): |
73 | 21 | assert os.path.isfile(checkpoint_path) |
74 | | - saved_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model"] |
| 22 | + checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True) |
| 23 | + |
| 24 | + saved_state_dict = checkpoint_dict["model"] |
75 | 25 | if hasattr(model, "module"): |
76 | 26 | state_dict = model.module.state_dict() |
77 | 27 | else: |
@@ -132,34 +82,6 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) |
132 | 82 | ) |
133 | 83 |
|
134 | 84 |
|
135 | | -""" |
136 | | -def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path): |
137 | | - logger.info( |
138 | | - "Saving model and optimizer state at epoch {} to {}".format( |
139 | | - iteration, checkpoint_path |
140 | | - ) |
141 | | - ) |
142 | | - if hasattr(combd, "module"): |
143 | | - state_dict_combd = combd.module.state_dict() |
144 | | - else: |
145 | | - state_dict_combd = combd.state_dict() |
146 | | - if hasattr(sbd, "module"): |
147 | | - state_dict_sbd = sbd.module.state_dict() |
148 | | - else: |
149 | | - state_dict_sbd = sbd.state_dict() |
150 | | - torch.save( |
151 | | - { |
152 | | - "combd": state_dict_combd, |
153 | | - "sbd": state_dict_sbd, |
154 | | - "iteration": iteration, |
155 | | - "optimizer": optimizer.state_dict(), |
156 | | - "learning_rate": learning_rate, |
157 | | - }, |
158 | | - checkpoint_path, |
159 | | - ) |
160 | | -""" |
161 | | - |
162 | | - |
163 | 85 | def summarize( |
164 | 86 | writer, |
165 | 87 | global_step, |
@@ -366,53 +288,6 @@ def get_hparams(init=True): |
366 | 288 | return hparams |
367 | 289 |
|
368 | 290 |
|
369 | | -""" |
370 | | -def get_hparams_from_dir(model_dir): |
371 | | - config_save_path = os.path.join(model_dir, "config.json") |
372 | | - with open(config_save_path, "r") as f: |
373 | | - data = f.read() |
374 | | - config = json.loads(data) |
375 | | -
|
376 | | - hparams = HParams(**config) |
377 | | - hparams.model_dir = model_dir |
378 | | - return hparams |
379 | | -
|
380 | | -
|
381 | | -def get_hparams_from_file(config_path): |
382 | | - with open(config_path, "r") as f: |
383 | | - data = f.read() |
384 | | - config = json.loads(data) |
385 | | -
|
386 | | - hparams = HParams(**config) |
387 | | - return hparams |
388 | | -
|
389 | | -
|
390 | | -def check_git_hash(model_dir): |
391 | | - source_dir = os.path.dirname(os.path.realpath(__file__)) |
392 | | - if not os.path.exists(os.path.join(source_dir, ".git")): |
393 | | - logger.warning( |
394 | | - "{} is not a git repository, therefore hash value comparison will be ignored.".format( |
395 | | - source_dir |
396 | | - ) |
397 | | - ) |
398 | | - return |
399 | | -
|
400 | | - cur_hash = subprocess.getoutput("git rev-parse HEAD") |
401 | | -
|
402 | | - path = os.path.join(model_dir, "githash") |
403 | | - if os.path.exists(path): |
404 | | - saved_hash = open(path).read() |
405 | | - if saved_hash != cur_hash: |
406 | | - logger.warning( |
407 | | - "git hash values are different. {}(saved) != {}(current)".format( |
408 | | - saved_hash[:8], cur_hash[:8] |
409 | | - ) |
410 | | - ) |
411 | | - else: |
412 | | - open(path, "w").write(cur_hash) |
413 | | -""" |
414 | | - |
415 | | - |
416 | 291 | def get_logger(model_dir, filename="train.log"): |
417 | 292 | global logger |
418 | 293 | logger = logging.getLogger(os.path.basename(model_dir)) |
|
0 commit comments