Skip to content

Commit 760608b

Browse files
committed
specialize parallel search for 1 thread
1 parent e713e48 commit 760608b

File tree

9 files changed

+227
-122
lines changed

9 files changed

+227
-122
lines changed

Cargo.lock

Lines changed: 11 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ ruzstd = { version = "0.7.0", default-features = false, features = ["std"] }
3838

3939
[dev-dependencies]
4040
criterion = { version = "0.5.1", default-features = false, features = ["rayon"] }
41+
criterion-macro = { version = "0.4.0", default-features = false }
4142
proptest = { version = "1.4.0", default-features = false, features = ["std"] }
4243
test-strategy = { version = "0.3.1", default-features = false }
4344

@@ -66,4 +67,3 @@ bench = false
6667

6768
[[bench]]
6869
name = "search"
69-
harness = false

benches/search.rs

Lines changed: 46 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,60 @@
1-
use criterion::{criterion_group, criterion_main, Criterion, SamplingMode, Throughput};
2-
use lib::search::{Depth, Engine, Limits, Options, ThreadCount};
1+
#![feature(custom_test_frameworks)]
2+
#![test_runner(criterion::runner)]
3+
4+
use criterion::{Criterion, SamplingMode, Throughput};
5+
use criterion_macro::criterion;
6+
use lib::search::{Depth, Engine, Limits, Options};
37
use lib::{nnue::Evaluator, util::Integer};
4-
use std::thread::available_parallelism;
58
use std::time::{Duration, Instant};
9+
use std::{str::FromStr, thread::available_parallelism};
10+
11+
#[ctor::ctor]
12+
static POSITION: Evaluator =
13+
Evaluator::from_str("6br/1KNp1n1r/2p2p2/P1ppRP2/1kP3pP/3PBB2/PN1P4/8 w - - 0 1").unwrap();
614

7-
fn bencher(reps: u64, positions: &[Evaluator], options: Options, limits: Limits) -> Duration {
15+
fn bench(reps: u64, options: Options, limits: Limits) -> Duration {
816
let mut time = Duration::ZERO;
917

10-
for pos in positions {
11-
for _ in 0..reps {
12-
let mut e = Engine::with_options(options);
13-
let timer = Instant::now();
14-
e.search(pos, limits);
15-
time += timer.elapsed();
16-
}
18+
for _ in 0..reps {
19+
let mut e = Engine::with_options(options);
20+
let timer = Instant::now();
21+
e.search(&POSITION, limits);
22+
time += timer.elapsed();
1723
}
1824

1925
time
2026
}
2127

22-
fn bench(c: &mut Criterion) {
23-
let positions: Vec<Evaluator> = FENS.iter().map(|p| p.parse().unwrap()).collect();
24-
let options = match available_parallelism() {
25-
Err(_) => Options::default(),
26-
Ok(cores) => match cores.get() / 2 {
27-
0 => Options::default(),
28-
threads => Options {
29-
threads: ThreadCount::new(threads),
30-
..Options::default()
31-
},
32-
},
28+
#[criterion]
29+
fn crit(c: &mut Criterion) {
30+
let thread_limit = match available_parallelism() {
31+
Ok(cores) => cores.get().div_ceil(2),
32+
Err(_) => 1,
3333
};
3434

35-
let depth = Depth::new(7);
36-
c.benchmark_group("search")
37-
.sampling_mode(SamplingMode::Flat)
38-
.bench_function("ttd", |b| {
39-
b.iter_custom(|i| bencher(i, &positions, options, depth.into()))
40-
});
35+
let options = Vec::from_iter((0..=thread_limit.ilog2()).map(|threads| Options {
36+
threads: 2usize.pow(threads).saturate(),
37+
..Options::default()
38+
}));
39+
40+
for &o in &options {
41+
let depth = Depth::new(14);
42+
c.benchmark_group("ttd")
43+
.sampling_mode(SamplingMode::Flat)
44+
.sample_size(10 * o.threads.get())
45+
.bench_function(o.threads.to_string(), |b| {
46+
b.iter_custom(|i| bench(i, o, depth.into()))
47+
});
48+
}
4149

42-
let nodes = 10000;
43-
c.benchmark_group("search")
44-
.sampling_mode(SamplingMode::Flat)
45-
.throughput(Throughput::Elements(nodes * positions.len() as u64))
46-
.bench_function("nps", |b| {
47-
b.iter_custom(|i| bencher(i, &positions, options, nodes.into()))
48-
});
50+
for &o in &options {
51+
let nodes = 500_000;
52+
c.benchmark_group("nps")
53+
.sampling_mode(SamplingMode::Flat)
54+
.sample_size(10 * o.threads.get())
55+
.throughput(Throughput::Elements(nodes))
56+
.bench_function(o.threads.to_string(), |b| {
57+
b.iter_custom(|i| bench(i, o, nodes.into()))
58+
});
59+
}
4960
}
50-
51-
criterion_group!(benches, bench);
52-
criterion_main!(benches);
53-
54-
// https://www.chessprogramming.org/CCR_One_Hour_Test
55-
const FENS: &[&str] = &[
56-
"rn1qkb1r/pp2pppp/5n2/3p1b2/3P4/2N1P3/PP3PPP/R1BQKBNR w KQkq - 0 1",
57-
"rn1qkb1r/pp2pppp/5n2/3p1b2/3P4/1QN1P3/PP3PPP/R1B1KBNR b KQkq - 1 1",
58-
"r1bqk2r/ppp2ppp/2n5/4P3/2Bp2n1/5N1P/PP1N1PP1/R2Q1RK1 b kq - 1 10",
59-
"r1bqrnk1/pp2bp1p/2p2np1/3p2B1/3P4/2NBPN2/PPQ2PPP/1R3RK1 w - - 1 12",
60-
"rnbqkb1r/ppp1pppp/5n2/8/3PP3/2N5/PP3PPP/R1BQKBNR b KQkq - 3 5",
61-
"rnbq1rk1/pppp1ppp/4pn2/8/1bPP4/P1N5/1PQ1PPPP/R1B1KBNR b KQ - 1 5",
62-
"r4rk1/3nppbp/bq1p1np1/2pP4/8/2N2NPP/PP2PPB1/R1BQR1K1 b - - 1 12",
63-
"rn1qkb1r/pb1p1ppp/1p2pn2/2p5/2PP4/5NP1/PP2PPBP/RNBQK2R w KQkq c6 1 6",
64-
"r1bq1rk1/1pp2pbp/p1np1np1/3Pp3/2P1P3/2N1BP2/PP4PP/R1NQKB1R b KQ - 1 9",
65-
"rnbqr1k1/1p3pbp/p2p1np1/2pP4/4P3/2N5/PP1NBPPP/R1BQ1RK1 w - - 1 11",
66-
"rnbqkb1r/pppp1ppp/5n2/4p3/4PP2/2N5/PPPP2PP/R1BQKBNR b KQkq f3 1 3",
67-
"r1bqk1nr/pppnbppp/3p4/8/2BNP3/8/PPP2PPP/RNBQK2R w KQkq - 2 6",
68-
"rnbq1b1r/ppp2kpp/3p1n2/8/3PP3/8/PPP2PPP/RNBQKB1R b KQ d3 1 5",
69-
"rnbqkb1r/pppp1ppp/3n4/8/2BQ4/5N2/PPP2PPP/RNB2RK1 b kq - 1 6",
70-
"r2q1rk1/2p1bppp/p2p1n2/1p2P3/4P1b1/1nP1BN2/PP3PPP/RN1QR1K1 w - - 1 12",
71-
"r1bqkb1r/2pp1ppp/p1n5/1p2p3/3Pn3/1B3N2/PPP2PPP/RNBQ1RK1 b kq - 2 7",
72-
"r2qkbnr/2p2pp1/p1pp4/4p2p/4P1b1/5N1P/PPPP1PP1/RNBQ1RK1 w kq - 1 8",
73-
"r1bqkb1r/pp3ppp/2np1n2/4p1B1/3NP3/2N5/PPP2PPP/R2QKB1R w KQkq e6 1 7",
74-
"rn1qk2r/1b2bppp/p2ppn2/1p6/3NP3/1BN5/PPP2PPP/R1BQR1K1 w kq - 5 10",
75-
"r1b1kb1r/1pqpnppp/p1n1p3/8/3NP3/2N1B3/PPP1BPPP/R2QK2R w KQkq - 3 8",
76-
"r1bqnr2/pp1ppkbp/4N1p1/n3P3/8/2N1B3/PPP2PPP/R2QK2R b KQ - 2 11",
77-
"r3kb1r/pp1n1ppp/1q2p3/n2p4/3P1Bb1/2PB1N2/PPQ2PPP/RN2K2R w KQkq - 3 11",
78-
"r1bq1rk1/pppnnppp/4p3/3pP3/1b1P4/2NB3N/PPP2PPP/R1BQK2R w KQ - 3 7",
79-
"r2qkbnr/ppp1pp1p/3p2p1/3Pn3/4P1b1/2N2N2/PPP2PPP/R1BQKB1R w KQkq - 2 6",
80-
"rn2kb1r/pp2pppp/1qP2n2/8/6b1/1Q6/PP1PPPBP/RNB1K1NR b KQkq - 1 6",
81-
];

lib/search.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
mod depth;
2+
mod driver;
23
mod engine;
34
mod killers;
45
mod limits;
@@ -9,6 +10,7 @@ mod score;
910
mod transposition;
1011

1112
pub use depth::*;
13+
pub use driver::*;
1214
pub use engine::*;
1315
pub use killers::*;
1416
pub use limits::*;

lib/search/driver.rs

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
use crate::search::{Pv, ThreadCount};
2+
use crate::util::{Binary, Bits, Integer};
3+
use derive_more::{Deref, Display, Error, From};
4+
use rayon::{prelude::*, ThreadPool, ThreadPoolBuilder};
5+
use std::sync::atomic::{AtomicU64, Ordering};
6+
7+
/// Indicates the search was interrupted upon reaching the configured [`crate::search::Limits`].
8+
#[derive(Debug, Display, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Error)]
9+
#[display("the search was interrupted")]
10+
pub struct Interrupted;
11+
12+
/// Whether the search should be [`Interrupted`] or exited early.
13+
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, From)]
14+
pub enum ControlFlow {
15+
Interrupt(Interrupted),
16+
Break,
17+
}
18+
19+
/// A parallel search driver.
20+
#[derive(Debug)]
21+
pub enum Driver {
22+
Parallel(ThreadPool),
23+
Sequential,
24+
}
25+
26+
impl Driver {
27+
/// Constructs a parallel search driver with the given [`ThreadCount`].
28+
pub fn new(threads: ThreadCount) -> Self {
29+
match threads.get() {
30+
1 => Self::Sequential,
31+
n => Self::Parallel(ThreadPoolBuilder::new().num_threads(n).build().unwrap()),
32+
}
33+
}
34+
35+
/// Drive the search, possibly across multiple threads in parallel.
36+
///
37+
/// The order in which elements are processed and on which thread is unspecified.
38+
#[inline(always)]
39+
pub fn drive<M, F>(&self, mut best: Pv, moves: &[M], f: F) -> Result<Pv, Interrupted>
40+
where
41+
M: Sync,
42+
F: Fn(&Pv, &M) -> Result<Pv, ControlFlow> + Sync,
43+
{
44+
match self {
45+
Self::Sequential => {
46+
for m in moves.iter().rev() {
47+
best = match f(&best, m) {
48+
Ok(pv) => pv.max(best),
49+
Err(ControlFlow::Break) => break,
50+
Err(ControlFlow::Interrupt(e)) => return Err(e),
51+
};
52+
}
53+
54+
Ok(best)
55+
}
56+
57+
Self::Parallel(e) => e.install(|| {
58+
use Ordering::Relaxed;
59+
let best = AtomicU64::new(IndexedPv(best, u32::MAX).encode().get());
60+
let result = moves.par_iter().enumerate().rev().try_for_each(|(idx, m)| {
61+
let pv = f(&IndexedPv::decode(Bits::new(best.load(Relaxed))), m)?;
62+
best.fetch_max(IndexedPv(pv, idx.saturate()).encode().get(), Relaxed);
63+
Ok(())
64+
});
65+
66+
if matches!(result, Ok(()) | Err(ControlFlow::Break)) {
67+
Ok(*IndexedPv::decode(Bits::new(best.into_inner())))
68+
} else {
69+
Err(Interrupted)
70+
}
71+
}),
72+
}
73+
}
74+
}
75+
76+
#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Deref)]
77+
#[cfg_attr(test, derive(test_strategy::Arbitrary))]
78+
struct IndexedPv(#[deref] Pv, u32);
79+
80+
impl Binary for IndexedPv {
81+
type Bits = Bits<u64, 64>;
82+
83+
#[inline(always)]
84+
fn encode(&self) -> Self::Bits {
85+
let mut bits = Bits::default();
86+
bits.push(self.score().encode());
87+
bits.push(Bits::<u32, 32>::new(self.1));
88+
bits.push(self.best().encode());
89+
bits
90+
}
91+
92+
#[inline(always)]
93+
fn decode(mut bits: Self::Bits) -> Self {
94+
let best = Binary::decode(bits.pop());
95+
let idx = bits.pop::<u32, 32>().get();
96+
let score = Binary::decode(bits.pop());
97+
Self(Pv::new(score, best), idx)
98+
}
99+
}
100+
101+
#[cfg(test)]
102+
mod tests {
103+
use super::*;
104+
use crate::{chess::Move, nnue::Value};
105+
use std::cmp::max;
106+
use test_strategy::proptest;
107+
108+
#[proptest]
109+
fn decoding_encoded_indexed_pv_is_an_identity(pv: IndexedPv) {
110+
assert_eq!(IndexedPv::decode(pv.encode()), pv);
111+
}
112+
113+
#[proptest]
114+
fn indexed_pv_with_higher_score_is_larger(a: Pv, b: Pv, i: u32) {
115+
assert_eq!(a < b, IndexedPv(a, i) < IndexedPv(b, i));
116+
}
117+
118+
#[proptest]
119+
fn indexed_pv_with_same_score_but_higher_index_is_larger(pv: Pv, a: u32, b: u32) {
120+
assert_eq!(a < b, IndexedPv(pv, a) < IndexedPv(pv, b));
121+
}
122+
123+
#[proptest]
124+
fn driver_finds_max_indexed_pv(c: ThreadCount, pv: Pv, ms: Vec<(Move, Value)>) {
125+
assert_eq!(
126+
Driver::new(c).drive(pv, &ms, |_, &(m, v)| Ok(Pv::new(v.saturate(), Some(m)))),
127+
Ok(*ms
128+
.into_iter()
129+
.enumerate()
130+
.map(|(i, (m, v))| IndexedPv(Pv::new(v.saturate(), Some(m)), i as _))
131+
.fold(IndexedPv(pv, u32::MAX), max))
132+
)
133+
}
134+
}

0 commit comments

Comments
 (0)