1- use std:: process:: exit;
2-
31use chess:: {
42 definitions:: NumberOf ,
53 pieces:: { ALL_PIECES , PIECE_NAMES } ,
64} ;
7- use clap:: Parser ;
5+ use clap:: { Parser , Subcommand , ValueEnum } ;
86use indicatif:: ParallelProgressIterator ;
97use parameters:: Parameters ;
108use rayon:: iter:: { IndexedParallelIterator , IntoParallelIterator , ParallelIterator } ;
119use textplots:: { Chart , Plot , Shape } ;
1210use tuner:: Tuner ;
1311use tuner_score:: TuningScore ;
12+ use tuning_position:: TuningPosition ;
1413mod epd_parser;
1514mod math;
1615mod offsets;
@@ -22,24 +21,42 @@ mod tuning_position;
2221#[ derive( Parser , Debug ) ]
2322#[ command( version, about="Texel tuner for HCE in byte-knight" , long_about=None ) ]
2423struct Options {
25- #[ clap( short, long, help = "Filterd, marked EPD input data." ) ]
26- input_data : String ,
27- #[ clap( short, long, help = "Number of epochs to run." ) ]
28- epochs : Option < usize > ,
29- #[ clap(
30- long,
31- action,
32- default_value_t = false ,
33- help = "Plot k versus error for the given dataset"
34- ) ]
35- plot_k : bool ,
36- #[ clap(
37- long,
38- action,
39- default_value_t = false ,
40- help = "Compute error of current parameters"
41- ) ]
42- compute_error : bool ,
24+ #[ command( subcommand) ]
25+ command : Command ,
26+ }
27+ #[ derive( Copy , Clone , PartialEq , Eq , PartialOrd , Ord , ValueEnum , Debug ) ]
28+ enum ParameterStartType {
29+ Zero ,
30+ EngineValues ,
31+ PieceValues ,
32+ }
33+
34+ const INPUT_DATA_HELP : & str = "Filtered, marked EPD or 'book' input data." ;
35+ #[ derive( Subcommand , Debug ) ]
36+ enum Command {
37+ Tune {
38+ #[ clap( short, long, help = INPUT_DATA_HELP ) ]
39+ input_data : String ,
40+ #[ clap( short, long, help = "Number of epochs to run." ) ]
41+ epochs : Option < usize > ,
42+ #[ arg( value_enum, short, long, help = "How to start the parameters" , default_value_t = ParameterStartType :: Zero ) ]
43+ param_start_type : ParameterStartType ,
44+ } ,
45+ PlotK {
46+ #[ clap( short, long, help = INPUT_DATA_HELP ) ]
47+ input_data : String ,
48+ } ,
49+ ComputeError {
50+ #[ clap( short, long, help = INPUT_DATA_HELP ) ]
51+ input_data : String ,
52+ #[ clap(
53+ short,
54+ long,
55+ help = "k value to compute error for (0.009)" ,
56+ default_value_t = 0.009
57+ ) ]
58+ k : f64 ,
59+ } ,
4360}
4461
4562fn print_table ( indent : usize , table : & [ TuningScore ] ) {
@@ -76,8 +93,8 @@ fn print_params(params: &Parameters) {
7693
7794fn plot_k ( tuner : & Tuner ) {
7895 let mut points = Vec :: new ( ) ;
79- let data_point_count = 10_000 ;
80- let k_min = - 0.1 ;
96+ let data_point_count = 1_000 ;
97+ let k_min = 0. ;
8198 let k_max = 0.1 ;
8299 ( 0 ..data_point_count)
83100 . into_par_iter ( )
@@ -91,32 +108,52 @@ fn plot_k(tuner: &Tuner) {
91108
92109 Chart :: new ( 180 , 60 , k_min as f32 , k_max as f32 )
93110 . lineplot ( & Shape :: Points ( points. as_slice ( ) ) )
94- . display ( ) ;
111+ . nice ( ) ;
95112}
96113
97- fn main ( ) {
98- let options = Options :: parse ( ) ;
99- println ! ( "Reading data from: {}" , options. input_data) ;
100- let positions = epd_parser:: parse_epd_file ( options. input_data . as_str ( ) ) ;
114+ fn parse_data ( input_data : & str ) -> Vec < TuningPosition > {
115+ println ! ( "Reading data from: {}" , input_data) ;
116+ let positions = epd_parser:: parse_epd_file ( input_data) ;
101117 // let positions = get_positions();
102118 println ! ( "Read {} positions" , positions. len( ) ) ;
119+ positions
120+ }
103121
104- let epochs = options. epochs . unwrap_or ( 10_000 ) ;
105- let parameters = Parameters :: create_from_engine_values ( ) ;
106- let mut tuner = tuner:: Tuner :: new ( parameters, & positions, epochs) ;
107-
108- if options. plot_k {
109- plot_k ( & tuner) ;
110- exit ( 0 ) ;
111- }
112-
113- if options. compute_error {
114- let k = tuner. compute_k ( ) ;
115- let error = tuner. mean_square_error ( k) ;
116- println ! ( "k: {}, error: {}" , k, error) ;
117- exit ( 0 ) ;
122+ fn main ( ) {
123+ let options = Options :: parse ( ) ;
124+ match options. command {
125+ Command :: Tune {
126+ input_data,
127+ epochs,
128+ param_start_type,
129+ } => {
130+ let positions = parse_data ( & input_data) ;
131+ let parameters = match param_start_type {
132+ ParameterStartType :: Zero => Parameters :: default ( ) ,
133+ ParameterStartType :: EngineValues => Parameters :: create_from_engine_values ( ) ,
134+ ParameterStartType :: PieceValues => Parameters :: create_from_piece_values ( ) ,
135+ } ;
136+ let epchs = epochs. unwrap_or ( 10_000 ) ;
137+ println ! (
138+ "Tuning parameters from {:?} for {} epochs" ,
139+ param_start_type, epchs
140+ ) ;
141+ let mut tuner = tuner:: Tuner :: new ( parameters, & positions, epchs) ;
142+ let tuned_results = tuner. tune ( ) ;
143+ print_params ( & tuned_results) ;
144+ }
145+ Command :: PlotK { input_data } => {
146+ let positions = parse_data ( & input_data) ;
147+ let parameters = Parameters :: create_from_engine_values ( ) ;
148+ let tuner = tuner:: Tuner :: new ( parameters, & positions, 10_000 ) ;
149+ plot_k ( & tuner) ;
150+ }
151+ Command :: ComputeError { input_data, k } => {
152+ let positions = parse_data ( & input_data) ;
153+ let parameters = Parameters :: create_from_engine_values ( ) ;
154+ let tuner = tuner:: Tuner :: new ( parameters, & positions, 10_000 ) ;
155+ let error = tuner. mean_square_error ( k) ;
156+ println ! ( "Error for k {:.8}: {:.8}" , k, error) ;
157+ }
118158 }
119-
120- let tuned_result = tuner. tune ( ) ;
121- print_params ( tuned_result) ;
122159}
0 commit comments