Skip to content

Commit 753f781

Browse files
committed
Refactor evaluation classes and improve type annotations
- Enhanced type hints in RussianMathEval, RussianPhysicsEval, and MathDemonEval classes for better clarity and type safety. - Added detailed docstrings to methods and classes to improve code documentation and understanding. - Updated the SamplerBase and Eval protocols to include return types and method descriptions. - Improved handling of optional parameters and default values across various classes. - Cleaned up code formatting and removed unnecessary comments for better readability.
1 parent fee9123 commit 753f781

File tree

8 files changed

+1312
-894
lines changed

8 files changed

+1312
-894
lines changed

requirements.txt

-1.12 KB
Binary file not shown.

runner.py

Lines changed: 91 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import yaml
2+
from typing import Dict, Any, Optional
23
from src.equality_checker import MathEqualityChecker
34
from src.leaderboard import Leaderboard
45
import argparse
@@ -7,13 +8,20 @@
78
import os
89
import shutil
910

10-
def main():
11+
12+
def main() -> None:
13+
"""
14+
Основная функция приложения для оценки языковых моделей на математических и физических задачах.
15+
16+
Обрабатывает аргументы командной строки, запускает оценку моделей на выбранных датасетах
17+
и выводит результаты в виде отформатированной таблицы.
18+
"""
1119
# Установка кодировки вывода в UTF-8
12-
if sys.platform == 'win32':
13-
os.system('chcp 65001') # Установка кодировки UTF-8 для Windows консоли
14-
20+
if sys.platform == "win32":
21+
os.system("chcp 65001") # Установка кодировки UTF-8 для Windows консоли
22+
1523
parser = argparse.ArgumentParser(
16-
description='Оценка языковых моделей на математических и физических задачах',
24+
description="Оценка языковых моделей на математических и физических задачах",
1725
formatter_class=argparse.RawDescriptionHelpFormatter,
1826
epilog="""
1927
Примеры использования:
@@ -22,111 +30,138 @@ def main():
2230
python runner.py --dataset physics # Запустить только на датасете RussianPhysics
2331
python runner.py --no-cache # Игнорировать кэш и переоценить все модели
2432
python runner.py --max-workers 8 # Использовать 8 параллельных потоков
25-
"""
33+
""",
34+
)
35+
parser.add_argument(
36+
"--max-workers",
37+
type=int,
38+
default=4,
39+
help="Максимальное количество параллельных потоков (по умолчанию: 4)",
40+
)
41+
parser.add_argument(
42+
"--config",
43+
type=str,
44+
default="configs/run.yaml",
45+
help="Путь к файлу конфигурации (по умолчанию: configs/run.yaml)",
46+
)
47+
parser.add_argument(
48+
"--no-cache",
49+
action="store_true",
50+
help="Игнорировать кэш и переоценить все модели",
51+
)
52+
parser.add_argument(
53+
"--dataset",
54+
choices=["all", "russianmath", "physics"],
55+
default="all",
56+
help="Выбор датасета для оценки: all (все), russianmath, physics (по умолчанию: all)",
2657
)
27-
parser.add_argument('--max-workers', type=int, default=4,
28-
help='Максимальное количество параллельных потоков (по умолчанию: 4)')
29-
parser.add_argument('--config', type=str, default='configs/run.yaml',
30-
help='Путь к файлу конфигурации (по умолчанию: configs/run.yaml)')
31-
parser.add_argument('--no-cache', action='store_true',
32-
help='Игнорировать кэш и переоценить все модели')
33-
parser.add_argument('--dataset', choices=['all', 'russianmath', 'physics'], default='all',
34-
help='Выбор датасета для оценки: all (все), russianmath, physics (по умолчанию: all)')
3558
args = parser.parse_args()
3659

3760
# Загружаем конфиг
38-
with open(args.config, 'r', encoding='utf-8') as f:
39-
config = yaml.safe_load(f)
40-
61+
with open(args.config, "r", encoding="utf-8") as f:
62+
config: Dict[str, Any] = yaml.safe_load(f)
63+
4164
# Если указан --no-cache, отключаем использование кэша
4265
if args.no_cache:
43-
# Очищаем директорию с кэшем
4466
cache_dir = Path("results/cache")
4567
if cache_dir.exists():
4668
for cache_file in cache_dir.glob("*.json"):
4769
cache_file.unlink()
4870

49-
# Создаем equality checker
71+
# Создаем equality checker для проверки равенства математических выражений
5072
equality_checker = MathEqualityChecker()
5173

52-
# Создаем и обновляем лидерборд
74+
# Создаем и инициализируем лидерборд
5375
leaderboard = Leaderboard(args.config, max_workers=args.max_workers)
54-
55-
# Определяем system prompts для моделей
56-
system_prompts = {
57-
model: config.get(model, {}).get('system_prompt')
58-
for model in config['model_list']
76+
77+
# Определяем системные промпты для каждой модели из конфига
78+
system_prompts: Dict[str, Optional[str]] = {
79+
model: config.get(model, {}).get("system_prompt")
80+
for model in config["model_list"]
5981
}
60-
82+
6183
# Запуск оценки в зависимости от выбранного датасета
62-
if args.dataset == 'all' or args.dataset == 'russianmath':
84+
if args.dataset == "all" or args.dataset == "russianmath":
6385
print("\nЗапуск оценки на датасете RussianMath")
6486
leaderboard.evaluate_all_models(system_prompts)
65-
66-
if args.dataset == 'all' or args.dataset == 'physics':
87+
88+
if args.dataset == "all" or args.dataset == "physics":
6789
print("\nЗапуск оценки на датасете RussianPhysics")
6890
leaderboard.evaluate_physics_models(system_prompts)
69-
91+
7092
# Вычисляем общий скор для моделей (полусумма по обоим датасетам)
71-
if args.dataset == 'all':
93+
if args.dataset == "all":
7294
print("\nВычисление общего скора по всем датасетам")
7395
leaderboard.calculate_combined_scores()
74-
75-
# Получаем ширину терминала
96+
97+
# Получаем ширину терминала для форматирования вывода
7698
terminal_width = shutil.get_terminal_size().columns
77-
78-
# Генерируем и выводим лидерборд с корректным форматированием
99+
100+
# Выводим красивый заголовок лидерборда
79101
header = " LEADERBOARD "
80102
padding = "=" * ((terminal_width - len(header)) // 2)
81103
print(f"\n{padding}{header}{padding}")
82-
83-
# Генерируем markdown таблицу
104+
105+
# Генерируем markdown таблицу с результатами
84106
md = leaderboard.generate_markdown()
85-
86-
# Выводим красиво форматированную таблицу
87-
lines = md.split('\n')
88-
table_lines = [line for line in lines if line.startswith('|')]
89-
107+
108+
# Форматируем и выводим таблицу в терминал
109+
lines = md.split("\n")
110+
table_lines = [line for line in lines if line.startswith("|")]
111+
90112
if len(table_lines) >= 2: # Есть заголовок и разделитель
91113
header_line = table_lines[0]
92114
separator_line = table_lines[1]
93115
data_lines = table_lines[2:] if len(table_lines) > 2 else []
94-
116+
95117
# Анализируем ширину каждого столбца из заголовка
96-
columns = header_line.split('|')
118+
columns = header_line.split("|")
97119
columns = [col.strip() for col in columns if col] # Убираем пустые элементы
98-
120+
99121
# Находим максимальную ширину для каждого столбца
100122
column_widths = [len(col) for col in columns]
101-
123+
102124
# Учитываем ширину данных в каждой строке
103125
for line in data_lines:
104-
cells = line.split('|')
126+
cells = line.split("|")
105127
cells = [cell.strip() for cell in cells if cell]
106128
for i, cell in enumerate(cells):
107129
if i < len(column_widths):
108130
column_widths[i] = max(column_widths[i], len(cell))
109-
131+
110132
# Форматируем и выводим заголовок
111-
formatted_header = '| ' + ' | '.join(f"{col:<{column_widths[i]}}" for i, col in enumerate(columns)) + ' |'
133+
formatted_header = (
134+
"| "
135+
+ " | ".join(f"{col:<{column_widths[i]}}" for i, col in enumerate(columns))
136+
+ " |"
137+
)
112138
print(f"\n{formatted_header}")
113-
139+
114140
# Форматируем и выводим разделитель
115-
formatted_separator = '|-' + '-|-'.join('-' * width for width in column_widths) + '-|'
141+
formatted_separator = (
142+
"|-" + "-|-".join("-" * width for width in column_widths) + "-|"
143+
)
116144
print(formatted_separator)
117-
145+
118146
# Форматируем и выводим данные
119147
for line in data_lines:
120-
cells = line.split('|')
148+
cells = line.split("|")
121149
cells = [cell.strip() for cell in cells if cell]
122-
formatted_line = '| ' + ' | '.join(f"{cell:<{column_widths[i]}}" for i, cell in enumerate(cells)) + ' |'
150+
formatted_line = (
151+
"| "
152+
+ " | ".join(
153+
f"{cell:<{column_widths[i]}}" for i, cell in enumerate(cells)
154+
)
155+
+ " |"
156+
)
123157
print(formatted_line)
124158
else:
125-
# Если не смогли обработать таблицу, просто выводим как есть
159+
# Если не смогли разобрать таблицу, выводим строки как есть
126160
for line in table_lines:
127161
print(line)
128162

129163
print(f"\nДетальные результаты сохранены в: {leaderboard.output_dir}")
130164

165+
131166
if __name__ == "__main__":
132-
main()
167+
main()

0 commit comments

Comments
 (0)