Skip to content

Commit 050a7b0

Browse files
authored
Update train_eval.py
fix a bug
1 parent 76e585c commit 050a7b0

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

train_eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sklearn import metrics
77
import time
88
from utils import get_time_dif
9-
from pytorch_pretrained_bert.optimization import BertAdam
9+
from pytorch_pretrained.optimization import BertAdam
1010

1111

1212
# 权重初始化,默认xavier
@@ -118,4 +118,4 @@ def evaluate(config, model, data_iter, test=False):
118118
report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
119119
confusion = metrics.confusion_matrix(labels_all, predict_all)
120120
return acc, loss_total / len(data_iter), report, confusion
121-
return acc, loss_total / len(data_iter)
121+
return acc, loss_total / len(data_iter)

0 commit comments

Comments
 (0)