Skip to content

Commit 45b7015

Browse files
committed
Push code
1 parent 09b0d14 commit 45b7015

File tree

2 files changed

+57
-2
lines changed

2 files changed

+57
-2
lines changed
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import json
2+
import numpy as np
3+
import os
4+
5+
6+
def compute_overall_accuracy(output_path, model_name, prompt_style):
7+
category_accuracy = {}
8+
9+
with open(f"outputs/{output_path}") as file:
10+
for line in file:
11+
data = json.loads(line)
12+
13+
category = data["Category"]
14+
15+
if category not in category_accuracy:
16+
category_accuracy[category] = []
17+
18+
if data["Result"] == "Correct":
19+
category_accuracy[category].append(1)
20+
else:
21+
category_accuracy[category].append(0)
22+
23+
# Compute average and standard deviation for each category
24+
category_stats = {}
25+
all_results = []
26+
27+
for cat, results in category_accuracy.items():
28+
results_array = np.array(results)
29+
category_mean = np.mean(results_array)
30+
category_std = round(np.sqrt(category_mean * (1-category_mean) / len(results_array)), 2)
31+
category_stats[cat] = {
32+
"average": round(category_mean * 100, 2),
33+
"std": category_std
34+
}
35+
all_results.extend(results)
36+
37+
# Compute overall average and standard deviation
38+
all_results_array = np.array(all_results)
39+
overall_average = np.mean(all_results_array)
40+
overall_std = round(np.sqrt(overall_average * (1-overall_average) / 1047), 2)
41+
42+
category_stats["overall"] = {
43+
"average": round(overall_average * 100, 2),
44+
"std": overall_std
45+
}
46+
47+
if not os.path.exists("results"):
48+
os.makedirs("results")
49+
50+
if "/" in model_name:
51+
model_name = model_name.split('/')[1]
52+
53+
with open(f"results/results_{model_name}_{prompt_style}.json", "w") as file:
54+
json.dump(category_stats, file, indent=4)
55+
56+
return category_stats
57+

evaluation/table_stats.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ def compute_overall_accuracy(output_path, model_name, prompt_style):
4747
if not os.path.exists("results"):
4848
os.makedirs("results")
4949

50-
model_name = model_name.replace("-", "_")
51-
5250
if "/" in model_name:
5351
model_name = model_name.split('/')[1]
5452

0 commit comments

Comments
 (0)