Skip to content

Commit 5c7abc8

Browse files
committed
Ensured split lengths are not recomputed unnecessarily.
1 parent 52d2baf commit 5c7abc8

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

src/semchunk/semchunk.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,13 @@ def bisect_left(a: list, x: int, hi: int) -> int:
7575

7676
return lo
7777

78-
def merge_splits(splits: list[str], chunk_size: int, splitter: str, token_counter: Callable) -> tuple[int, str]:
78+
def merge_splits(splits: list[str], lens: list[int], chunk_size: int, splitter: str, token_counter: Callable) -> tuple[int, str]:
7979
"""Merge splits until a chunk size is reached, returning the index of the last split included in the merged chunk along with the merged chunk itself."""
8080

8181
average = 0.2
8282
low = 0
8383
high = len(splits) + 1
84-
cumulative_lengths = list(accumulate([len(split) for split in splits], initial=0))
84+
cumulative_lengths = list(accumulate(lens, initial=0))
8585
cumulative_lengths.append(cumulative_lengths[-1])
8686

8787
while low < high:
@@ -145,7 +145,8 @@ def chunk(
145145

146146
offsets: list = []
147147
splitter_len = len(splitter)
148-
split_starts = accumulate([0] + [len(split) + splitter_len for split in splits])
148+
split_lens = [len(split) for split in splits]
149+
split_starts = accumulate([0] + [split_len + splitter_len for split_len in split_lens])
149150
split_starts = [start + _start for start in split_starts]
150151

151152
chunks = []
@@ -168,7 +169,7 @@ def chunk(
168169
# If the split is equal to or under the chunk size, add it and any subsequent splits to a new chunk until the chunk size is reached.
169170
else:
170171
# Merge the split with subsequent splits until the chunk size is reached.
171-
final_split_in_chunk_i, new_chunk = merge_splits(splits[i:], local_chunk_size, splitter, token_counter)
172+
final_split_in_chunk_i, new_chunk = merge_splits(splits[i:], split_lens[i:], local_chunk_size, splitter, token_counter)
172173

173174
# Mark any splits included in the new chunk for exclusion from future chunks.
174175
skips.update(range(i + 1, i + final_split_in_chunk_i))

0 commit comments

Comments
 (0)