Skip to content

Commit 648f153

Browse files
feat: rework tuner commands to be a bit cleaner
Also added options to select where the params start from bench: 1583604
1 parent 21b0554 commit 648f153

File tree

1 file changed

+82
-45
lines changed

1 file changed

+82
-45
lines changed

src/bin/hce-tuner/main.rs

Lines changed: 82 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,15 @@
1-
use std::process::exit;
2-
31
use chess::{
42
definitions::NumberOf,
53
pieces::{ALL_PIECES, PIECE_NAMES},
64
};
7-
use clap::Parser;
5+
use clap::{Parser, Subcommand, ValueEnum};
86
use indicatif::ParallelProgressIterator;
97
use parameters::Parameters;
108
use rayon::iter::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator};
119
use textplots::{Chart, Plot, Shape};
1210
use tuner::Tuner;
1311
use tuner_score::TuningScore;
12+
use tuning_position::TuningPosition;
1413
mod epd_parser;
1514
mod math;
1615
mod offsets;
@@ -22,24 +21,42 @@ mod tuning_position;
2221
#[derive(Parser, Debug)]
2322
#[command(version, about="Texel tuner for HCE in byte-knight", long_about=None)]
2423
struct Options {
25-
#[clap(short, long, help = "Filterd, marked EPD input data.")]
26-
input_data: String,
27-
#[clap(short, long, help = "Number of epochs to run.")]
28-
epochs: Option<usize>,
29-
#[clap(
30-
long,
31-
action,
32-
default_value_t = false,
33-
help = "Plot k versus error for the given dataset"
34-
)]
35-
plot_k: bool,
36-
#[clap(
37-
long,
38-
action,
39-
default_value_t = false,
40-
help = "Compute error of current parameters"
41-
)]
42-
compute_error: bool,
24+
#[command(subcommand)]
25+
command: Command,
26+
}
27+
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, ValueEnum, Debug)]
28+
enum ParameterStartType {
29+
Zero,
30+
EngineValues,
31+
PieceValues,
32+
}
33+
34+
const INPUT_DATA_HELP: &str = "Filtered, marked EPD or 'book' input data.";
35+
#[derive(Subcommand, Debug)]
36+
enum Command {
37+
Tune {
38+
#[clap(short, long, help = INPUT_DATA_HELP)]
39+
input_data: String,
40+
#[clap(short, long, help = "Number of epochs to run.")]
41+
epochs: Option<usize>,
42+
#[arg(value_enum, short, long, help = "How to start the parameters", default_value_t = ParameterStartType::Zero)]
43+
param_start_type: ParameterStartType,
44+
},
45+
PlotK {
46+
#[clap(short, long, help = INPUT_DATA_HELP)]
47+
input_data: String,
48+
},
49+
ComputeError {
50+
#[clap(short, long, help = INPUT_DATA_HELP)]
51+
input_data: String,
52+
#[clap(
53+
short,
54+
long,
55+
help = "k value to compute error for (0.009)",
56+
default_value_t = 0.009
57+
)]
58+
k: f64,
59+
},
4360
}
4461

4562
fn print_table(indent: usize, table: &[TuningScore]) {
@@ -76,8 +93,8 @@ fn print_params(params: &Parameters) {
7693

7794
fn plot_k(tuner: &Tuner) {
7895
let mut points = Vec::new();
79-
let data_point_count = 10_000;
80-
let k_min = -0.1;
96+
let data_point_count = 1_000;
97+
let k_min = 0.;
8198
let k_max = 0.1;
8299
(0..data_point_count)
83100
.into_par_iter()
@@ -91,32 +108,52 @@ fn plot_k(tuner: &Tuner) {
91108

92109
Chart::new(180, 60, k_min as f32, k_max as f32)
93110
.lineplot(&Shape::Points(points.as_slice()))
94-
.display();
111+
.nice();
95112
}
96113

97-
fn main() {
98-
let options = Options::parse();
99-
println!("Reading data from: {}", options.input_data);
100-
let positions = epd_parser::parse_epd_file(options.input_data.as_str());
114+
fn parse_data(input_data: &str) -> Vec<TuningPosition> {
115+
println!("Reading data from: {}", input_data);
116+
let positions = epd_parser::parse_epd_file(input_data);
101117
// let positions = get_positions();
102118
println!("Read {} positions", positions.len());
119+
positions
120+
}
103121

104-
let epochs = options.epochs.unwrap_or(10_000);
105-
let parameters = Parameters::create_from_engine_values();
106-
let mut tuner = tuner::Tuner::new(parameters, &positions, epochs);
107-
108-
if options.plot_k {
109-
plot_k(&tuner);
110-
exit(0);
111-
}
112-
113-
if options.compute_error {
114-
let k = tuner.compute_k();
115-
let error = tuner.mean_square_error(k);
116-
println!("k: {}, error: {}", k, error);
117-
exit(0);
122+
fn main() {
123+
let options = Options::parse();
124+
match options.command {
125+
Command::Tune {
126+
input_data,
127+
epochs,
128+
param_start_type,
129+
} => {
130+
let positions = parse_data(&input_data);
131+
let parameters = match param_start_type {
132+
ParameterStartType::Zero => Parameters::default(),
133+
ParameterStartType::EngineValues => Parameters::create_from_engine_values(),
134+
ParameterStartType::PieceValues => Parameters::create_from_piece_values(),
135+
};
136+
let epchs = epochs.unwrap_or(10_000);
137+
println!(
138+
"Tuning parameters from {:?} for {} epochs",
139+
param_start_type, epchs
140+
);
141+
let mut tuner = tuner::Tuner::new(parameters, &positions, epchs);
142+
let tuned_results = tuner.tune();
143+
print_params(&tuned_results);
144+
}
145+
Command::PlotK { input_data } => {
146+
let positions = parse_data(&input_data);
147+
let parameters = Parameters::create_from_engine_values();
148+
let tuner = tuner::Tuner::new(parameters, &positions, 10_000);
149+
plot_k(&tuner);
150+
}
151+
Command::ComputeError { input_data, k } => {
152+
let positions = parse_data(&input_data);
153+
let parameters = Parameters::create_from_engine_values();
154+
let tuner = tuner::Tuner::new(parameters, &positions, 10_000);
155+
let error = tuner.mean_square_error(k);
156+
println!("Error for k {:.8}: {:.8}", k, error);
157+
}
118158
}
119-
120-
let tuned_result = tuner.tune();
121-
print_params(tuned_result);
122159
}

0 commit comments

Comments
 (0)