Skip to content

Commit d8c0754

Browse files
committed
Support CLI-based Full Evaluation
1 parent 85dd893 commit d8c0754

File tree

4 files changed

+104
-10
lines changed

4 files changed

+104
-10
lines changed

evaluation/classification/eval_classification.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Modified from: https://github.com/mrpeerat/Thai-Sentence-Vector-Benchmark/blob/main/Transfer_Evaluation/transfer.py
22

33
from dataclasses import dataclass
4+
import json
5+
import os
46

57
from datargs import parse
68
from datasets import load_dataset
@@ -19,9 +21,12 @@ class Args:
1921
text_column: str = "tweet"
2022
label_column: str = "label"
2123
encode_batch_size: int = 128
24+
output_folder: str = "results"
2225

2326

2427
def main(args: Args):
28+
os.makedirs(args.output_folder, exist_ok=True)
29+
2530
model = SentenceTransformer(args.model_name)
2631

2732
dataset = load_dataset(args.dataset_name, args.dataset_config)
@@ -46,9 +51,7 @@ def main(args: Args):
4651
predictions = classifier.predict(test_text_encoded)
4752

4853
acc = accuracy_score(test_ds[args.label_column], predictions)
49-
precision, recall, f1, _ = precision_recall_fscore_support(
50-
test_ds[args.label_column], predictions, average="macro"
51-
)
54+
precision, recall, f1, _ = precision_recall_fscore_support(test_ds[args.label_column], predictions, average="macro")
5255

5356
results = {
5457
"accuracy": acc,
@@ -57,7 +60,9 @@ def main(args: Args):
5760
"f1": f1,
5861
}
5962

60-
print(results)
63+
task_name = f"{args.dataset_name.split('/')[-1]}_{args.dataset_config}"
64+
with open(f"{args.output_folder}/{task_name}_{args.test_split_name}.json", "w") as f:
65+
json.dump(results, f, indent=4)
6166

6267

6368
if __name__ == "__main__":

evaluation/pair_classification/eval_pair_classification.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Modified from: https://github.com/embeddings-benchmark/mteb/blob/main/mteb/evaluation/evaluators/PairClassificationEvaluator.py
22

33
from dataclasses import dataclass
4+
import json
5+
import os
46

57
import numpy as np
68
from datargs import parse
@@ -22,6 +24,7 @@ class Args:
2224
neutral_label: int = 1
2325
contradiction_label: int = 2
2426
encode_batch_size: int = 128
27+
output_folder: str = "results"
2528

2629

2730
def compute_metrics(model, sentences_1, sentences_2, labels, batch_size):
@@ -136,6 +139,8 @@ def ap_score(scores, labels, high_score_more_similar: bool):
136139

137140

138141
def main(args: Args):
142+
os.makedirs(args.output_folder, exist_ok=True)
143+
139144
model = SentenceTransformer(args.model_name)
140145

141146
test_ds = load_dataset(args.dataset_name, split=args.test_split_name, trust_remote_code=True)
@@ -157,7 +162,8 @@ def main(args: Args):
157162
main_score = max(scores[short_name]["ap"] for short_name in scores)
158163
scores["main_score"] = main_score
159164

160-
print(scores)
165+
with open(f"{args.output_folder}/{args.dataset_name}_{args.test_split_name}.json", "w") as f:
166+
json.dump(scores, f, indent=4)
161167

162168

163169
if __name__ == "__main__":

evaluation/run_evaluation.sh

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
#!/usr/bin/env bash
2+
model=$1
3+
model_name="${model#*/}"
4+
5+
###############################
6+
# RETRIEVAL
7+
###############################
8+
9+
python retrieval/eval_tydiqa.py \
10+
--model-name $model \
11+
--test-dataset-name khalidalt/tydiqa-goldp \
12+
--test-dataset-config indonesian \
13+
--test-dataset-split validation \
14+
--test-batch-size 32 \
15+
--output-folder retrieval/results/$model_name
16+
17+
python retrieval/eval_miracl.py \
18+
--model-name $model \
19+
--test-dataset-name miracl/miracl \
20+
--test-dataset-config id \
21+
--test-dataset-split dev \
22+
--test-batch-size 32 \
23+
--output-folder retrieval/results/$model_name
24+
25+
###############################
26+
# PAIR CLASSIFICATION
27+
###############################
28+
29+
for split in test_lay test_expert
30+
do
31+
python pair_classification/eval_pair_classification.py \
32+
--model-name $model \
33+
--dataset-name indonli \
34+
--test-split-name $split \
35+
--text-column-1 premise \
36+
--text-column-2 hypothesis \
37+
--label-column label \
38+
--output-folder pair_classification/results/$model_name
39+
done
40+
41+
###############################
42+
# CLASSIFICATION
43+
###############################
44+
45+
python classification/eval_classification.py \
46+
--model-name $model \
47+
--dataset-name indonlp/indonlu \
48+
--dataset-config emot \
49+
--train-split-name train \
50+
--test-split-name test \
51+
--text-column tweet \
52+
--label-column label \
53+
--output-folder classification/results/$model_name
54+
55+
python classification/eval_classification.py \
56+
--model-name $model \
57+
--dataset-name indonlp/indonlu \
58+
--dataset-config smsa \
59+
--train-split-name train \
60+
--test-split-name test \
61+
--text-column text \
62+
--label-column label \
63+
--output-folder classification/results/$model_name
64+
65+
mteb \
66+
-m $model \
67+
-l id \
68+
--output_folder mteb/results/$model_name
69+
70+
###############################
71+
# SEMANTIC TEXTUAL SIMILARITY
72+
###############################
73+
74+
python sts/eval_sts.py \
75+
--model-name $model \
76+
--test-dataset-name LazarusNLP/stsb_mt_id \
77+
--test-dataset-split test \
78+
--test-text-column-1 text_1 \
79+
--test-text-column-2 text_2 \
80+
--test-label-column correlation \
81+
--test-batch-size 32 \
82+
--output-folder sts/results/$model_name

evaluation/sts/eval_sts.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from dataclasses import dataclass
2+
import os
23

34
from datargs import parse
45
from datasets import load_dataset
@@ -15,9 +16,12 @@ class Args:
1516
test_text_column_2: str = "text_2"
1617
test_label_column: str = "correlation"
1718
test_batch_size: int = 32
19+
output_folder: str = "results"
1820

1921

2022
def main(args: Args):
23+
os.makedirs(args.output_folder, exist_ok=True)
24+
2125
model = SentenceTransformer(args.model_name)
2226

2327
# Load dataset
@@ -31,11 +35,8 @@ def main(args: Args):
3135
for data in test_ds
3236
]
3337

34-
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(
35-
test_data, batch_size=args.test_batch_size
36-
)
37-
38-
print(evaluator(model))
38+
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(test_data, batch_size=args.test_batch_size)
39+
evaluator(model, output_path=args.output_folder)
3940

4041

4142
if __name__ == "__main__":

0 commit comments

Comments
 (0)