Skip to content

Commit d6fa8ce

Browse files
committed
Merge branch 'main' of github.com:CompNet/Tibert
2 parents a86a6ae + 6b14e04 commit d6fa8ce

File tree

4 files changed

+159
-52
lines changed

4 files changed

+159
-52
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ Aside from the `tibert.train.train_coref_model` function, it is possible to trai
9797
```sh
9898
python -m tibert.run_train with\
9999
dataset_path=/path/to/litbank/repository\
100-
out_model_path=/path/to/output/model/directory
100+
out_model_dir=/path/to/output/model/directory
101101
```
102102

103103
The following parameters can be set (taken from `./tibert/run_train.py` config function):
@@ -119,7 +119,7 @@ The following parameters can be set (taken from `./tibert/run_train.py` config f
119119
| `dropout` | `0.3` |
120120
| `segment_size` | `128` |
121121
| `encoder` | `"bert-base-cased"` |
122-
| `out_model_path` | `"~/tibert/model"` |
122+
| `out_model_dir` | `"~/tibert/model"` |
123123

124124

125125
One can monitor training metrics by adding run observers using command line flags - see `sacred` documentation for more details.

tibert/bertcoref.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,14 @@ def from_wpieced_to_tokenized(
218218
for mention in chain:
219219

220220
new_start_idx = wp_to_token[mention.start_idx]
221-
new_end_idx = wp_to_token[mention.end_idx - 1] + 1
221+
new_end_idx = wp_to_token[mention.end_idx - 1]
222+
# NOTE: this happens in case the model has predicted
223+
# an erroneous mention such as '[CLS]' or '[SEP]'. In
224+
# that case, we simply ignore the mention.
225+
if new_start_idx is None or new_end_idx is None:
226+
continue
227+
new_end_idx += 1
228+
222229
new_mention = Mention(
223230
tokens[new_start_idx:new_end_idx],
224231
new_start_idx,

tibert/run_train.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Literal, cast
1+
from typing import Literal, Optional, cast
22
import os
3+
from torch.optim import optimizer
34
from transformers import BertTokenizerFast, CamembertTokenizerFast # type: ignore
45
from sacred.experiment import Experiment
56
from sacred.run import Run
@@ -10,6 +11,7 @@
1011
BertForCoreferenceResolution,
1112
CamembertForCoreferenceResolution,
1213
train_coref_model,
14+
load_train_checkpoint,
1315
)
1416

1517
ex = Experiment()
@@ -33,7 +35,8 @@ def config():
3335
dropout: float = 0.3
3436
segment_size: int = 128
3537
encoder: str = "bert-base-cased"
36-
out_model_path: str = os.path.expanduser("~/tibert/model")
38+
out_model_dir: str = os.path.expanduser("~/tibert/model")
39+
checkpoint: Optional[str] = None
3740

3841

3942
@ex.main
@@ -54,7 +57,8 @@ def main(
5457
dropout: float,
5558
segment_size: int,
5659
encoder: str,
57-
out_model_path: str,
60+
out_model_dir: str,
61+
checkpoint: Optional[str],
5862
):
5963
print_config(_run)
6064

@@ -76,39 +80,42 @@ def main(
7680

7781
config = dataset_configs[dataset_name]
7882

79-
model = config["model_class"].from_pretrained(
80-
encoder,
81-
mentions_per_tokens=mentions_per_tokens,
82-
antecedents_nb=antecedents_nb,
83-
max_span_size=max_span_size,
84-
segment_size=segment_size,
85-
mention_scorer_hidden_size=mention_scorer_hidden_size,
86-
mention_scorer_dropout=dropout,
87-
hidden_dropout_prob=dropout,
88-
attention_probs_dropout_prob=dropout,
89-
mention_loss_coeff=mention_loss_coeff,
90-
)
83+
if not checkpoint is None:
84+
model, optimizer = load_train_checkpoint(checkpoint, config["model_class"])
85+
else:
86+
model = config["model_class"].from_pretrained(
87+
encoder,
88+
mentions_per_tokens=mentions_per_tokens,
89+
antecedents_nb=antecedents_nb,
90+
max_span_size=max_span_size,
91+
segment_size=segment_size,
92+
mention_scorer_hidden_size=mention_scorer_hidden_size,
93+
mention_scorer_dropout=dropout,
94+
hidden_dropout_prob=dropout,
95+
attention_probs_dropout_prob=dropout,
96+
mention_loss_coeff=mention_loss_coeff,
97+
)
98+
optimizer = None
9199

92100
tokenizer = config["tokenizer_class"].from_pretrained(encoder)
93101

94102
dataset = config["loading_function"](dataset_path, tokenizer, max_span_size)
95103

96-
model = train_coref_model(
104+
train_coref_model(
97105
model,
98106
dataset,
99107
tokenizer,
100-
batch_size,
101-
epochs_nb,
102-
sents_per_documents_train,
103-
bert_lr,
104-
task_lr,
105-
out_model_path,
106-
"auto",
107-
_run,
108+
batch_size=batch_size,
109+
epochs_nb=epochs_nb,
110+
sents_per_documents_train=sents_per_documents_train,
111+
bert_lr=bert_lr,
112+
task_lr=task_lr,
113+
model_save_dir=out_model_dir,
114+
device_str="auto",
115+
_run=_run,
116+
optimizer=optimizer,
108117
)
109118

110-
model.save_pretrained(out_model_path)
111-
112119

113120
if __name__ == "__main__":
114121
ex.run_commandline()

tibert/train.py

Lines changed: 116 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,71 @@
1-
from typing import Optional, Union, Literal
2-
import traceback, copy
1+
from typing import Optional, Tuple, Type, Union, Literal
2+
import traceback, copy, os
33
from statistics import mean
44
from more_itertools.recipes import flatten
55
import torch
66
from torch.utils.data.dataloader import DataLoader
77
from transformers import BertTokenizerFast, CamembertTokenizerFast # type: ignore
88
from tqdm import tqdm
9-
from tibert import (
9+
from tibert.bertcoref import (
1010
BertForCoreferenceResolution,
1111
CamembertForCoreferenceResolution,
1212
CoreferenceDataset,
13-
split_coreference_document,
1413
DataCollatorForSpanClassification,
15-
score_coref_predictions,
16-
score_mention_detection,
1714
)
15+
from tibert.score import score_coref_predictions, score_mention_detection
1816
from tibert.predict import predict_coref
19-
from tibert.utils import gpu_memory_usage
17+
from tibert.utils import gpu_memory_usage, split_coreference_document
18+
19+
20+
def _save_train_checkpoint(
21+
path: str,
22+
model: Union[BertForCoreferenceResolution, CamembertForCoreferenceResolution],
23+
epoch: int,
24+
optimizer: torch.optim.AdamW,
25+
bert_lr: float,
26+
task_lr: float,
27+
):
28+
checkpoint = {
29+
"model": model.state_dict(),
30+
"model_config": vars(model.config),
31+
"epoch": epoch,
32+
"optimizer": optimizer.state_dict(),
33+
"bert_lr": bert_lr,
34+
"task_lr": task_lr,
35+
}
36+
torch.save(checkpoint, path)
37+
38+
39+
def load_train_checkpoint(
40+
checkpoint_path: str,
41+
model_class: Union[
42+
Type[BertForCoreferenceResolution], Type[CamembertForCoreferenceResolution]
43+
],
44+
) -> Tuple[
45+
Union[BertForCoreferenceResolution, CamembertForCoreferenceResolution],
46+
torch.optim.AdamW,
47+
]:
48+
config_class = model_class.config_class
49+
50+
checkpoint = torch.load(checkpoint_path)
51+
52+
model_config = config_class(**checkpoint["model_config"])
53+
model = model_class(model_config)
54+
model.load_state_dict(checkpoint["model"])
55+
56+
optimizer = torch.optim.AdamW(
57+
[
58+
{"params": model.bert_parameters(), "lr": checkpoint["bert_lr"]},
59+
{
60+
"params": model.task_parameters(),
61+
"lr": checkpoint["task_lr"],
62+
},
63+
],
64+
lr=checkpoint["task_lr"],
65+
)
66+
optimizer.load_state_dict(checkpoint["optimizer"])
67+
68+
return model, optimizer
2069

2170

2271
def train_coref_model(
@@ -28,14 +77,41 @@ def train_coref_model(
2877
sents_per_documents_train: int = 11,
2978
bert_lr: float = 1e-5,
3079
task_lr: float = 2e-4,
31-
model_save_path: Optional[str] = None,
80+
model_save_dir: Optional[str] = None,
3281
device_str: Literal["cpu", "cuda", "auto"] = "auto",
3382
_run: Optional["sacred.run.Run"] = None,
83+
optimizer: Optional[torch.optim.AdamW] = None,
3484
) -> BertForCoreferenceResolution:
85+
"""
86+
:param model: model to train
87+
:param dataset: dataset to train on. 90% of that dataset will be
88+
used for training, 10% for testing
89+
:param tokenizer: tokenizer associated with ``model``
90+
:param batch_size: batch_size during training and testing
91+
:param epochs_nb: number of epochs to train for
92+
:param sents_per_documents_train: max number of sentences in each
93+
train document
94+
:param bert_lr: learning rate of the BERT encoder
95+
:param task_lr: learning rate for other parts of the network
96+
:param model_save_dir: directory in which to save the final model
97+
(under 'model') and checkpoints ('checkpoint.pth')
98+
:param device_str:
99+
:param _run: sacred run, used to log metrics
100+
:param optimizer: a torch optimizer to use. Can be useful to
101+
resume training.
102+
103+
:return: the best trained model, according to CoNLL-F1 on the test
104+
set
105+
"""
106+
# Get torch device and send model to it
107+
# -------------------------------------
35108
if device_str == "auto":
36109
device_str = "cuda" if torch.cuda.is_available() else "cpu"
37110
device = torch.device(device_str)
111+
model = model.to(device)
38112

113+
# Prepare datasets
114+
# ----------------
39115
train_dataset = CoreferenceDataset(
40116
dataset.documents[: int(0.9 * len(dataset))],
41117
dataset.tokenizer,
@@ -72,23 +148,28 @@ def train_coref_model(
72148
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator
73149
)
74150

75-
optimizer = torch.optim.AdamW(
76-
[
77-
{"params": model.bert_parameters(), "lr": bert_lr},
78-
{
79-
"params": model.task_parameters(),
80-
"lr": task_lr,
81-
},
82-
],
83-
lr=task_lr,
84-
)
151+
# Optimizer initialization
152+
# ------------------------
153+
if optimizer is None:
154+
optimizer = torch.optim.AdamW(
155+
[
156+
{"params": model.bert_parameters(), "lr": bert_lr},
157+
{
158+
"params": model.task_parameters(),
159+
"lr": task_lr,
160+
},
161+
],
162+
lr=task_lr,
163+
)
85164

165+
# Best model saving
166+
# -----------------
86167
best_f1 = 0
87168
best_model = model
88169

89-
model = model.to(device)
90-
91-
for _ in range(epochs_nb):
170+
# Training loop
171+
# -------------
172+
for epoch_i in range(epochs_nb):
92173
model = model.train()
93174

94175
epoch_losses = []
@@ -158,10 +239,22 @@ def train_coref_model(
158239
f"mention detection metrics: (precision: {m_precision}, recall: {m_recall}, f1: {m_f1})"
159240
)
160241

242+
# Model saving
243+
# ------------
244+
if not model_save_dir is None:
245+
os.makedirs(model_save_dir, exist_ok=True)
246+
_save_train_checkpoint(
247+
os.path.join(model_save_dir, "checkpoint.pth"),
248+
model,
249+
epoch_i,
250+
optimizer,
251+
bert_lr,
252+
task_lr,
253+
)
161254
if conll_f1 > best_f1 or best_f1 == 0:
162255
best_model = copy.deepcopy(model).to("cpu")
163-
if not model_save_path is None:
164-
best_model.save_pretrained(model_save_path)
165256
best_f1 = conll_f1
257+
if not model_save_dir is None:
258+
model.save_pretrained(os.path.join(model_save_dir, "model"))
166259

167260
return best_model

0 commit comments

Comments
 (0)