Skip to content

Commit 65dd4bd

Browse files
committed
fixed some problem detected by CI
1 parent 14292cc commit 65dd4bd

File tree

1 file changed

+0
-35
lines changed

1 file changed

+0
-35
lines changed

examples/few_shot/pet/evaluate.py

Lines changed: 0 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -87,41 +87,6 @@ def do_evaluate(model, tokenizer, data_loader, label_normalize_dict):
8787
return 100 * correct_num / total_num, total_num
8888

8989

90-
@paddle.no_grad()
91-
def do_evaluate_cluewsc(model, tokenizer, data_loader, label_normalize_dict):
92-
93-
model.eval()
94-
95-
total_num = 0
96-
correct_num = 0
97-
98-
normed_labels = [
99-
normalized_lable
100-
for origin_lable, normalized_lable in label_normalize_dict.items()
101-
]
102-
103-
label_length = len(normed_labels[0])
104-
src_ids, token_type_ids, masked_positions, masked_lm_labels
105-
for batch in data_loader:
106-
src_ids, token_type_ids, label_idx = batch
107-
# [bs * label_length, vocab_size]
108-
prediction_probs = model(
109-
input_ids=src_ids, token_type_ids=token_type_ids)
110-
111-
# Get max probs label's index
112-
y_pred_index = paddle.argmax(prediction_probs, axis=-1).numpy()
113-
y_true_index = []
114-
115-
for label_i in label_idx.numpy():
116-
y_true_index.append(label_i)
117-
y_true_index = np.array(y_true_index)
118-
119-
total_num += len(y_true_index)
120-
correct_num += (y_true_index == y_pred_index).sum()
121-
122-
return 100 * correct_num / total_num, total_num
123-
124-
12590
@paddle.no_grad()
12691
def do_evaluate_chid(model, tokenizer, data_loader, label_normalize_dict):
12792
"""

0 commit comments

Comments
 (0)