11use bumpalo:: { collections:: Vec as BumpVec , Bump } ;
2- use std:: collections:: { HashMap , HashSet , BTreeMap } ;
32use itertools:: Itertools ;
43use pyo3:: prelude:: * ;
54use rayon:: iter:: { IntoParallelRefIterator , ParallelIterator } ;
5+ use std:: collections:: { BTreeMap , HashMap , HashSet } ;
66
77/// This is the config for the compression.
88#[ pyclass]
@@ -73,14 +73,10 @@ pub struct Codebook {
7373}
7474
7575impl Codebook {
76-
7776 pub fn new ( config : CodebookConfig ) -> Self {
7877 Self {
7978 base_ids2hyper_id_map : HashMap :: with_capacity ( config. max_codebook_size ) ,
80- merges : vec ! [
81- usize :: MAX ;
82- config. max_codebook_size * config. max_subtokens
83- ] ,
79+ merges : vec ! [ usize :: MAX ; config. max_codebook_size * config. max_subtokens] ,
8480 active_hyper_ids : HashSet :: with_capacity ( config. max_codebook_size ) ,
8581 updates : HashSet :: with_capacity ( config. max_codebook_size ) ,
8682 buffer_ids_to_merge : Vec :: with_capacity ( config. max_subtokens ) ,
@@ -105,7 +101,8 @@ impl Codebook {
105101 } else {
106102 hyper_id
107103 } ;
108- self . base_ids2hyper_id_map . insert ( base_ids. clone ( ) , hyper_id) ;
104+ self . base_ids2hyper_id_map
105+ . insert ( base_ids. clone ( ) , hyper_id) ;
109106
110107 let index = hyper_id - self . config . initial_vocab_size ;
111108 let start_index = index * self . config . max_subtokens ;
@@ -140,9 +137,9 @@ impl Codebook {
140137
141138 let end_index = start_index + entry_length;
142139
143- let target_range = update_index * self . config . max_subtokens ..update_index * self . config . max_subtokens + entry_length ;
144- updates_vec [ target_range . clone ( ) ]
145- . copy_from_slice ( & self . merges [ start_index..end_index] ) ;
140+ let target_range = update_index * self . config . max_subtokens
141+ ..update_index * self . config . max_subtokens + entry_length ;
142+ updates_vec [ target_range . clone ( ) ] . copy_from_slice ( & self . merges [ start_index..end_index] ) ;
146143 updates_indices. push ( index) ;
147144 }
148145
@@ -197,8 +194,7 @@ impl Codebook {
197194 for i in 0 ..size {
198195 let start_index = i * self . config . max_subtokens ;
199196 let end_index = start_index + self . config . max_subtokens ;
200- let mut entry_vec: Vec < usize > =
201- self . merges [ start_index..end_index] . to_vec ( ) ;
197+ let mut entry_vec: Vec < usize > = self . merges [ start_index..end_index] . to_vec ( ) ;
202198
203199 while entry_vec. last ( ) == Some ( & usize:: MAX ) {
204200 entry_vec. pop ( ) ;
@@ -315,7 +311,10 @@ impl LZWCompressor {
315311 if buffer_ids_to_merge. len ( ) > 0 {
316312 get_and_push ( & mut compressed_ids, & codebook, & buffer_ids_to_merge) ;
317313 buffer_ids_to_merge. clear ( ) ;
318- log:: debug!( "force emitting buffer_ids_to_merge because of disabled id: {}" , id) ;
314+ log:: debug!(
315+ "force emitting buffer_ids_to_merge because of disabled id: {}" ,
316+ id
317+ ) ;
319318 }
320319 get_and_push ( & mut compressed_ids, & codebook, & vec ! [ id] ) ;
321320 log:: debug!( "emitting disabled id: {}" , id) ;
@@ -342,7 +341,10 @@ impl LZWCompressor {
342341
343342 // reach the max number of subtokens, emit the buffer without adding new code
344343 if buffer_ids_to_merge. len ( ) == self . config . max_subtokens {
345- log:: debug!( "force emitting buffer_ids_to_merge because of max_subtokens reached: {:?}" , buffer_ids_to_merge) ;
344+ log:: debug!(
345+ "force emitting buffer_ids_to_merge because of max_subtokens reached: {:?}" ,
346+ buffer_ids_to_merge
347+ ) ;
346348 get_and_push ( & mut compressed_ids, & codebook, & buffer_ids_to_merge) ;
347349 buffer_ids_to_merge. clear ( ) ;
348350 }
@@ -522,14 +524,20 @@ impl LZWCompressor {
522524 // so it must not be a new hyper id but an existing one
523525 // we just clear the buffer and continue
524526 if previous_ids. len ( ) == self . config . max_subtokens {
525- assert ! ( codebook. contains_key( & previous_ids) , "previous_ids: {:?} not in codebook: {:?}" , previous_ids, codebook) ;
526- log:: debug!( "force emitting buffer_ids_to_merge because of max_subtokens reached: {:?}" , previous_ids) ;
527+ assert ! (
528+ codebook. contains_key( & previous_ids) ,
529+ "previous_ids: {:?} not in codebook: {:?}" ,
530+ previous_ids,
531+ codebook
532+ ) ;
533+ log:: debug!(
534+ "force emitting buffer_ids_to_merge because of max_subtokens reached: {:?}" ,
535+ previous_ids
536+ ) ;
527537 previous_ids = decoded_ids. clone ( ) ;
528538 continue ;
529539 } else {
530-
531- while decoded_ids. len ( ) > 0
532- {
540+ while decoded_ids. len ( ) > 0 {
533541 previous_ids. push ( decoded_ids[ 0 ] ) ;
534542
535543 if !codebook. contains_key ( & previous_ids) {
@@ -552,7 +560,10 @@ impl LZWCompressor {
552560 }
553561
554562 //print the codebook
555- log:: debug!( "Final codebook built from fuzzy decode: {:?}" , codebook. to_dict( ) ) ;
563+ log:: debug!(
564+ "Final codebook built from fuzzy decode: {:?}" ,
565+ codebook. to_dict( )
566+ ) ;
556567
557568 ( output_ids, codebook)
558569 }
@@ -624,7 +635,7 @@ impl LZWCompressor {
624635 max_codebook_size,
625636 max_subtokens,
626637 pad_token_id,
627- Some ( disabled_ids)
638+ Some ( disabled_ids) ,
628639 ) ,
629640 }
630641 }
@@ -672,10 +683,7 @@ impl LZWCompressor {
672683 /// compressed ids.
673684 ///
674685 /// Returns a list of ids.
675- pub fn decode (
676- & self ,
677- compressed_ids : Vec < usize >
678- ) -> ( Vec < usize > , Codebook ) {
686+ pub fn decode ( & self , compressed_ids : Vec < usize > ) -> ( Vec < usize > , Codebook ) {
679687 self . internal_fuzzy_decode ( & compressed_ids)
680688 }
681689
@@ -745,15 +753,56 @@ impl LZWCompressor {
745753 /// compressed ids.
746754 ///
747755 /// Returns a list of ids.
748- pub fn batch_decode (
749- & self ,
750- compressed_ids : Vec < Vec < usize > >
751- ) -> Vec < ( Vec < usize > , Codebook ) > {
756+ pub fn batch_decode ( & self , compressed_ids : Vec < Vec < usize > > ) -> Vec < ( Vec < usize > , Codebook ) > {
752757 compressed_ids
753- . par_iter ( )
754- // .map(|ids| self.internal_decode(ids))
755- . map ( |ids| self . internal_fuzzy_decode ( ids) )
756- . collect ( )
758+ . par_iter ( )
759+ // .map(|ids| self.internal_decode(ids))
760+ . map ( |ids| self . internal_fuzzy_decode ( ids) )
761+ . collect ( )
762+ }
763+
764+ pub fn continuous_batch_encode (
765+ & self ,
766+ ids : Vec < Vec < usize > > ,
767+ max_length : usize ,
768+ min_length : Option < usize > ,
769+ use_padding : Option < bool > ,
770+ ) -> ( Vec < Vec < usize > > , Vec < Vec < Vec < usize > > > ) {
771+ let min_length = min_length. unwrap_or ( 0 ) ;
772+ let use_padding = use_padding. unwrap_or ( true ) ;
773+ let padding_strategy = if use_padding {
774+ PaddingStrategy :: MaxLength
775+ } else {
776+ PaddingStrategy :: DoNotPad
777+ } ;
778+
779+ let ( compressed_ids, codebooks) : ( Vec < Vec < usize > > , Vec < Codebook > ) = ids
780+ . par_iter ( )
781+ . flat_map ( |ids| {
782+ let mut offset = 0 ;
783+ let mut chunks = Vec :: new ( ) ;
784+ while min_length < ( ids. len ( ) - offset) {
785+ let ( ( c_ids, codebook) , new_offset) = self . internal_encode (
786+ & ids,
787+ offset,
788+ padding_strategy,
789+ true ,
790+ Some ( max_length) ,
791+ ) ;
792+
793+ offset = new_offset;
794+ chunks. push ( ( c_ids, codebook) ) ;
795+ }
796+ chunks
797+ } )
798+ . unzip ( ) ;
799+
800+ let codebooks_as_lists = codebooks
801+ . par_iter ( )
802+ . map ( |codebook| codebook. to_list ( use_padding) )
803+ . collect ( ) ;
804+
805+ ( compressed_ids, codebooks_as_lists)
757806 }
758807}
759808
@@ -846,12 +895,7 @@ impl CodebookManager {
846895 }
847896 }
848897
849-
850- fn internal_fuzzy_update_codebook (
851- py : Python < ' _ > ,
852- state : & mut CodebookState ,
853- ids : & [ usize ] ,
854- ) {
898+ fn internal_fuzzy_update_codebook ( py : Python < ' _ > , state : & mut CodebookState , ids : & [ usize ] ) {
855899 log:: debug!( "Hitting fn internal_fuzzy_update_codebook" ) ;
856900 log:: debug!( "arg ids: {:?}" , ids) ;
857901
@@ -868,7 +912,10 @@ impl CodebookManager {
868912 if config. disabled_ids . contains ( & maybe_hid) {
869913 // no need to update the codebook, because buffer_ids_to_merge is already in the codebook
870914 // and buffer_ids_to_merge + maybe_hid is forbidden
871- log:: debug!( "Got disabled id: {}, clearing buffer_ids_to_merge" , maybe_hid) ;
915+ log:: debug!(
916+ "Got disabled id: {}, clearing buffer_ids_to_merge" ,
917+ maybe_hid
918+ ) ;
872919 state. buffer_ids_to_merge . clear ( ) ;
873920 continue ;
874921 }
@@ -878,7 +925,8 @@ impl CodebookManager {
878925 current_ids = vec ! [ maybe_hid] ;
879926 } else if let Some ( base_ids) = codebook. get_base_ids ( maybe_hid) {
880927 current_ids = base_ids. clone ( ) ;
881- } else { // (2) cSc pattern
928+ } else {
929+ // (2) cSc pattern
882930 log:: debug!( "Unknown id: {}, because of cSc pattern merge" , maybe_hid) ;
883931 current_ids = state. buffer_ids_to_merge . clone ( ) ;
884932 current_ids. push ( state. buffer_ids_to_merge [ 0 ] ) ;
@@ -893,7 +941,6 @@ impl CodebookManager {
893941 state. next_id += 1 ;
894942 state. buffer_ids_to_merge = current_ids. clone ( ) ;
895943 continue ;
896-
897944 }
898945
899946 if state. next_id == config. initial_vocab_size + config. max_codebook_size {
@@ -913,12 +960,16 @@ impl CodebookManager {
913960 state. buffer_ids_to_merge = current_ids. clone ( ) ;
914961 continue ;
915962 } else {
916- while current_ids. len ( ) > 0 {
963+ while current_ids. len ( ) > 0 {
917964 state. buffer_ids_to_merge . push ( current_ids[ 0 ] ) ;
918965
919966 if !codebook. contains_key ( & state. buffer_ids_to_merge ) {
920967 codebook. insert ( state. buffer_ids_to_merge . clone ( ) , state. next_id ) ;
921- log:: debug!( "inserting: {:?} -> {:?}" , state. buffer_ids_to_merge, state. next_id) ;
968+ log:: debug!(
969+ "inserting: {:?} -> {:?}" ,
970+ state. buffer_ids_to_merge,
971+ state. next_id
972+ ) ;
922973 state. next_id += 1 ;
923974 state. buffer_ids_to_merge = current_ids. clone ( ) ;
924975 break ;
@@ -1013,37 +1064,36 @@ impl CodebookManager {
10131064 }
10141065
10151066 self . set_codebooks ( codebooks) ;
1016- } // renormalizing_lzw requires the codebook to be initialized with codebook from the tokenizer
1067+ } // renormalizing_lzw requires the codebook to be initialized with codebook from the tokenizer
10171068
10181069 assert_eq ! ( ids. len( ) , self . states. len( ) ) ;
10191070 let max_ids_length = ids. iter ( ) . map ( |i| i. len ( ) ) . max ( ) . unwrap ( ) ;
10201071
1021-
10221072 let ( mut updates, updates_indices) : ( Vec < Vec < usize > > , Vec < Vec < usize > > ) = self
1023- . states
1024- . iter_mut ( )
1025- . zip ( ids. iter ( ) )
1026- . map ( |( state, ids) | {
1027- // choose the implementation for every call, even the first one
1028- match self . algorithm . as_str ( ) {
1029- "fault_tolerant_lzw" => {
1030- CodebookManager :: internal_fuzzy_update_codebook ( py, state, ids)
1031- }
1032- "renormalizing_lzw" => {
1033- if !self . first_updates {
1034- CodebookManager :: internal_update_codebook ( py, state, ids)
1073+ . states
1074+ . iter_mut ( )
1075+ . zip ( ids. iter ( ) )
1076+ . map ( |( state, ids) | {
1077+ // choose the implementation for every call, even the first one
1078+ match self . algorithm . as_str ( ) {
1079+ "fault_tolerant_lzw" => {
1080+ CodebookManager :: internal_fuzzy_update_codebook ( py, state, ids)
10351081 }
1082+ "renormalizing_lzw" => {
1083+ if !self . first_updates {
1084+ CodebookManager :: internal_update_codebook ( py, state, ids)
1085+ }
1086+ }
1087+ _ => panic ! ( "Invalid algorithm: {}" , self . algorithm) ,
10361088 }
1037- _ => panic ! ( "Invalid algorithm: {}" , self . algorithm) ,
1038- }
10391089
1040- // collect buffered updates from this state's codebook
1041- state
1042- . codebook
1043- . borrow_mut ( py)
1044- . get_updates ( self . first_updates ) // still tell it whether this was the first call
1045- } )
1046- . unzip ( ) ;
1090+ // collect buffered updates from this state's codebook
1091+ state
1092+ . codebook
1093+ . borrow_mut ( py)
1094+ . get_updates ( self . first_updates ) // still tell it whether this was the first call
1095+ } )
1096+ . unzip ( ) ;
10471097
10481098 // If the sequence is only one token (not the first one), we need to
10491099 // pad the updates to the longest sequence in the batch.
@@ -1065,7 +1115,7 @@ impl CodebookManager {
10651115 pub fn get_codebooks ( & self ) -> Vec < Py < Codebook > > {
10661116 self . states
10671117 . iter ( )
1068- . map ( |st| st. codebook . clone ( ) ) // clone the Py<…> handle
1118+ . map ( |st| st. codebook . clone ( ) ) // clone the Py<…> handle
10691119 . collect ( )
10701120 }
10711121
@@ -1087,7 +1137,6 @@ impl CodebookManager {
10871137 }
10881138}
10891139
1090-
10911140#[ pymodule]
10921141fn zip2zip_compression ( _py : Python < ' _ > , m : & PyModule ) -> PyResult < ( ) > {
10931142 env_logger:: init ( ) ;
0 commit comments