@@ -86,46 +86,51 @@ def _split_text(text: str) -> tuple[str, bool, list[str]]:
8686 return splitter , splitter_is_whitespace , text .split (splitter )
8787
8888
89- def bisect_left (a : list , x : int , hi : int ) -> int :
90- lo = 0
91-
92- while lo < hi :
93- mid = (lo + hi ) // 2
89+ def bisect_left (sorted : list , target : int , low : int , high : int ) -> int :
90+ while low < high :
91+ mid = (low + high ) // 2
9492
95- if a [mid ] < x :
96- lo = mid + 1
93+ if sorted [mid ] < target :
94+ low = mid + 1
9795
9896 else :
99- hi = mid
97+ high = mid
10098
101- return lo
99+ return low
102100
103101
104102def merge_splits (
105- splits : list [str ], lens : list [int ], chunk_size : int , splitter : str , token_counter : Callable
103+ splits : list [str ], cum_lens : list [int ], chunk_size : int , splitter : str , token_counter : Callable , start : int
106104) -> tuple [int , str ]:
107105 """Merge splits until a chunk size is reached, returning the index of the last split included in the merged chunk along with the merged chunk itself."""
108106
109107 average = 0.2
110- low = 0
108+ low = start
111109 high = len (splits ) + 1
112- cumulative_lengths = list (accumulate (lens , initial = 0 ))
113- cumulative_lengths .append (cumulative_lengths [- 1 ])
110+
111+ offset = cum_lens [start ]
112+ target = offset + (chunk_size * average )
114113
115114 while low < high :
116- i = bisect_left (cumulative_lengths [low : high + 1 ], chunk_size * average , hi = (high - low ) + 1 )
117- midpoint = min (i + low , high - 1 )
115+ i = bisect_left (cum_lens , target , low = low , high = high )
116+ midpoint = min (i , high - 1 )
117+
118+ tokens = token_counter (splitter .join (splits [start :midpoint ]))
118119
119- tokens = token_counter ( splitter . join ( splits [: midpoint ]))
120+ local_cum = cum_lens [ midpoint ] - offset
120121
121- average = cumulative_lengths [midpoint ] / tokens if cumulative_lengths [midpoint ] and tokens > 0 else average
122+ if local_cum and tokens > 0 :
123+ average = local_cum / tokens
124+ target = offset + (chunk_size * average )
122125
123126 if tokens > chunk_size :
124127 high = midpoint
128+
125129 else :
126130 low = midpoint + 1
127131
128- return low - 1 , splitter .join (splits [: low - 1 ])
132+ end = low - 1
133+ return end , splitter .join (splits [start :end ])
129134
130135
131136def chunk (
@@ -175,6 +180,7 @@ def chunk(
175180 offsets : list = []
176181 splitter_len = len (splitter )
177182 split_lens = [len (split ) for split in splits ]
183+ cum_lens = list (accumulate (split_lens , initial = 0 ))
178184 split_starts = accumulate ([0 ] + [split_len + splitter_len for split_len in split_lens ])
179185 split_starts = [start + _start for start in split_starts ]
180186
@@ -206,17 +212,22 @@ def chunk(
206212 else :
207213 # Merge the split with subsequent splits until the chunk size is reached.
208214 final_split_in_chunk_i , new_chunk = merge_splits (
209- splits [i :], split_lens [i :], local_chunk_size , splitter , token_counter
215+ splits = splits ,
216+ cum_lens = cum_lens ,
217+ chunk_size = local_chunk_size ,
218+ splitter = splitter ,
219+ token_counter = token_counter ,
220+ start = i ,
210221 )
211222
212223 # Mark any splits included in the new chunk for exclusion from future chunks.
213- skips .update (range (i + 1 , i + final_split_in_chunk_i ))
224+ skips .update (range (i + 1 , final_split_in_chunk_i ))
214225
215226 # Add the chunk.
216227 chunks .append (new_chunk )
217228
218229 # Add the chunk's offsets.
219- split_end = split_starts [i + final_split_in_chunk_i ] - splitter_len
230+ split_end = split_starts [final_split_in_chunk_i ] - splitter_len
220231 offsets .append ((split_start , split_end ))
221232
222233 # If the splitter is not whitespace and the split is not the last split, add the splitter to the end of the latest chunk if doing so would not cause it to exceed the chunk size otherwise add the splitter as a new chunk.
0 commit comments