Skip to content

Commit 9378679

Browse files
committed
fix panic in reader for large score diffs
1 parent 4007947 commit 9378679

File tree

3 files changed

+103
-2
lines changed

3 files changed

+103
-2
lines changed

src/reader/compressed_reader.rs

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ impl<T: Read + Seek> CompressedTrainingDataEntryReader<T> {
204204

205205
#[cfg(test)]
206206
mod tests {
207-
use std::fs::OpenOptions;
207+
use std::{fs::OpenOptions, io::Cursor};
208208

209209
use crate::chess::{
210210
coords::Square,
@@ -280,4 +280,52 @@ mod tests {
280280

281281
assert_eq!(entries, expected);
282282
}
283+
284+
#[test]
285+
fn test_reader_big_score_diff() {
286+
let cursor: Cursor<Vec<u8>> = Cursor::new(Vec::from([
287+
66, 73, 78, 80, 37, 0, 0, 0, 130, 130, 144, 210, 8, 192, 70, 82, 72, 58, 64, 0, 81, 16,
288+
18, 113, 155, 5, 0, 0, 0, 0, 0, 0, 10, 104, 249, 253, 0, 68, 0, 0, 0, 1, 29, 83, 79,
289+
]));
290+
291+
let mut reader = CompressedTrainingDataEntryReader::new(cursor).unwrap();
292+
293+
let mut entries: Vec<TrainingDataEntry> = Vec::new();
294+
while reader.has_next() {
295+
let entry = reader.next();
296+
297+
entries.push(entry);
298+
}
299+
300+
let expected = vec![
301+
TrainingDataEntry {
302+
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/3p4/6PP/1nP3B1/1Q2B1K1 w - - 0 35")
303+
.unwrap(),
304+
mv: Move::new(
305+
Square::new(10),
306+
Square::new(26),
307+
MoveType::Normal,
308+
Piece::none(),
309+
),
310+
score: -31999,
311+
ply: 68,
312+
result: 0,
313+
},
314+
TrainingDataEntry {
315+
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/2Pp4/6PP/1n4B1/1Q2B1K1 b - - 0 35")
316+
.unwrap(),
317+
mv: Move::new(
318+
Square::new(27),
319+
Square::new(19),
320+
MoveType::Normal,
321+
Piece::none(),
322+
),
323+
score: -1500,
324+
ply: 69,
325+
result: 0,
326+
},
327+
];
328+
329+
assert_eq!(entries, expected);
330+
}
283331
}

src/writer/compressed_writer.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,4 +320,57 @@ mod tests {
320320
let expected_bytes = fs::read("test/ep1.binpack").unwrap();
321321
assert_eq!(read_bytes, expected_bytes);
322322
}
323+
324+
#[test]
325+
fn test_compressed_writer_big_score_diff() {
326+
let entries = vec![
327+
TrainingDataEntry {
328+
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/3p4/6PP/1nP3B1/1Q2B1K1 w - - 0 35")
329+
.unwrap(),
330+
mv: Move::new(
331+
Square::new(10),
332+
Square::new(26),
333+
MoveType::Normal,
334+
Piece::none(),
335+
),
336+
score: -31999,
337+
ply: 68,
338+
result: 0,
339+
},
340+
TrainingDataEntry {
341+
pos: Position::from_fen("1q5b/1r5k/4p2p/1b2P1pN/2Pp4/6PP/1n4B1/1Q2B1K1 b - - 0 35")
342+
.unwrap(),
343+
mv: Move::new(
344+
Square::new(27),
345+
Square::new(19),
346+
MoveType::Normal,
347+
Piece::none(),
348+
),
349+
score: -1500,
350+
ply: 69,
351+
result: 0,
352+
},
353+
];
354+
355+
let cursor = Cursor::new(Vec::new());
356+
let mut writer = CompressedTrainingDataEntryWriter::new(cursor).unwrap();
357+
358+
for entry in entries.iter() {
359+
writer.write_entry(entry).unwrap();
360+
}
361+
362+
writer.flush().unwrap();
363+
364+
let mut cursor = writer.into_inner().unwrap();
365+
cursor.seek(io::SeekFrom::Start(0)).unwrap();
366+
367+
let mut read_bytes = vec![];
368+
cursor.read_to_end(&mut read_bytes).unwrap();
369+
370+
let expected_bytes = [
371+
66, 73, 78, 80, 37, 0, 0, 0, 130, 130, 144, 210, 8, 192, 70, 82, 72, 58, 64, 0, 81, 16,
372+
18, 113, 155, 5, 0, 0, 0, 0, 0, 0, 10, 104, 249, 253, 0, 68, 0, 0, 0, 1, 29, 83, 79,
373+
];
374+
assert_eq!(read_bytes, expected_bytes);
375+
}
323376
}

src/writer/move_score_list.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ impl PackedMoveScoreList {
6060
self.writer
6161
.add_bits_le8(move_id as u8, used_bits_safe(num_moves));
6262

63-
let score_delta = signed_to_unsigned(score - self.last_score);
63+
let score_delta: u16 = signed_to_unsigned(score.wrapping_sub(self.last_score));
6464

6565
self.writer
6666
.add_bits_vle16(score_delta, SCORE_VLE_BLOCK_SIZE);

0 commit comments

Comments
 (0)