88)
99"""A tuple of semantically meaningful non-whitespace splitters that may be used to chunk texts, ordered from most desirable to least desirable."""
1010
11- def _split_text (text : str ) -> tuple [str , list [str ]]:
11+ def _split_text (text : str ) -> tuple [str , bool , list [str ]]:
1212 """Split text using the most semantically meaningful splitter possible."""
13+
14+ splitter_is_whitespace = True
1315
1416 # Try splitting at, in order of most desirable to least desirable:
1517 # - The largest sequence of newlines and/or carriage returns;
@@ -29,14 +31,15 @@ def _split_text(text: str) -> tuple[str, list[str]]:
2931 # Identify the most desirable semantically meaningful non-whitespace splitter present in the text.
3032 for splitter in NON_WHITESPACE_SEMANTIC_SPLITTERS :
3133 if splitter in text :
34+ splitter_is_whitespace = False
3235 break
3336
3437 # If no semantically meaningful splitter is present in the text, return an empty string as the splitter and the text as a list of characters.
3538 else : # NOTE This code block will only be executed if the for loop completes without breaking.
36- return '' , list (text )
39+ return '' , splitter_is_whitespace , list (text )
3740
3841 # Return the splitter and the split text.
39- return splitter , text .split (splitter )
42+ return splitter , splitter_is_whitespace , text .split (splitter )
4043
4144def chunk (text : str , chunk_size : int , token_counter : callable , _recursion_depth : int = 0 ) -> list [str ]:
4245 """Split text into semantically meaningful chunks of a specified size as determined by the provided token counter.
@@ -50,13 +53,10 @@ def chunk(text: str, chunk_size: int, token_counter: callable, _recursion_depth:
5053 list[str]: A list of chunks up to `chunk_size`-tokens-long, with any whitespace used to split the text removed."""
5154
5255 # Split the text using the most semantically meaningful splitter possible.
53- splitter , splits = _split_text (text )
54-
55- # Flag whether the splitter is whitespace.
56- splitter_is_whitespace = not splitter .split ()
56+ splitter , splitter_is_whitespace , splits = _split_text (text )
5757
5858 chunks = []
59- skips = []
59+ skips = set ()
6060 """A list of indices of splits to skip because they have already been added to a chunk."""
6161
6262 # Iterate through the splits.
@@ -71,26 +71,23 @@ def chunk(text: str, chunk_size: int, token_counter: callable, _recursion_depth:
7171
7272 # If the split is equal to or under the chunk size, merge it with all subsequent splits until the chunk size is reached.
7373 else :
74- # Initialise a list of splits to be merged into a new chunk.
75- new_chunk = [ split ]
74+ # Initalise the new chunk.
75+ new_chunk = split
7676
7777 # Iterate through each subsequent split until the chunk size is reached.
7878 for j , next_split in enumerate (splits [i + 1 :], start = i + 1 ):
7979 # Check whether the next split can be added to the chunk without exceeding the chunk size.
80- if token_counter (splitter . join ( new_chunk + [ next_split ]) ) <= chunk_size :
81- # Add the next split to the chunk.
82- new_chunk . append ( next_split )
80+ if token_counter (updated_chunk := new_chunk + splitter + next_split ) <= chunk_size :
81+ # Add the next split to the new chunk.
82+ new_chunk = updated_chunk
8383
8484 # Add the index of the next split to the list of indices to skip.
85- skips .append (j )
85+ skips .add (j )
8686
8787 # If the next split cannot be added to the chunk without exceeding the chunk size, break.
8888 else :
8989 break
9090
91- # Join the splits with the splitter.
92- new_chunk = splitter .join (new_chunk )
93-
9491 # Add the chunk.
9592 chunks .append (new_chunk )
9693
@@ -104,6 +101,6 @@ def chunk(text: str, chunk_size: int, token_counter: callable, _recursion_depth:
104101
105102 # If this is not a recursive call, remove any empty chunks.
106103 if not _recursion_depth :
107- chunks = [ chunk for chunk in chunks if chunk ]
104+ chunks = list ( filter ( None , chunks ))
108105
109106 return chunks
0 commit comments