Skip to content

Question about loading checkpoints #11

@Congwen-x

Description

@Congwen-x

Hi,

Thank you for your open source work!

I have a question about loading checkpoints. I want to reproduce the code in "downstream_finetune_example". When I tried to load checkpoints from Visualized_m3.pth but got a RuntimeError as below.

The model in downstream_finetune_example/run_ds_cirr.py is "BGE_EVAToken", but in the guide, loading model use "Visualized_BGE". Is it cause the error? What can I do to load the checkpoint?

Here are my codes:
run_ds_cirr.py: model.load_state_dict(torch.load(training_args.resume_path, map_location='cuda'))
node1.bash: RESUME_PATH="/home/picture/models/Visualized_m3.pth" # pre-trained visualized bge weights

Here is the error:
11/15/2024 18:29:46 - INFO - main - Traing from checkpoint: /home/picture/models/Visualized_m3.pth
Traceback (most recent call last):
File "./research/visual_bge/visual_bge/run_ds_cirr.py", line 169, in
main()
File "./research/visual_bge/visual_bge/run_ds_cirr.py", line 105, in main
model.load_state_dict(torch.load(training_args.resume_path, map_location='cuda'))
File "/home/conda/envs/v-bge/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BGE_EVAToken:
Unexpected key(s) in state_dict: "model_visual.visual.blocks.12.norm1.weight", "model_visual.visual.blocks.12.norm1.bias", "model_visual.visual.blocks.12.attn.q_bias", "model_visual.visual.blocks.12.attn.v_bias", "model_visual.visual.blocks.12.attn.q_proj.weight", "model_visual.visual.blocks.12.attn.k_proj.weight", ……
size mismatch for model_visual.visual.cls_token: copying a param with shape torch.Size([1, 1, 1024]) from checkpoint, the shape in current model is torch.Size([1, 1, 768]).
size mismatch for model_visual.visual.pos_embed: copying a param with shape torch.Size([1, 257, 1024]) from checkpoint, the shape in current model is torch.Size([1, 197, 768]).
size mismatch for model_visual.visual.patch_embed.proj.weight: copying a param with shape torch.Size([1024, 3, 14, 14]) from checkpoint, the shape in current model is torch.Size([768, 3, 16, 16]).
size mismatch for model_visual.visual.patch_embed.proj.bias: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([768]).
……

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions