1- from typing import Optional , Union , Literal
2- import traceback , copy
1+ from typing import Optional , Tuple , Type , Union , Literal
2+ import traceback , copy , os
33from statistics import mean
44from more_itertools .recipes import flatten
55import torch
66from torch .utils .data .dataloader import DataLoader
77from transformers import BertTokenizerFast , CamembertTokenizerFast # type: ignore
88from tqdm import tqdm
9- from tibert import (
9+ from tibert . bertcoref import (
1010 BertForCoreferenceResolution ,
1111 CamembertForCoreferenceResolution ,
1212 CoreferenceDataset ,
13- split_coreference_document ,
1413 DataCollatorForSpanClassification ,
15- score_coref_predictions ,
16- score_mention_detection ,
1714)
15+ from tibert .score import score_coref_predictions , score_mention_detection
1816from tibert .predict import predict_coref
19- from tibert .utils import gpu_memory_usage
17+ from tibert .utils import gpu_memory_usage , split_coreference_document
18+
19+
20+ def _save_train_checkpoint (
21+ path : str ,
22+ model : Union [BertForCoreferenceResolution , CamembertForCoreferenceResolution ],
23+ epoch : int ,
24+ optimizer : torch .optim .AdamW ,
25+ bert_lr : float ,
26+ task_lr : float ,
27+ ):
28+ checkpoint = {
29+ "model" : model .state_dict (),
30+ "model_config" : vars (model .config ),
31+ "epoch" : epoch ,
32+ "optimizer" : optimizer .state_dict (),
33+ "bert_lr" : bert_lr ,
34+ "task_lr" : task_lr ,
35+ }
36+ torch .save (checkpoint , path )
37+
38+
39+ def load_train_checkpoint (
40+ checkpoint_path : str ,
41+ model_class : Union [
42+ Type [BertForCoreferenceResolution ], Type [CamembertForCoreferenceResolution ]
43+ ],
44+ ) -> Tuple [
45+ Union [BertForCoreferenceResolution , CamembertForCoreferenceResolution ],
46+ torch .optim .AdamW ,
47+ ]:
48+ config_class = model_class .config_class
49+
50+ checkpoint = torch .load (checkpoint_path )
51+
52+ model_config = config_class (** checkpoint ["model_config" ])
53+ model = model_class (model_config )
54+ model .load_state_dict (checkpoint ["model" ])
55+
56+ optimizer = torch .optim .AdamW (
57+ [
58+ {"params" : model .bert_parameters (), "lr" : checkpoint ["bert_lr" ]},
59+ {
60+ "params" : model .task_parameters (),
61+ "lr" : checkpoint ["task_lr" ],
62+ },
63+ ],
64+ lr = checkpoint ["task_lr" ],
65+ )
66+ optimizer .load_state_dict (checkpoint ["optimizer" ])
67+
68+ return model , optimizer
2069
2170
2271def train_coref_model (
@@ -28,14 +77,41 @@ def train_coref_model(
2877 sents_per_documents_train : int = 11 ,
2978 bert_lr : float = 1e-5 ,
3079 task_lr : float = 2e-4 ,
31- model_save_path : Optional [str ] = None ,
80+ model_save_dir : Optional [str ] = None ,
3281 device_str : Literal ["cpu" , "cuda" , "auto" ] = "auto" ,
3382 _run : Optional ["sacred.run.Run" ] = None ,
83+ optimizer : Optional [torch .optim .AdamW ] = None ,
3484) -> BertForCoreferenceResolution :
85+ """
86+ :param model: model to train
87+ :param dataset: dataset to train on. 90% of that dataset will be
88+ used for training, 10% for testing
89+ :param tokenizer: tokenizer associated with ``model``
90+ :param batch_size: batch_size during training and testing
91+ :param epochs_nb: number of epochs to train for
92+ :param sents_per_documents_train: max number of sentences in each
93+ train document
94+ :param bert_lr: learning rate of the BERT encoder
95+ :param task_lr: learning rate for other parts of the network
96+ :param model_save_dir: directory in which to save the final model
97+ (under 'model') and checkpoints ('checkpoint.pth')
98+ :param device_str:
99+ :param _run: sacred run, used to log metrics
100+ :param optimizer: a torch optimizer to use. Can be useful to
101+ resume training.
102+
103+ :return: the best trained model, according to CoNLL-F1 on the test
104+ set
105+ """
106+ # Get torch device and send model to it
107+ # -------------------------------------
35108 if device_str == "auto" :
36109 device_str = "cuda" if torch .cuda .is_available () else "cpu"
37110 device = torch .device (device_str )
111+ model = model .to (device )
38112
113+ # Prepare datasets
114+ # ----------------
39115 train_dataset = CoreferenceDataset (
40116 dataset .documents [: int (0.9 * len (dataset ))],
41117 dataset .tokenizer ,
@@ -72,23 +148,28 @@ def train_coref_model(
72148 train_dataset , batch_size = batch_size , shuffle = True , collate_fn = data_collator
73149 )
74150
75- optimizer = torch .optim .AdamW (
76- [
77- {"params" : model .bert_parameters (), "lr" : bert_lr },
78- {
79- "params" : model .task_parameters (),
80- "lr" : task_lr ,
81- },
82- ],
83- lr = task_lr ,
84- )
151+ # Optimizer initialization
152+ # ------------------------
153+ if optimizer is None :
154+ optimizer = torch .optim .AdamW (
155+ [
156+ {"params" : model .bert_parameters (), "lr" : bert_lr },
157+ {
158+ "params" : model .task_parameters (),
159+ "lr" : task_lr ,
160+ },
161+ ],
162+ lr = task_lr ,
163+ )
85164
165+ # Best model saving
166+ # -----------------
86167 best_f1 = 0
87168 best_model = model
88169
89- model = model . to ( device )
90-
91- for _ in range (epochs_nb ):
170+ # Training loop
171+ # -------------
172+ for epoch_i in range (epochs_nb ):
92173 model = model .train ()
93174
94175 epoch_losses = []
@@ -158,10 +239,22 @@ def train_coref_model(
158239 f"mention detection metrics: (precision: { m_precision } , recall: { m_recall } , f1: { m_f1 } )"
159240 )
160241
242+ # Model saving
243+ # ------------
244+ if not model_save_dir is None :
245+ os .makedirs (model_save_dir , exist_ok = True )
246+ _save_train_checkpoint (
247+ os .path .join (model_save_dir , "checkpoint.pth" ),
248+ model ,
249+ epoch_i ,
250+ optimizer ,
251+ bert_lr ,
252+ task_lr ,
253+ )
161254 if conll_f1 > best_f1 or best_f1 == 0 :
162255 best_model = copy .deepcopy (model ).to ("cpu" )
163- if not model_save_path is None :
164- best_model .save_pretrained (model_save_path )
165256 best_f1 = conll_f1
257+ if not model_save_dir is None :
258+ model .save_pretrained (os .path .join (model_save_dir , "model" ))
166259
167260 return best_model
0 commit comments