1- from typing import Optional , Tuple , Type , Union , Literal
1+ import pickle
2+ from typing import List , Optional , Tuple , Type , Union , Literal
23import traceback , copy , os
34from statistics import mean
45import torch
910 BertForCoreferenceResolution ,
1011 CamembertForCoreferenceResolution ,
1112 CoreferenceDataset ,
13+ CoreferenceDocument ,
1214 DataCollatorForSpanClassification ,
1315)
1416from tibert .score import score_coref_predictions , score_mention_detection
15- from tibert .predict import predict_coref
17+ from tibert .predict import predict_coref , predict_coref_simple
1618from tibert .utils import gpu_memory_usage
1719
1820
@@ -78,6 +80,30 @@ def _optimizer_to_(
7880 return optimizer
7981
8082
83+ def _save_append_example_pred (
84+ path : str ,
85+ model : Union [BertForCoreferenceResolution , CamembertForCoreferenceResolution ],
86+ tokenizer : Union [BertTokenizerFast , CamembertTokenizerFast ],
87+ example_doc : CoreferenceDocument ,
88+ ):
89+ """
90+ Save an example and its prediction to a file, keeping previous
91+ predictions. Useful to follow the evolution of predictions for an
92+ example.
93+ """
94+ pred = predict_coref_simple (example_doc .tokens , model , tokenizer )
95+
96+ if os .path .exists (path ):
97+ with open (path , "rb" ) as f :
98+ ex_dict = pickle .load (f )
99+ ex_dict ["preds" ] = ex_dict .get ("preds" , []) + [pred ]
100+ else :
101+ ex_dict = {"ref" : example_doc , "preds" : [pred ]}
102+
103+ with open (path , "wb" ) as f :
104+ pickle .dump (ex_dict , f )
105+
106+
81107def train_coref_model (
82108 model : Union [BertForCoreferenceResolution , CamembertForCoreferenceResolution ],
83109 train_dataset : CoreferenceDataset ,
@@ -91,6 +117,7 @@ def train_coref_model(
91117 device_str : Literal ["cpu" , "cuda" , "auto" ] = "auto" ,
92118 _run : Optional ["sacred.run.Run" ] = None ,
93119 optimizer : Optional [torch .optim .AdamW ] = None ,
120+ example_tracking_path : Optional [str ] = None ,
94121) -> BertForCoreferenceResolution :
95122 """
96123 :param model: model to train
@@ -109,6 +136,9 @@ def train_coref_model(
109136 :param _run: sacred run, used to log metrics
110137 :param optimizer: a torch optimizer to use. Can be useful to
111138 resume training.
139+ :param example_tracking_path: if given, path to a file where an
140+ example and its prediction will be dumped each epoch. Usefull
141+ to track the evolution of predictions.
112142
113143 :return: the best trained model, according to CoNLL-F1 on the test
114144 set
@@ -218,6 +248,13 @@ def train_coref_model(
218248 f"mention detection metrics: (precision: { m_precision } , recall: { m_recall } , f1: { m_f1 } )"
219249 )
220250
251+ # Example evolution tracking
252+ # --------------------------
253+ if not example_tracking_path is None :
254+ _save_append_example_pred (
255+ example_tracking_path , model , tokenizer , test_dataset .documents [1 ]
256+ )
257+
221258 # Model saving
222259 # ------------
223260 if not model_save_dir is None :
0 commit comments