Skip to content

Commit ea50451

Browse files
authored
Merge pull request #2 from Technolog796/release
Ready code base for release
2 parents 873c210 + 753f781 commit ea50451

File tree

8 files changed

+1748
-641
lines changed

8 files changed

+1748
-641
lines changed

requirements.txt

-1.12 KB
Binary file not shown.

runner.py

Lines changed: 126 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,98 +1,167 @@
11
import yaml
2+
from typing import Dict, Any, Optional
23
from src.equality_checker import MathEqualityChecker
34
from src.leaderboard import Leaderboard
45
import argparse
56
from pathlib import Path
67
import sys
78
import os
9+
import shutil
810

9-
def main():
11+
12+
def main() -> None:
13+
"""
14+
Основная функция приложения для оценки языковых моделей на математических и физических задачах.
15+
16+
Обрабатывает аргументы командной строки, запускает оценку моделей на выбранных датасетах
17+
и выводит результаты в виде отформатированной таблицы.
18+
"""
1019
# Установка кодировки вывода в UTF-8
11-
if sys.platform == 'win32':
12-
os.system('chcp 65001') # Установка кодировки UTF-8 для Windows консоли
13-
20+
if sys.platform == "win32":
21+
os.system("chcp 65001") # Установка кодировки UTF-8 для Windows консоли
22+
1423
parser = argparse.ArgumentParser(
15-
description='Оценка языковых моделей на математических задачах',
24+
description="Оценка языковых моделей на математических и физических задачах",
1625
formatter_class=argparse.RawDescriptionHelpFormatter,
1726
epilog="""
1827
Примеры использования:
1928
python runner.py # Запустить оценку на всех датасетах (по умолчанию)
2029
python runner.py --dataset russianmath # Запустить только на датасете RussianMath
21-
python runner.py --dataset mathdemon # Запустить только на датасете MathDemon_Demidovich
30+
python runner.py --dataset physics # Запустить только на датасете RussianPhysics
2231
python runner.py --no-cache # Игнорировать кэш и переоценить все модели
2332
python runner.py --max-workers 8 # Использовать 8 параллельных потоков
24-
"""
33+
""",
34+
)
35+
parser.add_argument(
36+
"--max-workers",
37+
type=int,
38+
default=4,
39+
help="Максимальное количество параллельных потоков (по умолчанию: 4)",
40+
)
41+
parser.add_argument(
42+
"--config",
43+
type=str,
44+
default="configs/run.yaml",
45+
help="Путь к файлу конфигурации (по умолчанию: configs/run.yaml)",
46+
)
47+
parser.add_argument(
48+
"--no-cache",
49+
action="store_true",
50+
help="Игнорировать кэш и переоценить все модели",
51+
)
52+
parser.add_argument(
53+
"--dataset",
54+
choices=["all", "russianmath", "physics"],
55+
default="all",
56+
help="Выбор датасета для оценки: all (все), russianmath, physics (по умолчанию: all)",
2557
)
26-
parser.add_argument('--max-workers', type=int, default=4,
27-
help='Максимальное количество параллельных потоков (по умолчанию: 4)')
28-
parser.add_argument('--config', type=str, default='configs/run.yaml',
29-
help='Путь к файлу конфигурации (по умолчанию: configs/run.yaml)')
30-
parser.add_argument('--no-cache', action='store_true',
31-
help='Игнорировать кэш и переоценить все модели')
32-
parser.add_argument('--dataset', choices=['all', 'russianmath', 'mathdemon'], default='all',
33-
help='Выбор датасета для оценки: all (все), russianmath, mathdemon (по умолчанию: all)')
3458
args = parser.parse_args()
3559

3660
# Загружаем конфиг
37-
with open(args.config, 'r', encoding='utf-8') as f:
38-
config = yaml.safe_load(f)
39-
61+
with open(args.config, "r", encoding="utf-8") as f:
62+
config: Dict[str, Any] = yaml.safe_load(f)
63+
4064
# Если указан --no-cache, отключаем использование кэша
4165
if args.no_cache:
42-
# Очищаем директорию с кэшем
4366
cache_dir = Path("results/cache")
4467
if cache_dir.exists():
4568
for cache_file in cache_dir.glob("*.json"):
4669
cache_file.unlink()
4770

48-
# Создаем equality checker
71+
# Создаем equality checker для проверки равенства математических выражений
4972
equality_checker = MathEqualityChecker()
5073

51-
# Создаем и обновляем лидерборд
74+
# Создаем и инициализируем лидерборд
5275
leaderboard = Leaderboard(args.config, max_workers=args.max_workers)
53-
54-
# Определяем system prompts для моделей
55-
system_prompts = {
56-
model: config.get(model, {}).get('system_prompt')
57-
for model in config['model_list']
76+
77+
# Определяем системные промпты для каждой модели из конфига
78+
system_prompts: Dict[str, Optional[str]] = {
79+
model: config.get(model, {}).get("system_prompt")
80+
for model in config["model_list"]
5881
}
59-
82+
6083
# Запуск оценки в зависимости от выбранного датасета
61-
if args.dataset == 'all':
62-
print("\nЗапуск оценки на всех датасетах (RussianMath и MathDemon_Demidovich)")
84+
if args.dataset == "all" or args.dataset == "russianmath":
85+
print("\nЗапуск оценки на датасете RussianMath")
6386
leaderboard.evaluate_all_models(system_prompts)
64-
leaderboard.evaluate_math_demon_subsets()
65-
elif args.dataset == 'russianmath':
66-
print("\nЗапуск оценки только на датасете RussianMath")
67-
leaderboard.evaluate_all_models(system_prompts)
68-
elif args.dataset == 'mathdemon':
69-
print("\nЗапуск оценки только на датасете MathDemon_Dемидович")
70-
leaderboard.evaluate_math_demon_subsets()
71-
72-
# Вычисляем комбинированные оценки для моделей
73-
if args.dataset == 'all':
74-
print("\nВычисление комбинированной оценки по всем датасетам")
87+
88+
if args.dataset == "all" or args.dataset == "physics":
89+
print("\nЗапуск оценки на датасете RussianPhysics")
90+
leaderboard.evaluate_physics_models(system_prompts)
91+
92+
# Вычисляем общий скор для моделей (полусумма по обоим датасетам)
93+
if args.dataset == "all":
94+
print("\nВычисление общего скора по всем датасетам")
7595
leaderboard.calculate_combined_scores()
76-
77-
# Генерируем и выводим лидерборд
78-
TERMINAL_WIDTH = 100 # Примерная ширина терминала
96+
97+
# Получаем ширину терминала для форматирования вывода
98+
terminal_width = shutil.get_terminal_size().columns
99+
100+
# Выводим красивый заголовок лидерборда
79101
header = " LEADERBOARD "
80-
padding = "=" * ((TERMINAL_WIDTH - len(header)) // 2)
81-
print(f"\n{padding}{header}{padding}\n")
82-
102+
padding = "=" * ((terminal_width - len(header)) // 2)
103+
print(f"\n{padding}{header}{padding}")
104+
105+
# Генерируем markdown таблицу с результатами
83106
md = leaderboard.generate_markdown()
84-
85-
# Выводим упрощенную версию лидерборда в консоль
86-
lines = md.split('\n')
87-
for line in lines:
88-
if line.startswith('|'):
89-
# Убираем ссылки [Details](path) из вывода в консоль
90-
cleaned_line = line.split('|')
91-
if len(cleaned_line) > 5: # Проверяем, что это строка с данными
92-
cleaned_line = cleaned_line[:-1] # Убираем последнюю колонку с ссылками
93-
print('|'.join(cleaned_line))
107+
108+
# Форматируем и выводим таблицу в терминал
109+
lines = md.split("\n")
110+
table_lines = [line for line in lines if line.startswith("|")]
111+
112+
if len(table_lines) >= 2: # Есть заголовок и разделитель
113+
header_line = table_lines[0]
114+
separator_line = table_lines[1]
115+
data_lines = table_lines[2:] if len(table_lines) > 2 else []
116+
117+
# Анализируем ширину каждого столбца из заголовка
118+
columns = header_line.split("|")
119+
columns = [col.strip() for col in columns if col] # Убираем пустые элементы
120+
121+
# Находим максимальную ширину для каждого столбца
122+
column_widths = [len(col) for col in columns]
123+
124+
# Учитываем ширину данных в каждой строке
125+
for line in data_lines:
126+
cells = line.split("|")
127+
cells = [cell.strip() for cell in cells if cell]
128+
for i, cell in enumerate(cells):
129+
if i < len(column_widths):
130+
column_widths[i] = max(column_widths[i], len(cell))
131+
132+
# Форматируем и выводим заголовок
133+
formatted_header = (
134+
"| "
135+
+ " | ".join(f"{col:<{column_widths[i]}}" for i, col in enumerate(columns))
136+
+ " |"
137+
)
138+
print(f"\n{formatted_header}")
139+
140+
# Форматируем и выводим разделитель
141+
formatted_separator = (
142+
"|-" + "-|-".join("-" * width for width in column_widths) + "-|"
143+
)
144+
print(formatted_separator)
145+
146+
# Форматируем и выводим данные
147+
for line in data_lines:
148+
cells = line.split("|")
149+
cells = [cell.strip() for cell in cells if cell]
150+
formatted_line = (
151+
"| "
152+
+ " | ".join(
153+
f"{cell:<{column_widths[i]}}" for i, cell in enumerate(cells)
154+
)
155+
+ " |"
156+
)
157+
print(formatted_line)
158+
else:
159+
# Если не смогли разобрать таблицу, выводим строки как есть
160+
for line in table_lines:
161+
print(line)
94162

95163
print(f"\nДетальные результаты сохранены в: {leaderboard.output_dir}")
96164

165+
97166
if __name__ == "__main__":
98-
main()
167+
main()

0 commit comments

Comments
 (0)