Skip to content

Commit 55910cb

Browse files
committed
Minor performance improvements.
1 parent cef460d commit 55910cb

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

src/semchunk/semchunk.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,10 @@
88
)
99
"""A tuple of semantically meaningful non-whitespace splitters that may be used to chunk texts, ordered from most desirable to least desirable."""
1010

11-
def _split_text(text: str) -> tuple[str, list[str]]:
11+
def _split_text(text: str) -> tuple[str, bool, list[str]]:
1212
"""Split text using the most semantically meaningful splitter possible."""
13+
14+
splitter_is_whitespace = True
1315

1416
# Try splitting at, in order of most desirable to least desirable:
1517
# - The largest sequence of newlines and/or carriage returns;
@@ -29,14 +31,15 @@ def _split_text(text: str) -> tuple[str, list[str]]:
2931
# Identify the most desirable semantically meaningful non-whitespace splitter present in the text.
3032
for splitter in NON_WHITESPACE_SEMANTIC_SPLITTERS:
3133
if splitter in text:
34+
splitter_is_whitespace = False
3235
break
3336

3437
# If no semantically meaningful splitter is present in the text, return an empty string as the splitter and the text as a list of characters.
3538
else: # NOTE This code block will only be executed if the for loop completes without breaking.
36-
return '', list(text)
39+
return '', splitter_is_whitespace, list(text)
3740

3841
# Return the splitter and the split text.
39-
return splitter, text.split(splitter)
42+
return splitter, splitter_is_whitespace, text.split(splitter)
4043

4144
def chunk(text: str, chunk_size: int, token_counter: callable, _recursion_depth: int = 0) -> list[str]:
4245
"""Split text into semantically meaningful chunks of a specified size as determined by the provided token counter.
@@ -50,13 +53,10 @@ def chunk(text: str, chunk_size: int, token_counter: callable, _recursion_depth:
5053
list[str]: A list of chunks up to `chunk_size`-tokens-long, with any whitespace used to split the text removed."""
5154

5255
# Split the text using the most semantically meaningful splitter possible.
53-
splitter, splits = _split_text(text)
54-
55-
# Flag whether the splitter is whitespace.
56-
splitter_is_whitespace = not splitter.split()
56+
splitter, splitter_is_whitespace, splits = _split_text(text)
5757

5858
chunks = []
59-
skips = []
59+
skips = set()
6060
"""A list of indices of splits to skip because they have already been added to a chunk."""
6161

6262
# Iterate through the splits.
@@ -71,26 +71,23 @@ def chunk(text: str, chunk_size: int, token_counter: callable, _recursion_depth:
7171

7272
# If the split is equal to or under the chunk size, merge it with all subsequent splits until the chunk size is reached.
7373
else:
74-
# Initialise a list of splits to be merged into a new chunk.
75-
new_chunk = [split]
74+
# Initalise the new chunk.
75+
new_chunk = split
7676

7777
# Iterate through each subsequent split until the chunk size is reached.
7878
for j, next_split in enumerate(splits[i+1:], start=i+1):
7979
# Check whether the next split can be added to the chunk without exceeding the chunk size.
80-
if token_counter(splitter.join(new_chunk+[next_split])) <= chunk_size:
81-
# Add the next split to the chunk.
82-
new_chunk.append(next_split)
80+
if token_counter(updated_chunk:=new_chunk+splitter+next_split) <= chunk_size:
81+
# Add the next split to the new chunk.
82+
new_chunk = updated_chunk
8383

8484
# Add the index of the next split to the list of indices to skip.
85-
skips.append(j)
85+
skips.add(j)
8686

8787
# If the next split cannot be added to the chunk without exceeding the chunk size, break.
8888
else:
8989
break
9090

91-
# Join the splits with the splitter.
92-
new_chunk = splitter.join(new_chunk)
93-
9491
# Add the chunk.
9592
chunks.append(new_chunk)
9693

@@ -104,6 +101,6 @@ def chunk(text: str, chunk_size: int, token_counter: callable, _recursion_depth:
104101

105102
# If this is not a recursive call, remove any empty chunks.
106103
if not _recursion_depth:
107-
chunks = [chunk for chunk in chunks if chunk]
104+
chunks = list(filter(None, chunks))
108105

109106
return chunks

0 commit comments

Comments
 (0)