diff --git a/src/lib.rs b/src/lib.rs index 80f4acc..b0ef4d8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,60 +15,129 @@ mod py; pub type Rank = u32; fn _byte_pair_merge(ranks: &HashMap, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> { - // This is a vector of (start, rank). - // The rank is of the pair starting at position start. - let mut parts = Vec::with_capacity(piece.len() + 1); - - // Note that we hash bytes when indexing into `ranks`, not token pairs. As long as we train BPE - // the way we currently do, this is equivalent. An easy way to break this would be to decouple - // merge priority from token index or to prevent specific token merges. - let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX); - for i in 0..piece.len() - 1 { - let rank = *ranks.get(&piece[i..i + 2]).unwrap_or(&Rank::MAX); - if rank < min_rank.0 { - min_rank = (rank, i); + use std::cmp::Ordering; + use std::collections::BinaryHeap; + + #[derive(Clone, Copy)] + struct Node { + prev: Option, + next: Option, + alive: bool, + } + + #[derive(Eq, Clone, Copy)] + struct Cand { + rank: Rank, + left: usize, + ver: u32, + } + + impl PartialEq for Cand { + fn eq(&self, other: &Self) -> bool { + self.rank == other.rank && self.left == other.left && self.ver == other.ver } - parts.push((i, rank)); } - parts.push((piece.len() - 1, Rank::MAX)); - parts.push((piece.len(), Rank::MAX)); - - let get_rank = { - #[inline(always)] - |parts: &Vec<(usize, Rank)>, i: usize| { - if (i + 3) < parts.len() { - // Similar to `piece[i..i + 2]` above. The +3 is because we haven't yet deleted - // parts[i + 1], see comment in the main loop. - *ranks - .get(&piece[parts[i].0..parts[i + 3].0]) - .unwrap_or(&Rank::MAX) - } else { - Rank::MAX + impl PartialOrd for Cand { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + impl Ord for Cand { + fn cmp(&self, other: &Self) -> Ordering { + other.rank.cmp(&self.rank) + .then_with(|| other.left.cmp(&self.left)) + .then_with(|| other.ver.cmp(&self.ver)) + } + } + + #[inline(always)] + fn compute_rank_at( + ranks: &HashMap, Rank>, + piece: &[u8], + nodes: &Vec, + i: usize, + ) -> Rank { + if let Some(j) = nodes[i].next { + if let Some(k) = nodes[j].next { + return *ranks.get(&piece[i..k]).unwrap_or(&Rank::MAX); } } - }; + Rank::MAX + } + + let n_bytes = piece.len(); + if n_bytes == 0 { + return vec![(0, Rank::MAX)]; + } + if n_bytes == 1 { + return vec![(0, Rank::MAX), (1, Rank::MAX)]; + } + + let num_nodes = n_bytes + 1; + let mut nodes: Vec = (0..num_nodes) + .map(|i| Node { + prev: if i > 0 { Some(i - 1) } else { None }, + next: if i + 1 < num_nodes { Some(i + 1) } else { None }, + alive: true, + }) + .collect(); + + let mut ver: Vec = vec![0; num_nodes]; - // If you have n parts and m merges, this does O(mn) work. - // We could do something with a heap and do O(m log n) work. - // n is often very small so considerations like cache-locality outweigh the algorithmic - // complexity downsides of the `parts` vector. - while min_rank.0 != Rank::MAX { - let i = min_rank.1; - // Update parts[i] and parts[i - 1] before removing parts[i + 1], since - // `parts.remove(i + 1)` will thrash the cache. - if i > 0 { - parts[i - 1].1 = get_rank(&parts, i - 1); + let mut heap = BinaryHeap::new(); + for i in 0..(num_nodes - 2) { + let rank = compute_rank_at(ranks, piece, &nodes, i); + if rank != Rank::MAX { + heap.push(Cand { rank, left: i, ver: ver[i] }); } - parts[i].1 = get_rank(&parts, i); - parts.remove(i + 1); + } + + while let Some(c) = heap.pop() { + if !nodes[c.left].alive { continue; } + if ver[c.left] != c.ver { continue; } + + let j = match nodes[c.left].next { Some(j) => j, None => continue }; + if !nodes[j].alive { continue; } + let k = match nodes[j].next { Some(k) => k, None => continue }; + if !nodes[k].alive { continue; } + + nodes[c.left].next = Some(k); + nodes[k].prev = Some(c.left); + nodes[j].alive = false; + + ver[c.left] = ver[c.left].wrapping_add(1); - min_rank = (Rank::MAX, usize::MAX); - for (i, &(_, rank)) in parts[..parts.len() - 1].iter().enumerate() { - if rank < min_rank.0 { - min_rank = (rank, i); + if let Some(p) = nodes[c.left].prev { + ver[p] = ver[p].wrapping_add(1); + let prank = compute_rank_at(ranks, piece, &nodes, p); + if prank != Rank::MAX { + heap.push(Cand { rank: prank, left: p, ver: ver[p] }); } } + + let crank = compute_rank_at(ranks, piece, &nodes, c.left); + if crank != Rank::MAX { + heap.push(Cand { rank: crank, left: c.left, ver: ver[c.left] }); + } } + + let mut parts: Vec<(usize, Rank)> = Vec::new(); + let mut cur = 0usize; + loop { + if nodes[cur].alive { + let r = compute_rank_at(ranks, piece, &nodes, cur); + parts.push((cur, r)); + } + match nodes[cur].next { + Some(n) => cur = n, + None => break, + } + } + + if parts.is_empty() || parts.last().unwrap().0 != n_bytes { + parts.push((n_bytes, Rank::MAX)); + } + parts } @@ -571,4 +640,14 @@ mod tests { let res = byte_pair_split(b"abab", &ranks); assert_eq!(res, vec![b"ab", b"ab"]); } + + #[test] + fn test__byte_pair_merge_boundaries() { + let ranks = setup_ranks(); + let piece = b"abcd"; + let parts = super::_byte_pair_merge(&ranks, piece); + let positions: Vec = parts.iter().map(|(i, _)| *i).collect(); + assert_eq!(positions, vec![0, 2, 4]); + assert_eq!(parts.last().unwrap().1, Rank::MAX); + } }