Skip to content

Commit 1a2dbc8

Browse files
committed
Fix erroneous metrics reporting in training
1 parent 84e77ea commit 1a2dbc8

File tree

1 file changed

+48
-73
lines changed

1 file changed

+48
-73
lines changed

tibert/train.py

Lines changed: 48 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Optional, Union
1+
from typing import Optional, Union, Literal
22
import traceback, copy
33
from statistics import mean
44
from more_itertools.recipes import flatten
@@ -8,14 +8,14 @@
88
from tqdm import tqdm
99
from tibert import (
1010
BertForCoreferenceResolution,
11-
BertCoreferenceResolutionOutput,
1211
CamembertForCoreferenceResolution,
1312
CoreferenceDataset,
1413
split_coreference_document,
1514
DataCollatorForSpanClassification,
1615
score_coref_predictions,
1716
score_mention_detection,
1817
)
18+
from tibert.predict import predict_coref
1919
from tibert.utils import gpu_memory_usage
2020

2121

@@ -29,9 +29,12 @@ def train_coref_model(
2929
bert_lr: float = 1e-5,
3030
task_lr: float = 2e-4,
3131
model_save_path: Optional[str] = None,
32+
device_str: Literal["cpu", "cuda", "auto"] = "auto",
3233
_run: Optional["sacred.run.Run"] = None,
3334
) -> BertForCoreferenceResolution:
34-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35+
if device_str == "auto":
36+
device_str = "cuda" if torch.cuda.is_available() else "cpu"
37+
device = torch.device(device_str)
3538

3639
train_dataset = CoreferenceDataset(
3740
dataset.documents[: int(0.9 * len(dataset))],
@@ -68,9 +71,6 @@ def train_coref_model(
6871
train_dataloader = DataLoader(
6972
train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator
7073
)
71-
test_dataloader = DataLoader(
72-
test_dataset, batch_size=batch_size, shuffle=False, collate_fn=data_collator
73-
)
7474

7575
optimizer = torch.optim.AdamW(
7676
[
@@ -88,7 +88,7 @@ def train_coref_model(
8888

8989
model = model.to(device)
9090

91-
for epoch_i in range(epochs_nb):
91+
for _ in range(epochs_nb):
9292
model = model.train()
9393

9494
epoch_losses = []
@@ -120,73 +120,48 @@ def train_coref_model(
120120
if _run:
121121
_run.log_scalar("epoch_mean_loss", mean(epoch_losses))
122122

123-
# metrics computation
124-
model = model.eval()
125-
126-
with torch.no_grad():
127-
try:
128-
preds = []
129-
losses = []
130-
131-
for batch in tqdm(test_dataloader):
132-
local_batch_size = batch["input_ids"].shape[0]
133-
batch = batch.to(device)
134-
out: BertCoreferenceResolutionOutput = model(**batch)
135-
batch_preds = out.coreference_documents(
136-
[
137-
[tokenizer.decode(t) for t in batch["input_ids"][i]]
138-
for i in range(local_batch_size)
139-
]
140-
)
141-
preds += batch_preds
142-
143-
assert not out.loss is None
144-
losses.append(out.loss.item())
145-
146-
_ = _run and _run.log_scalar("epoch_mean_test_loss", mean(losses))
147-
148-
refs = [
149-
doc.prepared_document(
150-
test_dataset.tokenizer, model.config.max_span_size
151-
)[0]
152-
for doc in test_dataset.documents
153-
]
154-
155-
metrics = score_coref_predictions(preds, refs)
156-
conll_f1 = mean(
157-
[metrics["MUC"]["f1"], metrics["B3"]["f1"], metrics["CEAF"]["f1"]]
158-
)
159-
if _run:
160-
_run.log_scalar("muc_precision", metrics["MUC"]["precision"])
161-
_run.log_scalar("muc_recall", metrics["MUC"]["recall"])
162-
_run.log_scalar("muc_f1", metrics["MUC"]["f1"])
163-
_run.log_scalar("b3_precision", metrics["B3"]["precision"])
164-
_run.log_scalar("b3_recall", metrics["B3"]["recall"])
165-
_run.log_scalar("b3_f1", metrics["B3"]["f1"])
166-
_run.log_scalar("ceaf_precision", metrics["CEAF"]["precision"])
167-
_run.log_scalar("ceaf_recall", metrics["CEAF"]["recall"])
168-
_run.log_scalar("ceaf_f1", metrics["CEAF"]["f1"])
169-
_run.log_scalar("conll_f1", conll_f1)
170-
print(metrics)
171-
172-
m_precision, m_recall, m_f1 = score_mention_detection(preds, refs)
173-
if _run:
174-
_run.log_scalar("mention_detection_precision", m_precision)
175-
_run.log_scalar("mention_detection_recall", m_recall)
176-
_run.log_scalar("mention_detection_f1", m_f1)
177-
print(
178-
f"mention detection metrics: (precision: {m_precision}, recall: {m_recall}, f1: {m_f1})"
179-
)
123+
# Metrics Computation
124+
# -------------------
125+
preds = predict_coref(
126+
[doc.tokens for doc in test_dataset.documents],
127+
model,
128+
tokenizer,
129+
batch_size=batch_size,
130+
device_str=device_str,
131+
)
132+
metrics = score_coref_predictions(preds, test_dataset.documents)
180133

181-
except Exception as e:
182-
print(e)
183-
traceback.print_exc()
184-
conll_f1 = 0
134+
conll_f1 = mean(
135+
[metrics["MUC"]["f1"], metrics["B3"]["f1"], metrics["CEAF"]["f1"]]
136+
)
137+
if _run:
138+
_run.log_scalar("muc_precision", metrics["MUC"]["precision"])
139+
_run.log_scalar("muc_recall", metrics["MUC"]["recall"])
140+
_run.log_scalar("muc_f1", metrics["MUC"]["f1"])
141+
_run.log_scalar("b3_precision", metrics["B3"]["precision"])
142+
_run.log_scalar("b3_recall", metrics["B3"]["recall"])
143+
_run.log_scalar("b3_f1", metrics["B3"]["f1"])
144+
_run.log_scalar("ceaf_precision", metrics["CEAF"]["precision"])
145+
_run.log_scalar("ceaf_recall", metrics["CEAF"]["recall"])
146+
_run.log_scalar("ceaf_f1", metrics["CEAF"]["f1"])
147+
_run.log_scalar("conll_f1", conll_f1)
148+
print(metrics)
149+
150+
m_precision, m_recall, m_f1 = score_mention_detection(
151+
preds, test_dataset.documents
152+
)
153+
if _run:
154+
_run.log_scalar("mention_detection_precision", m_precision)
155+
_run.log_scalar("mention_detection_recall", m_recall)
156+
_run.log_scalar("mention_detection_f1", m_f1)
157+
print(
158+
f"mention detection metrics: (precision: {m_precision}, recall: {m_recall}, f1: {m_f1})"
159+
)
185160

186-
if conll_f1 > best_f1 or best_f1 == 0:
187-
best_model = copy.deepcopy(model).to("cpu")
188-
if not model_save_path is None:
189-
best_model.save_pretrained(model_save_path)
190-
best_f1 = conll_f1
161+
if conll_f1 > best_f1 or best_f1 == 0:
162+
best_model = copy.deepcopy(model).to("cpu")
163+
if not model_save_path is None:
164+
best_model.save_pretrained(model_save_path)
165+
best_f1 = conll_f1
191166

192167
return best_model

0 commit comments

Comments
 (0)