Skip to content

Commit e4b97cd

Browse files
nathanrchnSaibo-creator
authored andcommitted
add the continuous_batch_encode method and a config file to build the library on mac
1 parent 8f1da31 commit e4b97cd

File tree

3 files changed

+153
-69
lines changed

3 files changed

+153
-69
lines changed

.cargo/config.toml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
[target.x86_64-apple-darwin]
2+
rustflags = [
3+
"-C", "link-arg=-undefined",
4+
"-C", "link-arg=dynamic_lookup",
5+
]
6+
7+
[target.aarch64-apple-darwin]
8+
rustflags = [
9+
"-C", "link-arg=-undefined",
10+
"-C", "link-arg=dynamic_lookup",
11+
]

src/lib.rs

Lines changed: 118 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use bumpalo::{collections::Vec as BumpVec, Bump};
2-
use std::collections::{HashMap, HashSet, BTreeMap};
32
use itertools::Itertools;
43
use pyo3::prelude::*;
54
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
5+
use std::collections::{BTreeMap, HashMap, HashSet};
66

77
/// This is the config for the compression.
88
#[pyclass]
@@ -73,14 +73,10 @@ pub struct Codebook {
7373
}
7474

7575
impl Codebook {
76-
7776
pub fn new(config: CodebookConfig) -> Self {
7877
Self {
7978
base_ids2hyper_id_map: HashMap::with_capacity(config.max_codebook_size),
80-
merges: vec![
81-
usize::MAX;
82-
config.max_codebook_size * config.max_subtokens
83-
],
79+
merges: vec![usize::MAX; config.max_codebook_size * config.max_subtokens],
8480
active_hyper_ids: HashSet::with_capacity(config.max_codebook_size),
8581
updates: HashSet::with_capacity(config.max_codebook_size),
8682
buffer_ids_to_merge: Vec::with_capacity(config.max_subtokens),
@@ -105,7 +101,8 @@ impl Codebook {
105101
} else {
106102
hyper_id
107103
};
108-
self.base_ids2hyper_id_map.insert(base_ids.clone(), hyper_id);
104+
self.base_ids2hyper_id_map
105+
.insert(base_ids.clone(), hyper_id);
109106

110107
let index = hyper_id - self.config.initial_vocab_size;
111108
let start_index = index * self.config.max_subtokens;
@@ -140,9 +137,9 @@ impl Codebook {
140137

141138
let end_index = start_index + entry_length;
142139

143-
let target_range = update_index * self.config.max_subtokens..update_index * self.config.max_subtokens + entry_length;
144-
updates_vec[target_range.clone()]
145-
.copy_from_slice(&self.merges[start_index..end_index]);
140+
let target_range = update_index * self.config.max_subtokens
141+
..update_index * self.config.max_subtokens + entry_length;
142+
updates_vec[target_range.clone()].copy_from_slice(&self.merges[start_index..end_index]);
146143
updates_indices.push(index);
147144
}
148145

@@ -197,8 +194,7 @@ impl Codebook {
197194
for i in 0..size {
198195
let start_index = i * self.config.max_subtokens;
199196
let end_index = start_index + self.config.max_subtokens;
200-
let mut entry_vec: Vec<usize> =
201-
self.merges[start_index..end_index].to_vec();
197+
let mut entry_vec: Vec<usize> = self.merges[start_index..end_index].to_vec();
202198

203199
while entry_vec.last() == Some(&usize::MAX) {
204200
entry_vec.pop();
@@ -315,7 +311,10 @@ impl LZWCompressor {
315311
if buffer_ids_to_merge.len() > 0 {
316312
get_and_push(&mut compressed_ids, &codebook, &buffer_ids_to_merge);
317313
buffer_ids_to_merge.clear();
318-
log::debug!("force emitting buffer_ids_to_merge because of disabled id: {}", id);
314+
log::debug!(
315+
"force emitting buffer_ids_to_merge because of disabled id: {}",
316+
id
317+
);
319318
}
320319
get_and_push(&mut compressed_ids, &codebook, &vec![id]);
321320
log::debug!("emitting disabled id: {}", id);
@@ -342,7 +341,10 @@ impl LZWCompressor {
342341

343342
// reach the max number of subtokens, emit the buffer without adding new code
344343
if buffer_ids_to_merge.len() == self.config.max_subtokens {
345-
log::debug!("force emitting buffer_ids_to_merge because of max_subtokens reached: {:?}", buffer_ids_to_merge);
344+
log::debug!(
345+
"force emitting buffer_ids_to_merge because of max_subtokens reached: {:?}",
346+
buffer_ids_to_merge
347+
);
346348
get_and_push(&mut compressed_ids, &codebook, &buffer_ids_to_merge);
347349
buffer_ids_to_merge.clear();
348350
}
@@ -522,14 +524,20 @@ impl LZWCompressor {
522524
// so it must not be a new hyper id but an existing one
523525
// we just clear the buffer and continue
524526
if previous_ids.len() == self.config.max_subtokens {
525-
assert!(codebook.contains_key(&previous_ids), "previous_ids: {:?} not in codebook: {:?}", previous_ids, codebook);
526-
log::debug!("force emitting buffer_ids_to_merge because of max_subtokens reached: {:?}", previous_ids);
527+
assert!(
528+
codebook.contains_key(&previous_ids),
529+
"previous_ids: {:?} not in codebook: {:?}",
530+
previous_ids,
531+
codebook
532+
);
533+
log::debug!(
534+
"force emitting buffer_ids_to_merge because of max_subtokens reached: {:?}",
535+
previous_ids
536+
);
527537
previous_ids = decoded_ids.clone();
528538
continue;
529539
} else {
530-
531-
while decoded_ids.len() > 0
532-
{
540+
while decoded_ids.len() > 0 {
533541
previous_ids.push(decoded_ids[0]);
534542

535543
if !codebook.contains_key(&previous_ids) {
@@ -552,7 +560,10 @@ impl LZWCompressor {
552560
}
553561

554562
//print the codebook
555-
log::debug!("Final codebook built from fuzzy decode: {:?}", codebook.to_dict());
563+
log::debug!(
564+
"Final codebook built from fuzzy decode: {:?}",
565+
codebook.to_dict()
566+
);
556567

557568
(output_ids, codebook)
558569
}
@@ -624,7 +635,7 @@ impl LZWCompressor {
624635
max_codebook_size,
625636
max_subtokens,
626637
pad_token_id,
627-
Some(disabled_ids)
638+
Some(disabled_ids),
628639
),
629640
}
630641
}
@@ -672,10 +683,7 @@ impl LZWCompressor {
672683
/// compressed ids.
673684
///
674685
/// Returns a list of ids.
675-
pub fn decode(
676-
&self,
677-
compressed_ids: Vec<usize>
678-
) -> (Vec<usize>, Codebook) {
686+
pub fn decode(&self, compressed_ids: Vec<usize>) -> (Vec<usize>, Codebook) {
679687
self.internal_fuzzy_decode(&compressed_ids)
680688
}
681689

@@ -745,15 +753,56 @@ impl LZWCompressor {
745753
/// compressed ids.
746754
///
747755
/// Returns a list of ids.
748-
pub fn batch_decode(
749-
&self,
750-
compressed_ids: Vec<Vec<usize>>
751-
) -> Vec<(Vec<usize>, Codebook)> {
756+
pub fn batch_decode(&self, compressed_ids: Vec<Vec<usize>>) -> Vec<(Vec<usize>, Codebook)> {
752757
compressed_ids
753-
.par_iter()
754-
// .map(|ids| self.internal_decode(ids))
755-
.map(|ids| self.internal_fuzzy_decode(ids))
756-
.collect()
758+
.par_iter()
759+
// .map(|ids| self.internal_decode(ids))
760+
.map(|ids| self.internal_fuzzy_decode(ids))
761+
.collect()
762+
}
763+
764+
pub fn continuous_batch_encode(
765+
&self,
766+
ids: Vec<Vec<usize>>,
767+
max_length: usize,
768+
min_length: Option<usize>,
769+
use_padding: Option<bool>,
770+
) -> (Vec<Vec<usize>>, Vec<Vec<Vec<usize>>>) {
771+
let min_length = min_length.unwrap_or(0);
772+
let use_padding = use_padding.unwrap_or(true);
773+
let padding_strategy = if use_padding {
774+
PaddingStrategy::MaxLength
775+
} else {
776+
PaddingStrategy::DoNotPad
777+
};
778+
779+
let (compressed_ids, codebooks): (Vec<Vec<usize>>, Vec<Codebook>) = ids
780+
.par_iter()
781+
.flat_map(|ids| {
782+
let mut offset = 0;
783+
let mut chunks = Vec::new();
784+
while min_length < (ids.len() - offset) {
785+
let ((c_ids, codebook), new_offset) = self.internal_encode(
786+
&ids,
787+
offset,
788+
padding_strategy,
789+
true,
790+
Some(max_length),
791+
);
792+
793+
offset = new_offset;
794+
chunks.push((c_ids, codebook));
795+
}
796+
chunks
797+
})
798+
.unzip();
799+
800+
let codebooks_as_lists = codebooks
801+
.par_iter()
802+
.map(|codebook| codebook.to_list(use_padding))
803+
.collect();
804+
805+
(compressed_ids, codebooks_as_lists)
757806
}
758807
}
759808

@@ -846,12 +895,7 @@ impl CodebookManager {
846895
}
847896
}
848897

849-
850-
fn internal_fuzzy_update_codebook(
851-
py: Python<'_>,
852-
state: &mut CodebookState,
853-
ids: &[usize],
854-
) {
898+
fn internal_fuzzy_update_codebook(py: Python<'_>, state: &mut CodebookState, ids: &[usize]) {
855899
log::debug!("Hitting fn internal_fuzzy_update_codebook");
856900
log::debug!("arg ids: {:?}", ids);
857901

@@ -868,7 +912,10 @@ impl CodebookManager {
868912
if config.disabled_ids.contains(&maybe_hid) {
869913
// no need to update the codebook, because buffer_ids_to_merge is already in the codebook
870914
// and buffer_ids_to_merge + maybe_hid is forbidden
871-
log::debug!("Got disabled id: {}, clearing buffer_ids_to_merge", maybe_hid);
915+
log::debug!(
916+
"Got disabled id: {}, clearing buffer_ids_to_merge",
917+
maybe_hid
918+
);
872919
state.buffer_ids_to_merge.clear();
873920
continue;
874921
}
@@ -878,7 +925,8 @@ impl CodebookManager {
878925
current_ids = vec![maybe_hid];
879926
} else if let Some(base_ids) = codebook.get_base_ids(maybe_hid) {
880927
current_ids = base_ids.clone();
881-
} else { // (2) cSc pattern
928+
} else {
929+
// (2) cSc pattern
882930
log::debug!("Unknown id: {}, because of cSc pattern merge", maybe_hid);
883931
current_ids = state.buffer_ids_to_merge.clone();
884932
current_ids.push(state.buffer_ids_to_merge[0]);
@@ -893,7 +941,6 @@ impl CodebookManager {
893941
state.next_id += 1;
894942
state.buffer_ids_to_merge = current_ids.clone();
895943
continue;
896-
897944
}
898945

899946
if state.next_id == config.initial_vocab_size + config.max_codebook_size {
@@ -913,12 +960,16 @@ impl CodebookManager {
913960
state.buffer_ids_to_merge = current_ids.clone();
914961
continue;
915962
} else {
916-
while current_ids.len() > 0 {
963+
while current_ids.len() > 0 {
917964
state.buffer_ids_to_merge.push(current_ids[0]);
918965

919966
if !codebook.contains_key(&state.buffer_ids_to_merge) {
920967
codebook.insert(state.buffer_ids_to_merge.clone(), state.next_id);
921-
log::debug!("inserting: {:?} -> {:?}", state.buffer_ids_to_merge, state.next_id);
968+
log::debug!(
969+
"inserting: {:?} -> {:?}",
970+
state.buffer_ids_to_merge,
971+
state.next_id
972+
);
922973
state.next_id += 1;
923974
state.buffer_ids_to_merge = current_ids.clone();
924975
break;
@@ -1013,37 +1064,36 @@ impl CodebookManager {
10131064
}
10141065

10151066
self.set_codebooks(codebooks);
1016-
}// renormalizing_lzw requires the codebook to be initialized with codebook from the tokenizer
1067+
} // renormalizing_lzw requires the codebook to be initialized with codebook from the tokenizer
10171068

10181069
assert_eq!(ids.len(), self.states.len());
10191070
let max_ids_length = ids.iter().map(|i| i.len()).max().unwrap();
10201071

1021-
10221072
let (mut updates, updates_indices): (Vec<Vec<usize>>, Vec<Vec<usize>>) = self
1023-
.states
1024-
.iter_mut()
1025-
.zip(ids.iter())
1026-
.map(|(state, ids)| {
1027-
// choose the implementation for every call, even the first one
1028-
match self.algorithm.as_str() {
1029-
"fault_tolerant_lzw" => {
1030-
CodebookManager::internal_fuzzy_update_codebook(py, state, ids)
1031-
}
1032-
"renormalizing_lzw" => {
1033-
if !self.first_updates {
1034-
CodebookManager::internal_update_codebook(py, state, ids)
1073+
.states
1074+
.iter_mut()
1075+
.zip(ids.iter())
1076+
.map(|(state, ids)| {
1077+
// choose the implementation for every call, even the first one
1078+
match self.algorithm.as_str() {
1079+
"fault_tolerant_lzw" => {
1080+
CodebookManager::internal_fuzzy_update_codebook(py, state, ids)
10351081
}
1082+
"renormalizing_lzw" => {
1083+
if !self.first_updates {
1084+
CodebookManager::internal_update_codebook(py, state, ids)
1085+
}
1086+
}
1087+
_ => panic!("Invalid algorithm: {}", self.algorithm),
10361088
}
1037-
_ => panic!("Invalid algorithm: {}", self.algorithm),
1038-
}
10391089

1040-
// collect buffered updates from this state's codebook
1041-
state
1042-
.codebook
1043-
.borrow_mut(py)
1044-
.get_updates(self.first_updates) // still tell it whether this was the first call
1045-
})
1046-
.unzip();
1090+
// collect buffered updates from this state's codebook
1091+
state
1092+
.codebook
1093+
.borrow_mut(py)
1094+
.get_updates(self.first_updates) // still tell it whether this was the first call
1095+
})
1096+
.unzip();
10471097

10481098
// If the sequence is only one token (not the first one), we need to
10491099
// pad the updates to the longest sequence in the batch.
@@ -1065,7 +1115,7 @@ impl CodebookManager {
10651115
pub fn get_codebooks(&self) -> Vec<Py<Codebook>> {
10661116
self.states
10671117
.iter()
1068-
.map(|st| st.codebook.clone()) // clone the Py<…> handle
1118+
.map(|st| st.codebook.clone()) // clone the Py<…> handle
10691119
.collect()
10701120
}
10711121

@@ -1087,7 +1137,6 @@ impl CodebookManager {
10871137
}
10881138
}
10891139

1090-
10911140
#[pymodule]
10921141
fn zip2zip_compression(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
10931142
env_logger::init();

zip2zip_compression.pyi

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,30 @@ class LZWCompressor:
130130
"""
131131
...
132132

133+
def continuous_batch_encode(
134+
self,
135+
ids: List[List[int]],
136+
max_length: int,
137+
min_length: Optional[int] = 0,
138+
use_padding: Optional[bool] = True,
139+
) -> Tuple[List[List[int]], List[List[List[int]]]]:
140+
"""
141+
Encode a batch of sequences of tokens in a continuous manner. This method will try to consume
142+
the sequences as much as possible, and return the compressed ids and the codebooks.
143+
144+
Args:
145+
ids: The batch of sequences of tokens to encode.
146+
max_length: The maximum length of the sequences.
147+
min_length: The minimum length of the sequences. This can be used to limit the padding
148+
and discard the sequences that are too short. Default is 0.
149+
use_padding: If the compressed ids are padded to the `max_length` and the codebooks are
150+
padded to the maximum size (max subtokens, and max entries).
151+
152+
Returns:
153+
A tuple containing the compressed ids and the codebooks.
154+
"""
155+
...
156+
133157
class CodebookManager:
134158
"""
135159
Class for managing the codebooks. During generation, the codebook manager will update the codebooks

0 commit comments

Comments
 (0)