44from more_itertools .recipes import flatten
55import torch
66from torch .utils .data .dataloader import DataLoader
7- from transformers import BertTokenizerFast # type: ignore
7+ from transformers import BertTokenizerFast , CamembertTokenizerFast # type: ignore
88from tqdm import tqdm
99from tibert import (
1010 BertForCoreferenceResolution ,
1414 split_coreference_document ,
1515 DataCollatorForSpanClassification ,
1616 score_coref_predictions ,
17+ score_mention_detection ,
1718)
1819from tibert .utils import gpu_memory_usage
1920
2021
2122def train_coref_model (
2223 model : Union [BertForCoreferenceResolution , CamembertForCoreferenceResolution ],
2324 dataset : CoreferenceDataset ,
24- tokenizer : BertTokenizerFast ,
25+ tokenizer : Union [ BertTokenizerFast , CamembertTokenizerFast ] ,
2526 batch_size : int = 1 ,
2627 epochs_nb : int = 30 ,
2728 sents_per_documents_train : int = 11 ,
@@ -150,8 +151,11 @@ def train_coref_model(
150151 )[0 ]
151152 for doc in test_dataset .documents
152153 ]
153- metrics = score_coref_predictions (preds , refs )
154154
155+ metrics = score_coref_predictions (preds , refs )
156+ conll_f1 = mean (
157+ [metrics ["MUC" ]["f1" ], metrics ["B3" ]["f1" ], metrics ["CEAF" ]["f1" ]]
158+ )
155159 if _run :
156160 _run .log_scalar ("muc_precision" , metrics ["MUC" ]["precision" ])
157161 _run .log_scalar ("muc_recall" , metrics ["MUC" ]["recall" ])
@@ -162,23 +166,27 @@ def train_coref_model(
162166 _run .log_scalar ("ceaf_precision" , metrics ["CEAF" ]["precision" ])
163167 _run .log_scalar ("ceaf_recall" , metrics ["CEAF" ]["recall" ])
164168 _run .log_scalar ("ceaf_f1" , metrics ["CEAF" ]["f1" ])
165-
169+ _run . log_scalar ( "conll_f1" , conll_f1 )
166170 print (metrics )
167171
168- # keep the best model
169- model_f1 = mean (
170- [metrics ["MUC" ]["f1" ], metrics ["B3" ]["f1" ], metrics ["CEAF" ]["f1" ]]
172+ m_precision , m_recall , m_f1 = score_mention_detection (preds , refs )
173+ if _run :
174+ _run .log_scalar ("mention_detection_precision" , m_precision )
175+ _run .log_scalar ("mention_detection_recall" , m_recall )
176+ _run .log_scalar ("mention_detection_f1" , m_f1 )
177+ print (
178+ f"mention detection metrics: (precision: { m_precision } , recall: { m_recall } , f1: { m_f1 } )"
171179 )
172180
173181 except Exception as e :
174182 print (e )
175183 traceback .print_exc ()
176- model_f1 = 0
184+ conll_f1 = 0
177185
178- if model_f1 > best_f1 or best_f1 == 0 :
186+ if conll_f1 > best_f1 or best_f1 == 0 :
179187 best_model = copy .deepcopy (model ).to ("cpu" )
180188 if not model_save_path is None :
181189 best_model .save_pretrained (model_save_path )
182- best_f1 = model_f1
190+ best_f1 = conll_f1
183191
184192 return best_model
0 commit comments