|
| 1 | +import tempfile |
| 2 | +import torch |
| 3 | +from torch.optim import optimizer |
| 4 | +from transformers import BertTokenizerFast |
| 5 | +from tibert.bertcoref import BertForCoreferenceResolutionConfig |
| 6 | +from tibert.train import _save_train_checkpoint, load_train_checkpoint |
| 7 | +from tibert import BertForCoreferenceResolution, predict_coref_simple |
| 8 | + |
| 9 | + |
| 10 | +def test_save_load_checkpoint(): |
| 11 | + model = BertForCoreferenceResolution(BertForCoreferenceResolutionConfig()) |
| 12 | + tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") |
| 13 | + |
| 14 | + bert_lr = 1e-5 |
| 15 | + task_lr = 2e-4 |
| 16 | + optimizer = torch.optim.AdamW( |
| 17 | + [ |
| 18 | + {"params": model.bert_parameters(), "lr": bert_lr}, |
| 19 | + { |
| 20 | + "params": model.task_parameters(), |
| 21 | + "lr": task_lr, |
| 22 | + }, |
| 23 | + ], |
| 24 | + lr=task_lr, |
| 25 | + ) |
| 26 | + |
| 27 | + text = "Sli did not want the earpods. He didn't like them." |
| 28 | + before_pred = predict_coref_simple(text, model, tokenizer) |
| 29 | + |
| 30 | + with tempfile.TemporaryDirectory() as d: |
| 31 | + checkpoint_f = f"{d}/checkpoint.pth" |
| 32 | + _save_train_checkpoint(checkpoint_f, model, 1, optimizer, bert_lr, task_lr) |
| 33 | + model, new_optimizer = load_train_checkpoint( |
| 34 | + checkpoint_f, BertForCoreferenceResolution |
| 35 | + ) |
| 36 | + |
| 37 | + assert new_optimizer.state_dict() == optimizer.state_dict() |
| 38 | + |
| 39 | + after_pred = predict_coref_simple(text, model, tokenizer) |
| 40 | + assert before_pred == after_pred |
0 commit comments