|
| 1 | +from collections import OrderedDict |
1 | 2 | from typing import Any, Dict, List, Optional, Tuple |
2 | 3 |
|
3 | 4 | import torch |
4 | 5 | from torch import Tensor, nn |
5 | 6 |
|
6 | 7 | from chebai.models import ChebaiBaseNet |
7 | | - |
8 | | -from .electra import filter_dict |
| 8 | +from chebai.models.electra import filter_dict |
9 | 9 |
|
10 | 10 |
|
11 | 11 | class FFN(ChebaiBaseNet): |
@@ -44,8 +44,24 @@ def __init__( |
44 | 44 | state_dict = filter_dict(ckpt_file["state_dict"], load_prefix) |
45 | 45 | else: |
46 | 46 | state_dict = ckpt_file["state_dict"] |
47 | | - self.model.load_state_dict(state_dict) |
48 | | - print(f"Loaded pretrained weights from {pretrained_checkpoint}") |
| 47 | + |
| 48 | + model_sd = self.model.state_dict() |
| 49 | + filtered = OrderedDict() |
| 50 | + skipped = set() |
| 51 | + for k, v in state_dict.items(): |
| 52 | + if model_sd[k].shape == v.shape: |
| 53 | + filtered[k] = v # only load params with matching shapes |
| 54 | + else: |
| 55 | + skipped.add(k) |
| 56 | + filtered[k] = model_sd[k] |
| 57 | + # else: silently skip mismatched keys like "2.weight", "2.bias" |
| 58 | + # which is the last linear layers which maps to output dimension |
| 59 | + |
| 60 | + self.model.load_state_dict(filtered) |
| 61 | + print( |
| 62 | + f"Loaded (shape-matched) weights from {pretrained_checkpoint}", |
| 63 | + f"Skipped the following weights: {skipped}", |
| 64 | + ) |
49 | 65 |
|
50 | 66 | def _get_prediction_and_labels(self, data, labels, model_output): |
51 | 67 | d = model_output["logits"] |
|
0 commit comments