|
1 | 1 | from typing import Optional, Tuple, Type, Union, Literal |
2 | 2 | import traceback, copy, os |
3 | 3 | from statistics import mean |
4 | | -from more_itertools.recipes import flatten |
5 | 4 | import torch |
6 | 5 | from torch.utils.data.dataloader import DataLoader |
7 | 6 | from transformers import BertTokenizerFast, CamembertTokenizerFast # type: ignore |
|
14 | 13 | ) |
15 | 14 | from tibert.score import score_coref_predictions, score_mention_detection |
16 | 15 | from tibert.predict import predict_coref |
17 | | -from tibert.utils import gpu_memory_usage, split_coreference_document |
| 16 | +from tibert.utils import gpu_memory_usage |
18 | 17 |
|
19 | 18 |
|
20 | 19 | def _save_train_checkpoint( |
@@ -81,11 +80,11 @@ def _optimizer_to_( |
81 | 80 |
|
82 | 81 | def train_coref_model( |
83 | 82 | model: Union[BertForCoreferenceResolution, CamembertForCoreferenceResolution], |
84 | | - dataset: CoreferenceDataset, |
| 83 | + train_dataset: CoreferenceDataset, |
| 84 | + test_dataset: CoreferenceDataset, |
85 | 85 | tokenizer: Union[BertTokenizerFast, CamembertTokenizerFast], |
86 | 86 | batch_size: int = 1, |
87 | 87 | epochs_nb: int = 30, |
88 | | - sents_per_documents_train: int = 11, |
89 | 88 | bert_lr: float = 1e-5, |
90 | 89 | task_lr: float = 2e-4, |
91 | 90 | model_save_dir: Optional[str] = None, |
@@ -121,37 +120,6 @@ def train_coref_model( |
121 | 120 | device = torch.device(device_str) |
122 | 121 | model = model.to(device) |
123 | 122 |
|
124 | | - # Prepare datasets |
125 | | - # ---------------- |
126 | | - train_dataset = CoreferenceDataset( |
127 | | - dataset.documents[: int(0.9 * len(dataset))], |
128 | | - dataset.tokenizer, |
129 | | - dataset.max_span_size, |
130 | | - ) |
131 | | - train_dataset.documents = list( |
132 | | - flatten( |
133 | | - [ |
134 | | - split_coreference_document(doc, sents_per_documents_train) |
135 | | - for doc in train_dataset.documents |
136 | | - ] |
137 | | - ) |
138 | | - ) |
139 | | - |
140 | | - test_dataset = CoreferenceDataset( |
141 | | - dataset.documents[int(0.9 * len(dataset)) :], |
142 | | - dataset.tokenizer, |
143 | | - dataset.max_span_size, |
144 | | - ) |
145 | | - test_dataset.documents = list( |
146 | | - flatten( |
147 | | - [ |
148 | | - # HACK: test on full documents |
149 | | - split_coreference_document(doc, 11) |
150 | | - for doc in test_dataset.documents |
151 | | - ] |
152 | | - ) |
153 | | - ) |
154 | | - |
155 | 123 | data_collator = DataCollatorForSpanClassification( |
156 | 124 | tokenizer, model.config.max_span_size |
157 | 125 | ) |
|
0 commit comments