File tree Expand file tree Collapse file tree 3 files changed +15
-3
lines changed Expand file tree Collapse file tree 3 files changed +15
-3
lines changed Original file line number Diff line number Diff line change @@ -120,6 +120,7 @@ The following parameters can be set (taken from `./tibert/run_train.py` config f
120120| ` segment_size ` | ` 128 ` |
121121| ` encoder ` | ` "bert-base-cased" ` |
122122| ` out_model_dir ` | ` "~/tibert/model" ` |
123+ | ` checkpoint ` | ` None ` |
123124
124125
125126One can monitor training metrics by adding run observers using command line flags - see ` sacred ` documentation for more details.
Original file line number Diff line number Diff line change 11[tool .poetry ]
22name = " tibert"
3- version = " 0.2.1 "
3+ version = " 0.2.2 "
44description = " BERT for Coreference Resolution"
55authors = [
" Arthur Amalvy <[email protected] >" ]
66license = " GPL-3.0-only"
Original file line number Diff line number Diff line change @@ -51,7 +51,7 @@ def load_train_checkpoint(
5151
5252 model_config = config_class (** checkpoint ["model_config" ])
5353 model = model_class (model_config )
54- model .load_state_dict (checkpoint ["model" ])
54+ model .load_state_dict (checkpoint ["model" ], strict = False )
5555
5656 optimizer = torch .optim .AdamW (
5757 [
@@ -68,6 +68,17 @@ def load_train_checkpoint(
6868 return model , optimizer
6969
7070
71+ def _optimizer_to_ (
72+ optimizer : torch .optim .AdamW , device : torch .device
73+ ) -> torch .optim .AdamW :
74+ """From https://github.com/pytorch/pytorch/issues/2830"""
75+ for state in optimizer .state .values ():
76+ for k , v in state .items ():
77+ if isinstance (v , torch .Tensor ):
78+ state [k ] = v .cuda ()
79+ return optimizer
80+
81+
7182def train_coref_model (
7283 model : Union [BertForCoreferenceResolution , CamembertForCoreferenceResolution ],
7384 dataset : CoreferenceDataset ,
@@ -161,8 +172,8 @@ def train_coref_model(
161172 ],
162173 lr = task_lr ,
163174 )
175+ optimizer = _optimizer_to_ (optimizer , device )
164176
165- # Best model saving
166177 # -----------------
167178 best_f1 = 0
168179 best_model = model
You can’t perform that action at this time.
0 commit comments