Skip to content

Commit d8d06f5

Browse files
committed
split out rollhash impl
1 parent 7386c29 commit d8d06f5

File tree

3 files changed

+53
-45
lines changed

3 files changed

+53
-45
lines changed

src/encoding.rs

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_hash::FxHashMap as HashMap;
44
use rustc_hash::FxHashSet as HashSet;
55
use std::sync::Arc;
66
use thiserror::Error;
7-
use const_primes::is_prime;
7+
use crate::rollhash::{roll_hash, roll_hash_slice};
88

99
/// A struct that represents an encoding scheme based on byte-pair encoding (BPE).
1010
#[derive(Debug)]
@@ -497,43 +497,6 @@ impl Default for Encoding {
497497
}
498498
}
499499

500-
// Chose a prime number greater than 256 that minimizes hash collisions
501-
// for the prefixes of all mergeable ranks.
502-
// Modulus * prime must be less than 2^63-1 to avoid overflow.
503-
const PRIME: i64 = 997;
504-
const PRIME_INVERSE: i64 = 617853560682069;
505-
const MODULUS: i64 = 1e15 as i64 + 37;
506-
507-
const _: () = assert!(PRIME > 256, "PRIME must be greater than 256 for byte-wise rolling hash");
508-
const _: () = assert!(PRIME < MODULUS, "PRIME must be less than MODULUS");
509-
const _: () = assert!(
510-
MODULUS as i128 * PRIME as i128 <= i64::MAX as i128,
511-
"MODULUS * PRIME must not exceed i64::MAX to avoid overflow"
512-
);
513-
const _: () = assert!(
514-
(PRIME as i128 * PRIME_INVERSE as i128) % MODULUS as i128 == 1,
515-
"PRIME_INVERSE must be the modular multiplicative inverse of PRIME"
516-
);
517-
const _: () = assert!(is_prime(PRIME as u64), "PRIME must be a prime number");
518-
const _: () = assert!(is_prime(MODULUS as u64), "MODULUS must be a prime number");
519-
520-
521-
fn roll_hash(old: i64, new: u8) -> i64 {
522-
(((old * PRIME) % MODULUS) + (new as i64)) % MODULUS
523-
}
524-
525-
fn roll_hash_back(old: i64, new: u8) -> i64 {
526-
((((old + MODULUS) - (new as i64)) % MODULUS) * PRIME_INVERSE) % MODULUS
527-
}
528-
529-
530-
fn roll_hash_slice(slice: &[u8]) -> i64 {
531-
let mut hash = 0;
532-
for &byte in slice {
533-
hash = roll_hash(hash, byte);
534-
}
535-
hash
536-
}
537500
#[cfg(test)]
538501
mod tests {
539502
use crate::{EncodingFactory, EncodingFactoryError};
@@ -542,13 +505,6 @@ mod tests {
542505
use test_case::test_case;
543506
use memory_stats::memory_stats;
544507

545-
#[test]
546-
fn test_roll_hash() {
547-
let result = roll_hash_back(roll_hash(roll_hash(0, 10), 17), 17);
548-
let r2 = roll_hash(0, 10);
549-
assert_eq!(result, r2);
550-
}
551-
552508
#[test_case(EncodingFactory::llama3 ; "llama3")]
553509
#[test_case(EncodingFactory::codestral ; "codestral")]
554510
#[test_case(EncodingFactory::cl100k_im ; "cl100k_im")]

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod corebpe;
22
mod encoding;
33
mod load;
44
mod openai_public;
5+
mod rollhash;
56

67
#[cfg(test)]
78
mod tests;

src/rollhash.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
use const_primes::is_prime;
2+
3+
// Chose a prime number greater than 256 that minimizes hash collisions
4+
// for the prefixes of all mergeable ranks.
5+
// Modulus * prime must be less than 2^63-1 to avoid overflow.
6+
const PRIME: i64 = 997;
7+
const PRIME_INVERSE: i64 = 617853560682069;
8+
const MODULUS: i64 = 1e15 as i64 + 37;
9+
10+
const _: () = assert!(PRIME > 256, "PRIME must be greater than 256 for byte-wise rolling hash");
11+
const _: () = assert!(PRIME < MODULUS, "PRIME must be less than MODULUS");
12+
const _: () = assert!(
13+
MODULUS as i128 * PRIME as i128 <= i64::MAX as i128,
14+
"MODULUS * PRIME must not exceed i64::MAX to avoid overflow"
15+
);
16+
const _: () = assert!(
17+
(PRIME as i128 * PRIME_INVERSE as i128) % MODULUS as i128 == 1,
18+
"PRIME_INVERSE must be the modular multiplicative inverse of PRIME"
19+
);
20+
const _: () = assert!(is_prime(PRIME as u64), "PRIME must be a prime number");
21+
const _: () = assert!(is_prime(MODULUS as u64), "MODULUS must be a prime number");
22+
23+
24+
pub fn roll_hash(old: i64, new: u8) -> i64 {
25+
(((old * PRIME) % MODULUS) + (new as i64)) % MODULUS
26+
}
27+
28+
#[allow(dead_code)]
29+
fn roll_hash_back(old: i64, new: u8) -> i64 {
30+
((((old + MODULUS) - (new as i64)) % MODULUS) * PRIME_INVERSE) % MODULUS
31+
}
32+
33+
pub fn roll_hash_slice(slice: &[u8]) -> i64 {
34+
let mut hash = 0;
35+
for &byte in slice {
36+
hash = roll_hash(hash, byte);
37+
}
38+
hash
39+
}
40+
41+
#[cfg(test)]
42+
mod tests {
43+
use super::*;
44+
45+
#[test]
46+
fn test_roll_hash() {
47+
let result = roll_hash_back(roll_hash(roll_hash(0, 10), 17), 17);
48+
let r2 = roll_hash(0, 10);
49+
assert_eq!(result, r2);
50+
}
51+
}

0 commit comments

Comments
 (0)