Skip to content

Commit 2d0fd07

Browse files
committed
Merge branch 'main' of github.com:CompNet/Tibert
2 parents 5c8c970 + 2a78a4d commit 2d0fd07

File tree

5 files changed

+64
-4
lines changed

5 files changed

+64
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ The following parameters can be set (taken from `./tibert/run_train.py` config f
120120
| `segment_size` | `128` |
121121
| `encoder` | `"bert-base-cased"` |
122122
| `out_model_dir` | `"~/tibert/model"` |
123+
| `checkpoint` | `None` |
123124

124125

125126
One can monitor training metrics by adding run observers using command line flags - see `sacred` documentation for more details.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tibert"
3-
version = "0.2.0"
3+
version = "0.2.2"
44
description = "BERT for Coreference Resolution"
55
authors = ["Arthur Amalvy <[email protected]>"]
66
license = "GPL-3.0-only"

tests/test_train.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import tempfile
2+
import torch
3+
from torch.optim import optimizer
4+
from transformers import BertTokenizerFast
5+
from tibert.bertcoref import BertForCoreferenceResolutionConfig
6+
from tibert.train import _save_train_checkpoint, load_train_checkpoint
7+
from tibert import BertForCoreferenceResolution, predict_coref_simple
8+
9+
10+
def test_save_load_checkpoint():
11+
model = BertForCoreferenceResolution(BertForCoreferenceResolutionConfig())
12+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
13+
14+
bert_lr = 1e-5
15+
task_lr = 2e-4
16+
optimizer = torch.optim.AdamW(
17+
[
18+
{"params": model.bert_parameters(), "lr": bert_lr},
19+
{
20+
"params": model.task_parameters(),
21+
"lr": task_lr,
22+
},
23+
],
24+
lr=task_lr,
25+
)
26+
27+
text = "Sli did not want the earpods. He didn't like them."
28+
before_pred = predict_coref_simple(text, model, tokenizer)
29+
30+
with tempfile.TemporaryDirectory() as d:
31+
checkpoint_f = f"{d}/checkpoint.pth"
32+
_save_train_checkpoint(checkpoint_f, model, 1, optimizer, bert_lr, task_lr)
33+
model, new_optimizer = load_train_checkpoint(
34+
checkpoint_f, BertForCoreferenceResolution
35+
)
36+
37+
assert new_optimizer.state_dict() == optimizer.state_dict()
38+
39+
after_pred = predict_coref_simple(text, model, tokenizer)
40+
assert before_pred == after_pred

tibert/bertcoref.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from transformers.models.camembert.modeling_camembert import CamembertModel
2020
from transformers.models.camembert.configuration_camembert import CamembertConfig
2121
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
22+
from transformers.utils import logging as transformers_logging
2223
from tqdm import tqdm
2324
from tibert.utils import spans_indexs, batch_index_select, spans
2425

@@ -131,7 +132,14 @@ def prepared_document(
131132
"""
132133
# (silly) exemple for the tokens ["I", "am", "PG"]
133134
# a BertTokenizer would produce ["[CLS]", "I", "am", "P", "##G", "[SEP]"]
134-
batch = tokenizer(self.tokens, is_split_into_words=True, truncation=True) # type: ignore
135+
# NOTE: we disable tokenizer warning to avoid a length
136+
# ---- warning. Usually, sequences should be truncated to a max
137+
# length (512 for BERT). However, in our case, the sequence is
138+
# later cut into segments of configurable size, so this does
139+
# not apply (see BertForCoreferenceResolutionConfig.segment_size)
140+
transformers_logging.set_verbosity_error()
141+
batch = tokenizer(self.tokens, is_split_into_words=True)
142+
transformers_logging.set_verbosity_info()
135143
tokens = tokenizer.convert_ids_to_tokens(batch["input_ids"]) # type: ignore
136144

137145
# words_ids is used to correspond post-tokenization word pieces

tibert/train.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def load_train_checkpoint(
5151

5252
model_config = config_class(**checkpoint["model_config"])
5353
model = model_class(model_config)
54-
model.load_state_dict(checkpoint["model"])
54+
model.load_state_dict(checkpoint["model"], strict=False)
5555

5656
optimizer = torch.optim.AdamW(
5757
[
@@ -68,6 +68,17 @@ def load_train_checkpoint(
6868
return model, optimizer
6969

7070

71+
def _optimizer_to_(
72+
optimizer: torch.optim.AdamW, device: torch.device
73+
) -> torch.optim.AdamW:
74+
"""From https://github.com/pytorch/pytorch/issues/2830"""
75+
for state in optimizer.state.values():
76+
for k, v in state.items():
77+
if isinstance(v, torch.Tensor):
78+
state[k] = v.to(device)
79+
return optimizer
80+
81+
7182
def train_coref_model(
7283
model: Union[BertForCoreferenceResolution, CamembertForCoreferenceResolution],
7384
dataset: CoreferenceDataset,
@@ -161,8 +172,8 @@ def train_coref_model(
161172
],
162173
lr=task_lr,
163174
)
175+
optimizer = _optimizer_to_(optimizer, device)
164176

165-
# Best model saving
166177
# -----------------
167178
best_f1 = 0
168179
best_model = model

0 commit comments

Comments
 (0)