Skip to content

Commit b60c5d1

Browse files
committed
Fix a possible crash when using checkpoints
1 parent 9fea5b9 commit b60c5d1

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ The following parameters can be set (taken from `./tibert/run_train.py` config f
120120
| `segment_size` | `128` |
121121
| `encoder` | `"bert-base-cased"` |
122122
| `out_model_dir` | `"~/tibert/model"` |
123+
| `checkpoint` | `None` |
123124

124125

125126
One can monitor training metrics by adding run observers using command line flags - see `sacred` documentation for more details.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tibert"
3-
version = "0.2.1"
3+
version = "0.2.2"
44
description = "BERT for Coreference Resolution"
55
authors = ["Arthur Amalvy <[email protected]>"]
66
license = "GPL-3.0-only"

tibert/train.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def load_train_checkpoint(
5151

5252
model_config = config_class(**checkpoint["model_config"])
5353
model = model_class(model_config)
54-
model.load_state_dict(checkpoint["model"])
54+
model.load_state_dict(checkpoint["model"], strict=False)
5555

5656
optimizer = torch.optim.AdamW(
5757
[
@@ -68,6 +68,17 @@ def load_train_checkpoint(
6868
return model, optimizer
6969

7070

71+
def _optimizer_to_(
72+
optimizer: torch.optim.AdamW, device: torch.device
73+
) -> torch.optim.AdamW:
74+
"""From https://github.com/pytorch/pytorch/issues/2830"""
75+
for state in optimizer.state.values():
76+
for k, v in state.items():
77+
if isinstance(v, torch.Tensor):
78+
state[k] = v.cuda()
79+
return optimizer
80+
81+
7182
def train_coref_model(
7283
model: Union[BertForCoreferenceResolution, CamembertForCoreferenceResolution],
7384
dataset: CoreferenceDataset,
@@ -161,8 +172,8 @@ def train_coref_model(
161172
],
162173
lr=task_lr,
163174
)
175+
optimizer = _optimizer_to_(optimizer, device)
164176

165-
# Best model saving
166177
# -----------------
167178
best_f1 = 0
168179
best_model = model

0 commit comments

Comments
 (0)