1- from typing import Optional , Union
1+ from typing import Optional , Union , Literal
22import traceback , copy
33from statistics import mean
44from more_itertools .recipes import flatten
88from tqdm import tqdm
99from tibert import (
1010 BertForCoreferenceResolution ,
11- BertCoreferenceResolutionOutput ,
1211 CamembertForCoreferenceResolution ,
1312 CoreferenceDataset ,
1413 split_coreference_document ,
1514 DataCollatorForSpanClassification ,
1615 score_coref_predictions ,
1716 score_mention_detection ,
1817)
18+ from tibert .predict import predict_coref
1919from tibert .utils import gpu_memory_usage
2020
2121
@@ -29,9 +29,12 @@ def train_coref_model(
2929 bert_lr : float = 1e-5 ,
3030 task_lr : float = 2e-4 ,
3131 model_save_path : Optional [str ] = None ,
32+ device_str : Literal ["cpu" , "cuda" , "auto" ] = "auto" ,
3233 _run : Optional ["sacred.run.Run" ] = None ,
3334) -> BertForCoreferenceResolution :
34- device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
35+ if device_str == "auto" :
36+ device_str = "cuda" if torch .cuda .is_available () else "cpu"
37+ device = torch .device (device_str )
3538
3639 train_dataset = CoreferenceDataset (
3740 dataset .documents [: int (0.9 * len (dataset ))],
@@ -68,9 +71,6 @@ def train_coref_model(
6871 train_dataloader = DataLoader (
6972 train_dataset , batch_size = batch_size , shuffle = True , collate_fn = data_collator
7073 )
71- test_dataloader = DataLoader (
72- test_dataset , batch_size = batch_size , shuffle = False , collate_fn = data_collator
73- )
7474
7575 optimizer = torch .optim .AdamW (
7676 [
@@ -88,7 +88,7 @@ def train_coref_model(
8888
8989 model = model .to (device )
9090
91- for epoch_i in range (epochs_nb ):
91+ for _ in range (epochs_nb ):
9292 model = model .train ()
9393
9494 epoch_losses = []
@@ -120,73 +120,48 @@ def train_coref_model(
120120 if _run :
121121 _run .log_scalar ("epoch_mean_loss" , mean (epoch_losses ))
122122
123- # metrics computation
124- model = model .eval ()
125-
126- with torch .no_grad ():
127- try :
128- preds = []
129- losses = []
130-
131- for batch in tqdm (test_dataloader ):
132- local_batch_size = batch ["input_ids" ].shape [0 ]
133- batch = batch .to (device )
134- out : BertCoreferenceResolutionOutput = model (** batch )
135- batch_preds = out .coreference_documents (
136- [
137- [tokenizer .decode (t ) for t in batch ["input_ids" ][i ]]
138- for i in range (local_batch_size )
139- ]
140- )
141- preds += batch_preds
142-
143- assert not out .loss is None
144- losses .append (out .loss .item ())
145-
146- _ = _run and _run .log_scalar ("epoch_mean_test_loss" , mean (losses ))
147-
148- refs = [
149- doc .prepared_document (
150- test_dataset .tokenizer , model .config .max_span_size
151- )[0 ]
152- for doc in test_dataset .documents
153- ]
154-
155- metrics = score_coref_predictions (preds , refs )
156- conll_f1 = mean (
157- [metrics ["MUC" ]["f1" ], metrics ["B3" ]["f1" ], metrics ["CEAF" ]["f1" ]]
158- )
159- if _run :
160- _run .log_scalar ("muc_precision" , metrics ["MUC" ]["precision" ])
161- _run .log_scalar ("muc_recall" , metrics ["MUC" ]["recall" ])
162- _run .log_scalar ("muc_f1" , metrics ["MUC" ]["f1" ])
163- _run .log_scalar ("b3_precision" , metrics ["B3" ]["precision" ])
164- _run .log_scalar ("b3_recall" , metrics ["B3" ]["recall" ])
165- _run .log_scalar ("b3_f1" , metrics ["B3" ]["f1" ])
166- _run .log_scalar ("ceaf_precision" , metrics ["CEAF" ]["precision" ])
167- _run .log_scalar ("ceaf_recall" , metrics ["CEAF" ]["recall" ])
168- _run .log_scalar ("ceaf_f1" , metrics ["CEAF" ]["f1" ])
169- _run .log_scalar ("conll_f1" , conll_f1 )
170- print (metrics )
171-
172- m_precision , m_recall , m_f1 = score_mention_detection (preds , refs )
173- if _run :
174- _run .log_scalar ("mention_detection_precision" , m_precision )
175- _run .log_scalar ("mention_detection_recall" , m_recall )
176- _run .log_scalar ("mention_detection_f1" , m_f1 )
177- print (
178- f"mention detection metrics: (precision: { m_precision } , recall: { m_recall } , f1: { m_f1 } )"
179- )
123+ # Metrics Computation
124+ # -------------------
125+ preds = predict_coref (
126+ [doc .tokens for doc in test_dataset .documents ],
127+ model ,
128+ tokenizer ,
129+ batch_size = batch_size ,
130+ device_str = device_str ,
131+ )
132+ metrics = score_coref_predictions (preds , test_dataset .documents )
180133
181- except Exception as e :
182- print (e )
183- traceback .print_exc ()
184- conll_f1 = 0
134+ conll_f1 = mean (
135+ [metrics ["MUC" ]["f1" ], metrics ["B3" ]["f1" ], metrics ["CEAF" ]["f1" ]]
136+ )
137+ if _run :
138+ _run .log_scalar ("muc_precision" , metrics ["MUC" ]["precision" ])
139+ _run .log_scalar ("muc_recall" , metrics ["MUC" ]["recall" ])
140+ _run .log_scalar ("muc_f1" , metrics ["MUC" ]["f1" ])
141+ _run .log_scalar ("b3_precision" , metrics ["B3" ]["precision" ])
142+ _run .log_scalar ("b3_recall" , metrics ["B3" ]["recall" ])
143+ _run .log_scalar ("b3_f1" , metrics ["B3" ]["f1" ])
144+ _run .log_scalar ("ceaf_precision" , metrics ["CEAF" ]["precision" ])
145+ _run .log_scalar ("ceaf_recall" , metrics ["CEAF" ]["recall" ])
146+ _run .log_scalar ("ceaf_f1" , metrics ["CEAF" ]["f1" ])
147+ _run .log_scalar ("conll_f1" , conll_f1 )
148+ print (metrics )
149+
150+ m_precision , m_recall , m_f1 = score_mention_detection (
151+ preds , test_dataset .documents
152+ )
153+ if _run :
154+ _run .log_scalar ("mention_detection_precision" , m_precision )
155+ _run .log_scalar ("mention_detection_recall" , m_recall )
156+ _run .log_scalar ("mention_detection_f1" , m_f1 )
157+ print (
158+ f"mention detection metrics: (precision: { m_precision } , recall: { m_recall } , f1: { m_f1 } )"
159+ )
185160
186- if conll_f1 > best_f1 or best_f1 == 0 :
187- best_model = copy .deepcopy (model ).to ("cpu" )
188- if not model_save_path is None :
189- best_model .save_pretrained (model_save_path )
190- best_f1 = conll_f1
161+ if conll_f1 > best_f1 or best_f1 == 0 :
162+ best_model = copy .deepcopy (model ).to ("cpu" )
163+ if not model_save_path is None :
164+ best_model .save_pretrained (model_save_path )
165+ best_f1 = conll_f1
191166
192167 return best_model
0 commit comments