Skip to content

Commit 27afc39

Browse files
committed
fix(train): missing ckpt var
1 parent 75b6ab6 commit 27afc39

File tree

1 file changed

+3
-128
lines changed

1 file changed

+3
-128
lines changed

infer/lib/train/utils.py

Lines changed: 3 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -16,62 +16,12 @@
1616
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
1717
logger = logging
1818

19-
"""
20-
def load_checkpoint_d(checkpoint_path, combd, sbd, optimizer=None, load_opt=1):
21-
assert os.path.isfile(checkpoint_path)
22-
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
23-
24-
##################
25-
def go(model, bkey):
26-
saved_state_dict = checkpoint_dict[bkey]
27-
if hasattr(model, "module"):
28-
state_dict = model.module.state_dict()
29-
else:
30-
state_dict = model.state_dict()
31-
new_state_dict = {}
32-
for k, v in state_dict.items(): # 模型需要的shape
33-
try:
34-
new_state_dict[k] = saved_state_dict[k]
35-
if saved_state_dict[k].shape != state_dict[k].shape:
36-
logger.warning(
37-
"shape-%s-mismatch. need: %s, get: %s",
38-
k,
39-
state_dict[k].shape,
40-
saved_state_dict[k].shape,
41-
) #
42-
raise KeyError
43-
except:
44-
# logger.info(traceback.format_exc())
45-
logger.info("%s is not in the checkpoint", k) # pretrain缺失的
46-
new_state_dict[k] = v # 模型自带的随机值
47-
if hasattr(model, "module"):
48-
model.module.load_state_dict(new_state_dict, strict=False)
49-
else:
50-
model.load_state_dict(new_state_dict, strict=False)
51-
return model
52-
53-
go(combd, "combd")
54-
model = go(sbd, "sbd")
55-
#############
56-
logger.info("Loaded model weights")
57-
58-
iteration = checkpoint_dict["iteration"]
59-
learning_rate = checkpoint_dict["learning_rate"]
60-
if (
61-
optimizer is not None and load_opt == 1
62-
): ###加载不了,如果是空的的话,重新初始化,可能还会影响lr时间表的更新,因此在train文件最外围catch
63-
# try:
64-
optimizer.load_state_dict(checkpoint_dict["optimizer"])
65-
# except:
66-
# traceback.print_exc()
67-
logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, iteration))
68-
return model, optimizer, learning_rate, iteration
69-
"""
70-
7119

7220
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
7321
assert os.path.isfile(checkpoint_path)
74-
saved_state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)["model"]
22+
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=True)
23+
24+
saved_state_dict = checkpoint_dict["model"]
7525
if hasattr(model, "module"):
7626
state_dict = model.module.state_dict()
7727
else:
@@ -132,34 +82,6 @@ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path)
13282
)
13383

13484

135-
"""
136-
def save_checkpoint_d(combd, sbd, optimizer, learning_rate, iteration, checkpoint_path):
137-
logger.info(
138-
"Saving model and optimizer state at epoch {} to {}".format(
139-
iteration, checkpoint_path
140-
)
141-
)
142-
if hasattr(combd, "module"):
143-
state_dict_combd = combd.module.state_dict()
144-
else:
145-
state_dict_combd = combd.state_dict()
146-
if hasattr(sbd, "module"):
147-
state_dict_sbd = sbd.module.state_dict()
148-
else:
149-
state_dict_sbd = sbd.state_dict()
150-
torch.save(
151-
{
152-
"combd": state_dict_combd,
153-
"sbd": state_dict_sbd,
154-
"iteration": iteration,
155-
"optimizer": optimizer.state_dict(),
156-
"learning_rate": learning_rate,
157-
},
158-
checkpoint_path,
159-
)
160-
"""
161-
162-
16385
def summarize(
16486
writer,
16587
global_step,
@@ -366,53 +288,6 @@ def get_hparams(init=True):
366288
return hparams
367289

368290

369-
"""
370-
def get_hparams_from_dir(model_dir):
371-
config_save_path = os.path.join(model_dir, "config.json")
372-
with open(config_save_path, "r") as f:
373-
data = f.read()
374-
config = json.loads(data)
375-
376-
hparams = HParams(**config)
377-
hparams.model_dir = model_dir
378-
return hparams
379-
380-
381-
def get_hparams_from_file(config_path):
382-
with open(config_path, "r") as f:
383-
data = f.read()
384-
config = json.loads(data)
385-
386-
hparams = HParams(**config)
387-
return hparams
388-
389-
390-
def check_git_hash(model_dir):
391-
source_dir = os.path.dirname(os.path.realpath(__file__))
392-
if not os.path.exists(os.path.join(source_dir, ".git")):
393-
logger.warning(
394-
"{} is not a git repository, therefore hash value comparison will be ignored.".format(
395-
source_dir
396-
)
397-
)
398-
return
399-
400-
cur_hash = subprocess.getoutput("git rev-parse HEAD")
401-
402-
path = os.path.join(model_dir, "githash")
403-
if os.path.exists(path):
404-
saved_hash = open(path).read()
405-
if saved_hash != cur_hash:
406-
logger.warning(
407-
"git hash values are different. {}(saved) != {}(current)".format(
408-
saved_hash[:8], cur_hash[:8]
409-
)
410-
)
411-
else:
412-
open(path, "w").write(cur_hash)
413-
"""
414-
415-
416291
def get_logger(model_dir, filename="train.log"):
417292
global logger
418293
logger = logging.getLogger(os.path.basename(model_dir))

0 commit comments

Comments
 (0)