@@ -75,13 +75,13 @@ def bisect_left(a: list, x: int, hi: int) -> int:
7575
7676 return lo
7777
78- def merge_splits (splits : list [str ], chunk_size : int , splitter : str , token_counter : Callable ) -> tuple [int , str ]:
78+ def merge_splits (splits : list [str ], lens : list [ int ], chunk_size : int , splitter : str , token_counter : Callable ) -> tuple [int , str ]:
7979 """Merge splits until a chunk size is reached, returning the index of the last split included in the merged chunk along with the merged chunk itself."""
8080
8181 average = 0.2
8282 low = 0
8383 high = len (splits ) + 1
84- cumulative_lengths = list (accumulate ([ len ( split ) for split in splits ] , initial = 0 ))
84+ cumulative_lengths = list (accumulate (lens , initial = 0 ))
8585 cumulative_lengths .append (cumulative_lengths [- 1 ])
8686
8787 while low < high :
@@ -145,7 +145,8 @@ def chunk(
145145
146146 offsets : list = []
147147 splitter_len = len (splitter )
148- split_starts = accumulate ([0 ] + [len (split ) + splitter_len for split in splits ])
148+ split_lens = [len (split ) for split in splits ]
149+ split_starts = accumulate ([0 ] + [split_len + splitter_len for split_len in split_lens ])
149150 split_starts = [start + _start for start in split_starts ]
150151
151152 chunks = []
@@ -168,7 +169,7 @@ def chunk(
168169 # If the split is equal to or under the chunk size, add it and any subsequent splits to a new chunk until the chunk size is reached.
169170 else :
170171 # Merge the split with subsequent splits until the chunk size is reached.
171- final_split_in_chunk_i , new_chunk = merge_splits (splits [i :], local_chunk_size , splitter , token_counter )
172+ final_split_in_chunk_i , new_chunk = merge_splits (splits [i :], split_lens [ i :], local_chunk_size , splitter , token_counter )
172173
173174 # Mark any splits included in the new chunk for exclusion from future chunks.
174175 skips .update (range (i + 1 , i + final_split_in_chunk_i ))
0 commit comments