Skip to content

Commit 9e08511

Browse files
committed
for mismatch state use default init
1 parent 0add04b commit 9e08511

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

chebai/models/ffn.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from collections import OrderedDict
12
from typing import Any, Dict, List, Optional, Tuple
23

34
import torch
45
from torch import Tensor, nn
56

67
from chebai.models import ChebaiBaseNet
7-
8-
from .electra import filter_dict
8+
from chebai.models.electra import filter_dict
99

1010

1111
class FFN(ChebaiBaseNet):
@@ -44,8 +44,24 @@ def __init__(
4444
state_dict = filter_dict(ckpt_file["state_dict"], load_prefix)
4545
else:
4646
state_dict = ckpt_file["state_dict"]
47-
self.model.load_state_dict(state_dict)
48-
print(f"Loaded pretrained weights from {pretrained_checkpoint}")
47+
48+
model_sd = self.model.state_dict()
49+
filtered = OrderedDict()
50+
skipped = set()
51+
for k, v in state_dict.items():
52+
if model_sd[k].shape == v.shape:
53+
filtered[k] = v # only load params with matching shapes
54+
else:
55+
skipped.add(k)
56+
filtered[k] = model_sd[k]
57+
# else: silently skip mismatched keys like "2.weight", "2.bias"
58+
# which is the last linear layers which maps to output dimension
59+
60+
self.model.load_state_dict(filtered)
61+
print(
62+
f"Loaded (shape-matched) weights from {pretrained_checkpoint}",
63+
f"Skipped the following weights: {skipped}",
64+
)
4965

5066
def _get_prediction_and_labels(self, data, labels, model_output):
5167
d = model_output["logits"]

0 commit comments

Comments
 (0)