|
| 1 | +import json |
| 2 | +import numpy as np |
| 3 | +import os |
| 4 | + |
| 5 | + |
| 6 | +def compute_overall_accuracy(output_path, model_name, prompt_style): |
| 7 | + category_accuracy = {} |
| 8 | + |
| 9 | + with open(f"outputs/{output_path}") as file: |
| 10 | + for line in file: |
| 11 | + data = json.loads(line) |
| 12 | + |
| 13 | + category = data["Category"] |
| 14 | + |
| 15 | + if category not in category_accuracy: |
| 16 | + category_accuracy[category] = [] |
| 17 | + |
| 18 | + if data["Result"] == "Correct": |
| 19 | + category_accuracy[category].append(1) |
| 20 | + else: |
| 21 | + category_accuracy[category].append(0) |
| 22 | + |
| 23 | + # Compute average and standard deviation for each category |
| 24 | + category_stats = {} |
| 25 | + all_results = [] |
| 26 | + |
| 27 | + for cat, results in category_accuracy.items(): |
| 28 | + results_array = np.array(results) |
| 29 | + category_mean = np.mean(results_array) |
| 30 | + category_std = round(np.sqrt(category_mean * (1-category_mean) / len(results_array)), 2) |
| 31 | + category_stats[cat] = { |
| 32 | + "average": round(category_mean * 100, 2), |
| 33 | + "std": category_std |
| 34 | + } |
| 35 | + all_results.extend(results) |
| 36 | + |
| 37 | + # Compute overall average and standard deviation |
| 38 | + all_results_array = np.array(all_results) |
| 39 | + overall_average = np.mean(all_results_array) |
| 40 | + overall_std = round(np.sqrt(overall_average * (1-overall_average) / 1047), 2) |
| 41 | + |
| 42 | + category_stats["overall"] = { |
| 43 | + "average": round(overall_average * 100, 2), |
| 44 | + "std": overall_std |
| 45 | + } |
| 46 | + |
| 47 | + if not os.path.exists("results"): |
| 48 | + os.makedirs("results") |
| 49 | + |
| 50 | + if "/" in model_name: |
| 51 | + model_name = model_name.split('/')[1] |
| 52 | + |
| 53 | + with open(f"results/results_{model_name}_{prompt_style}.json", "w") as file: |
| 54 | + json.dump(category_stats, file, indent=4) |
| 55 | + |
| 56 | + return category_stats |
| 57 | + |
0 commit comments