Skip to content

Commit 0ba18bf

Browse files
Set train mode after evaluation in training (#311)
1 parent 95f3d64 commit 0ba18bf

File tree

3 files changed

+20
-8
lines changed

3 files changed

+20
-8
lines changed

examples/bert/README.md

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,24 @@ After convergence, the evaluation performance is around the following. Due to ce
120120
initialization of the classification layer), the evaluation accuracy is reasonable as long as it's `>0.84`.
121121
An example output is as follows:
122122
```
123-
Using cached pre-trained BERT checkpoint from texar_download/BERT/bert-base-uncased.
124-
INFO:root:step: 50; loss: 0.448667
125-
INFO:root:step: 100; loss: 0.400019
126-
INFO:root:step: 150; loss: 0.340314
127-
INFO:root:step: 200; loss: 0.151271
128-
INFO:root:step: 250; loss: 0.093740
129-
INFO:root:step: 300; loss: 0.161118
130-
INFO:root:eval accu: 0.8554; loss: 0.4524; nsamples: 408
123+
Using cached pre-trained BERT checkpoint from /home/centos/texar_data/BERT/bert-base-uncased.
124+
INFO:root:step: 50; loss: 0.646327
125+
INFO:root:step: 100; loss: 0.281063
126+
INFO:root:eval accu: 0.8260; loss: 0.4123; nsamples: 408
127+
INFO:root:step: 150; loss: 0.231236
128+
INFO:root:step: 200; loss: 0.175780
129+
INFO:root:eval accu: 0.8431; loss: 0.4503; nsamples: 408
130+
INFO:root:step: 250; loss: 0.077983
131+
INFO:root:step: 300; loss: 0.009281
132+
INFO:root:eval accu: 0.8578; loss: 0.5342; nsamples: 408
133+
INFO:root:step: 350; loss: 0.021876
134+
INFO:root:step: 400; loss: 0.005707
135+
INFO:root:eval accu: 0.8676; loss: 0.5084; nsamples: 408
136+
INFO:root:step: 450; loss: 0.003567
137+
INFO:root:step: 500; loss: 0.034953
138+
INFO:root:eval accu: 0.8701; loss: 0.4743; nsamples: 408
139+
INFO:root:step: 550; loss: 0.008626
140+
INFO:root:eval accu: 0.8627; loss: 0.5593; nsamples: 408
131141
```
132142

133143
### Restore and Test

examples/bert/bert_classifier_main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,7 @@ def _train_epoch():
162162
eval_steps = config_data.eval_steps
163163
if eval_steps > 0 and step % eval_steps == 0:
164164
_eval_epoch()
165+
model.train()
165166

166167
@torch.no_grad()
167168
def _eval_epoch():

examples/gpt-2/gpt2_train_main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def _train_epoch():
157157

158158
if eval_steps > 0 and step % eval_steps == 0:
159159
_eval_epoch()
160+
model.train()
160161

161162
step += 1
162163

0 commit comments

Comments
 (0)