Skip to content

Commit 95ac76d

Browse files
committed
Add the possibility to specifiy train and test dataset for training
1 parent 982b0ca commit 95ac76d

File tree

3 files changed

+33
-39
lines changed

3 files changed

+33
-39
lines changed

tibert/bertcoref.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
from transformers.tokenization_utils_base import BatchEncoding, PreTrainedTokenizerBase
2222
from transformers.utils import logging as transformers_logging
2323
from tqdm import tqdm
24-
from tibert.utils import spans_indexs, batch_index_select, spans
24+
from tibert.utils import (
25+
spans_indexs,
26+
batch_index_select,
27+
spans,
28+
split_coreference_document,
29+
)
2530

2631

2732
@dataclass
@@ -579,6 +584,21 @@ def merged_datasets(datasets: List[CoreferenceDataset]) -> CoreferenceDataset:
579584
datasets[0].max_span_size,
580585
)
581586

587+
def splitted(self, ratio: float) -> Tuple[CoreferenceDataset, CoreferenceDataset]:
588+
first_docs = self.documents[: int(ratio * len(self))]
589+
second_docs = self.documents[int(ratio * len(self)) :]
590+
return (
591+
CoreferenceDataset(first_docs, self.tokenizer, self.max_span_size),
592+
CoreferenceDataset(second_docs, self.tokenizer, self.max_span_size),
593+
)
594+
595+
def limit_doc_size_(self, sents_nb: int):
596+
self.documents = list(
597+
flatten(
598+
[split_coreference_document(doc, sents_nb) for doc in self.documents]
599+
)
600+
)
601+
582602
def __len__(self) -> int:
583603
return len(self.documents)
584604

tibert/run_train.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
train_coref_model,
1414
load_train_checkpoint,
1515
)
16+
from tibert.bertcoref import CoreferenceDataset
1617

1718
ex = Experiment()
1819

@@ -99,15 +100,20 @@ def main(
99100

100101
tokenizer = config["tokenizer_class"].from_pretrained(encoder)
101102

102-
dataset = config["loading_function"](dataset_path, tokenizer, max_span_size)
103+
dataset: CoreferenceDataset = config["loading_function"](
104+
dataset_path, tokenizer, max_span_size
105+
)
106+
train_dataset, test_dataset = dataset.splitted(0.9)
107+
train_dataset.limit_doc_size_(sents_per_documents_train)
108+
test_dataset.limit_doc_size_(11)
103109

104110
train_coref_model(
105111
model,
106-
dataset,
112+
train_dataset,
113+
test_dataset,
107114
tokenizer,
108115
batch_size=batch_size,
109116
epochs_nb=epochs_nb,
110-
sents_per_documents_train=sents_per_documents_train,
111117
bert_lr=bert_lr,
112118
task_lr=task_lr,
113119
model_save_dir=out_model_dir,

tibert/train.py

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from typing import Optional, Tuple, Type, Union, Literal
22
import traceback, copy, os
33
from statistics import mean
4-
from more_itertools.recipes import flatten
54
import torch
65
from torch.utils.data.dataloader import DataLoader
76
from transformers import BertTokenizerFast, CamembertTokenizerFast # type: ignore
@@ -14,7 +13,7 @@
1413
)
1514
from tibert.score import score_coref_predictions, score_mention_detection
1615
from tibert.predict import predict_coref
17-
from tibert.utils import gpu_memory_usage, split_coreference_document
16+
from tibert.utils import gpu_memory_usage
1817

1918

2019
def _save_train_checkpoint(
@@ -81,11 +80,11 @@ def _optimizer_to_(
8180

8281
def train_coref_model(
8382
model: Union[BertForCoreferenceResolution, CamembertForCoreferenceResolution],
84-
dataset: CoreferenceDataset,
83+
train_dataset: CoreferenceDataset,
84+
test_dataset: CoreferenceDataset,
8585
tokenizer: Union[BertTokenizerFast, CamembertTokenizerFast],
8686
batch_size: int = 1,
8787
epochs_nb: int = 30,
88-
sents_per_documents_train: int = 11,
8988
bert_lr: float = 1e-5,
9089
task_lr: float = 2e-4,
9190
model_save_dir: Optional[str] = None,
@@ -121,37 +120,6 @@ def train_coref_model(
121120
device = torch.device(device_str)
122121
model = model.to(device)
123122

124-
# Prepare datasets
125-
# ----------------
126-
train_dataset = CoreferenceDataset(
127-
dataset.documents[: int(0.9 * len(dataset))],
128-
dataset.tokenizer,
129-
dataset.max_span_size,
130-
)
131-
train_dataset.documents = list(
132-
flatten(
133-
[
134-
split_coreference_document(doc, sents_per_documents_train)
135-
for doc in train_dataset.documents
136-
]
137-
)
138-
)
139-
140-
test_dataset = CoreferenceDataset(
141-
dataset.documents[int(0.9 * len(dataset)) :],
142-
dataset.tokenizer,
143-
dataset.max_span_size,
144-
)
145-
test_dataset.documents = list(
146-
flatten(
147-
[
148-
# HACK: test on full documents
149-
split_coreference_document(doc, 11)
150-
for doc in test_dataset.documents
151-
]
152-
)
153-
)
154-
155123
data_collator = DataCollatorForSpanClassification(
156124
tokenizer, model.config.max_span_size
157125
)

0 commit comments

Comments
 (0)