Skip to content

Commit 2a78a4d

Browse files
committed
Add test_train
1 parent 06c0bcd commit 2a78a4d

File tree

1 file changed

+40
-0
lines changed

1 file changed

+40
-0
lines changed

tests/test_train.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import tempfile
2+
import torch
3+
from torch.optim import optimizer
4+
from transformers import BertTokenizerFast
5+
from tibert.bertcoref import BertForCoreferenceResolutionConfig
6+
from tibert.train import _save_train_checkpoint, load_train_checkpoint
7+
from tibert import BertForCoreferenceResolution, predict_coref_simple
8+
9+
10+
def test_save_load_checkpoint():
11+
model = BertForCoreferenceResolution(BertForCoreferenceResolutionConfig())
12+
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
13+
14+
bert_lr = 1e-5
15+
task_lr = 2e-4
16+
optimizer = torch.optim.AdamW(
17+
[
18+
{"params": model.bert_parameters(), "lr": bert_lr},
19+
{
20+
"params": model.task_parameters(),
21+
"lr": task_lr,
22+
},
23+
],
24+
lr=task_lr,
25+
)
26+
27+
text = "Sli did not want the earpods. He didn't like them."
28+
before_pred = predict_coref_simple(text, model, tokenizer)
29+
30+
with tempfile.TemporaryDirectory() as d:
31+
checkpoint_f = f"{d}/checkpoint.pth"
32+
_save_train_checkpoint(checkpoint_f, model, 1, optimizer, bert_lr, task_lr)
33+
model, new_optimizer = load_train_checkpoint(
34+
checkpoint_f, BertForCoreferenceResolution
35+
)
36+
37+
assert new_optimizer.state_dict() == optimizer.state_dict()
38+
39+
after_pred = predict_coref_simple(text, model, tokenizer)
40+
assert before_pred == after_pred

0 commit comments

Comments
 (0)