11import yaml
2+ from typing import Dict , Any , Optional
23from src .equality_checker import MathEqualityChecker
34from src .leaderboard import Leaderboard
45import argparse
56from pathlib import Path
67import sys
78import os
9+ import shutil
810
9- def main ():
11+
12+ def main () -> None :
13+ """
14+ Основная функция приложения для оценки языковых моделей на математических и физических задачах.
15+
16+ Обрабатывает аргументы командной строки, запускает оценку моделей на выбранных датасетах
17+ и выводит результаты в виде отформатированной таблицы.
18+ """
1019 # Установка кодировки вывода в UTF-8
11- if sys .platform == ' win32' :
12- os .system (' chcp 65001' ) # Установка кодировки UTF-8 для Windows консоли
13-
20+ if sys .platform == " win32" :
21+ os .system (" chcp 65001" ) # Установка кодировки UTF-8 для Windows консоли
22+
1423 parser = argparse .ArgumentParser (
15- description = ' Оценка языковых моделей на математических задачах' ,
24+ description = " Оценка языковых моделей на математических и физических задачах" ,
1625 formatter_class = argparse .RawDescriptionHelpFormatter ,
1726 epilog = """
1827Примеры использования:
1928 python runner.py # Запустить оценку на всех датасетах (по умолчанию)
2029 python runner.py --dataset russianmath # Запустить только на датасете RussianMath
21- python runner.py --dataset mathdemon # Запустить только на датасете MathDemon_Demidovich
30+ python runner.py --dataset physics # Запустить только на датасете RussianPhysics
2231 python runner.py --no-cache # Игнорировать кэш и переоценить все модели
2332 python runner.py --max-workers 8 # Использовать 8 параллельных потоков
24- """
33+ """ ,
34+ )
35+ parser .add_argument (
36+ "--max-workers" ,
37+ type = int ,
38+ default = 4 ,
39+ help = "Максимальное количество параллельных потоков (по умолчанию: 4)" ,
40+ )
41+ parser .add_argument (
42+ "--config" ,
43+ type = str ,
44+ default = "configs/run.yaml" ,
45+ help = "Путь к файлу конфигурации (по умолчанию: configs/run.yaml)" ,
46+ )
47+ parser .add_argument (
48+ "--no-cache" ,
49+ action = "store_true" ,
50+ help = "Игнорировать кэш и переоценить все модели" ,
51+ )
52+ parser .add_argument (
53+ "--dataset" ,
54+ choices = ["all" , "russianmath" , "physics" ],
55+ default = "all" ,
56+ help = "Выбор датасета для оценки: all (все), russianmath, physics (по умолчанию: all)" ,
2557 )
26- parser .add_argument ('--max-workers' , type = int , default = 4 ,
27- help = 'Максимальное количество параллельных потоков (по умолчанию: 4)' )
28- parser .add_argument ('--config' , type = str , default = 'configs/run.yaml' ,
29- help = 'Путь к файлу конфигурации (по умолчанию: configs/run.yaml)' )
30- parser .add_argument ('--no-cache' , action = 'store_true' ,
31- help = 'Игнорировать кэш и переоценить все модели' )
32- parser .add_argument ('--dataset' , choices = ['all' , 'russianmath' , 'mathdemon' ], default = 'all' ,
33- help = 'Выбор датасета для оценки: all (все), russianmath, mathdemon (по умолчанию: all)' )
3458 args = parser .parse_args ()
3559
3660 # Загружаем конфиг
37- with open (args .config , 'r' , encoding = ' utf-8' ) as f :
38- config = yaml .safe_load (f )
39-
61+ with open (args .config , "r" , encoding = " utf-8" ) as f :
62+ config : Dict [ str , Any ] = yaml .safe_load (f )
63+
4064 # Если указан --no-cache, отключаем использование кэша
4165 if args .no_cache :
42- # Очищаем директорию с кэшем
4366 cache_dir = Path ("results/cache" )
4467 if cache_dir .exists ():
4568 for cache_file in cache_dir .glob ("*.json" ):
4669 cache_file .unlink ()
4770
48- # Создаем equality checker
71+ # Создаем equality checker для проверки равенства математических выражений
4972 equality_checker = MathEqualityChecker ()
5073
51- # Создаем и обновляем лидерборд
74+ # Создаем и инициализируем лидерборд
5275 leaderboard = Leaderboard (args .config , max_workers = args .max_workers )
53-
54- # Определяем system prompts для моделей
55- system_prompts = {
56- model : config .get (model , {}).get (' system_prompt' )
57- for model in config [' model_list' ]
76+
77+ # Определяем системные промпты для каждой модели из конфига
78+ system_prompts : Dict [ str , Optional [ str ]] = {
79+ model : config .get (model , {}).get (" system_prompt" )
80+ for model in config [" model_list" ]
5881 }
59-
82+
6083 # Запуск оценки в зависимости от выбранного датасета
61- if args .dataset == ' all' :
62- print ("\n Запуск оценки на всех датасетах ( RussianMath и MathDemon_Demidovich) " )
84+ if args .dataset == " all" or args . dataset == "russianmath" :
85+ print ("\n Запуск оценки на датасете RussianMath" )
6386 leaderboard .evaluate_all_models (system_prompts )
64- leaderboard .evaluate_math_demon_subsets ()
65- elif args .dataset == 'russianmath' :
66- print ("\n Запуск оценки только на датасете RussianMath" )
67- leaderboard .evaluate_all_models (system_prompts )
68- elif args .dataset == 'mathdemon' :
69- print ("\n Запуск оценки только на датасете MathDemon_Dемидович" )
70- leaderboard .evaluate_math_demon_subsets ()
71-
72- # Вычисляем комбинированные оценки для моделей
73- if args .dataset == 'all' :
74- print ("\n Вычисление комбинированной оценки по всем датасетам" )
87+
88+ if args .dataset == "all" or args .dataset == "physics" :
89+ print ("\n Запуск оценки на датасете RussianPhysics" )
90+ leaderboard .evaluate_physics_models (system_prompts )
91+
92+ # Вычисляем общий скор для моделей (полусумма по обоим датасетам)
93+ if args .dataset == "all" :
94+ print ("\n Вычисление общего скора по всем датасетам" )
7595 leaderboard .calculate_combined_scores ()
76-
77- # Генерируем и выводим лидерборд
78- TERMINAL_WIDTH = 100 # Примерная ширина терминала
96+
97+ # Получаем ширину терминала для форматирования вывода
98+ terminal_width = shutil .get_terminal_size ().columns
99+
100+ # Выводим красивый заголовок лидерборда
79101 header = " LEADERBOARD "
80- padding = "=" * ((TERMINAL_WIDTH - len (header )) // 2 )
81- print (f"\n { padding } { header } { padding } \n " )
82-
102+ padding = "=" * ((terminal_width - len (header )) // 2 )
103+ print (f"\n { padding } { header } { padding } " )
104+
105+ # Генерируем markdown таблицу с результатами
83106 md = leaderboard .generate_markdown ()
84-
85- # Выводим упрощенную версию лидерборда в консоль
86- lines = md .split ('\n ' )
87- for line in lines :
88- if line .startswith ('|' ):
89- # Убираем ссылки [Details](path) из вывода в консоль
90- cleaned_line = line .split ('|' )
91- if len (cleaned_line ) > 5 : # Проверяем, что это строка с данными
92- cleaned_line = cleaned_line [:- 1 ] # Убираем последнюю колонку с ссылками
93- print ('|' .join (cleaned_line ))
107+
108+ # Форматируем и выводим таблицу в терминал
109+ lines = md .split ("\n " )
110+ table_lines = [line for line in lines if line .startswith ("|" )]
111+
112+ if len (table_lines ) >= 2 : # Есть заголовок и разделитель
113+ header_line = table_lines [0 ]
114+ separator_line = table_lines [1 ]
115+ data_lines = table_lines [2 :] if len (table_lines ) > 2 else []
116+
117+ # Анализируем ширину каждого столбца из заголовка
118+ columns = header_line .split ("|" )
119+ columns = [col .strip () for col in columns if col ] # Убираем пустые элементы
120+
121+ # Находим максимальную ширину для каждого столбца
122+ column_widths = [len (col ) for col in columns ]
123+
124+ # Учитываем ширину данных в каждой строке
125+ for line in data_lines :
126+ cells = line .split ("|" )
127+ cells = [cell .strip () for cell in cells if cell ]
128+ for i , cell in enumerate (cells ):
129+ if i < len (column_widths ):
130+ column_widths [i ] = max (column_widths [i ], len (cell ))
131+
132+ # Форматируем и выводим заголовок
133+ formatted_header = (
134+ "| "
135+ + " | " .join (f"{ col :<{column_widths [i ]}} " for i , col in enumerate (columns ))
136+ + " |"
137+ )
138+ print (f"\n { formatted_header } " )
139+
140+ # Форматируем и выводим разделитель
141+ formatted_separator = (
142+ "|-" + "-|-" .join ("-" * width for width in column_widths ) + "-|"
143+ )
144+ print (formatted_separator )
145+
146+ # Форматируем и выводим данные
147+ for line in data_lines :
148+ cells = line .split ("|" )
149+ cells = [cell .strip () for cell in cells if cell ]
150+ formatted_line = (
151+ "| "
152+ + " | " .join (
153+ f"{ cell :<{column_widths [i ]}} " for i , cell in enumerate (cells )
154+ )
155+ + " |"
156+ )
157+ print (formatted_line )
158+ else :
159+ # Если не смогли разобрать таблицу, выводим строки как есть
160+ for line in table_lines :
161+ print (line )
94162
95163 print (f"\n Детальные результаты сохранены в: { leaderboard .output_dir } " )
96164
165+
97166if __name__ == "__main__" :
98- main ()
167+ main ()
0 commit comments