Skip to content

Commit 8839dba

Browse files
committed
add example evolution tracking
1 parent de86dd3 commit 8839dba

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

tibert/run_train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def config():
3838
encoder: str = "bert-base-cased"
3939
out_model_dir: str = os.path.expanduser("~/tibert/model")
4040
checkpoint: Optional[str] = None
41+
example_tracking_path: Optional[str] = None
4142

4243

4344
@ex.main
@@ -60,6 +61,7 @@ def main(
6061
encoder: str,
6162
out_model_dir: str,
6263
checkpoint: Optional[str],
64+
example_tracking_path: Optional[str],
6365
):
6466
print_config(_run)
6567

@@ -125,6 +127,7 @@ def main(
125127
device_str="auto",
126128
_run=_run,
127129
optimizer=optimizer,
130+
example_tracking_path=example_tracking_path,
128131
)
129132

130133

tibert/train.py

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Optional, Tuple, Type, Union, Literal
1+
import pickle
2+
from typing import List, Optional, Tuple, Type, Union, Literal
23
import traceback, copy, os
34
from statistics import mean
45
import torch
@@ -9,10 +10,11 @@
910
BertForCoreferenceResolution,
1011
CamembertForCoreferenceResolution,
1112
CoreferenceDataset,
13+
CoreferenceDocument,
1214
DataCollatorForSpanClassification,
1315
)
1416
from tibert.score import score_coref_predictions, score_mention_detection
15-
from tibert.predict import predict_coref
17+
from tibert.predict import predict_coref, predict_coref_simple
1618
from tibert.utils import gpu_memory_usage
1719

1820

@@ -78,6 +80,30 @@ def _optimizer_to_(
7880
return optimizer
7981

8082

83+
def _save_append_example_pred(
84+
path: str,
85+
model: Union[BertForCoreferenceResolution, CamembertForCoreferenceResolution],
86+
tokenizer: Union[BertTokenizerFast, CamembertTokenizerFast],
87+
example_doc: CoreferenceDocument,
88+
):
89+
"""
90+
Save an example and its prediction to a file, keeping previous
91+
predictions. Useful to follow the evolution of predictions for an
92+
example.
93+
"""
94+
pred = predict_coref_simple(example_doc.tokens, model, tokenizer)
95+
96+
if os.path.exists(path):
97+
with open(path, "rb") as f:
98+
ex_dict = pickle.load(f)
99+
ex_dict["preds"] = ex_dict.get("preds", []) + [pred]
100+
else:
101+
ex_dict = {"ref": example_doc, "preds": [pred]}
102+
103+
with open(path, "wb") as f:
104+
pickle.dump(ex_dict, f)
105+
106+
81107
def train_coref_model(
82108
model: Union[BertForCoreferenceResolution, CamembertForCoreferenceResolution],
83109
train_dataset: CoreferenceDataset,
@@ -91,6 +117,7 @@ def train_coref_model(
91117
device_str: Literal["cpu", "cuda", "auto"] = "auto",
92118
_run: Optional["sacred.run.Run"] = None,
93119
optimizer: Optional[torch.optim.AdamW] = None,
120+
example_tracking_path: Optional[str] = None,
94121
) -> BertForCoreferenceResolution:
95122
"""
96123
:param model: model to train
@@ -109,6 +136,9 @@ def train_coref_model(
109136
:param _run: sacred run, used to log metrics
110137
:param optimizer: a torch optimizer to use. Can be useful to
111138
resume training.
139+
:param example_tracking_path: if given, path to a file where an
140+
example and its prediction will be dumped each epoch. Usefull
141+
to track the evolution of predictions.
112142
113143
:return: the best trained model, according to CoNLL-F1 on the test
114144
set
@@ -218,6 +248,13 @@ def train_coref_model(
218248
f"mention detection metrics: (precision: {m_precision}, recall: {m_recall}, f1: {m_f1})"
219249
)
220250

251+
# Example evolution tracking
252+
# --------------------------
253+
if not example_tracking_path is None:
254+
_save_append_example_pred(
255+
example_tracking_path, model, tokenizer, test_dataset.documents[1]
256+
)
257+
221258
# Model saving
222259
# ------------
223260
if not model_save_dir is None:

0 commit comments

Comments
 (0)