33
44"""A module containing the 'Tokenizer', 'TextSplitter', 'NoopTextSplitter' and 'TokenTextSplitter' models."""
55
6- import json
76import logging
87from abc import ABC , abstractmethod
98from collections .abc import Callable , Collection , Iterable
109from dataclasses import dataclass
11- from enum import Enum
1210from typing import Any , Literal , cast
1311
1412import pandas as pd
1513import tiktoken
1614
1715import graphrag .config .defaults as defs
18- from graphrag .index .utils .tokens import num_tokens_from_string
16+ from graphrag .index .operations .chunk_text .typing import TextChunk
17+ from graphrag .logger .progress import ProgressTicker
1918
2019EncodedText = list [int ]
2120DecodeFn = Callable [[EncodedText ], str ]
@@ -123,10 +122,10 @@ def num_tokens(self, text: str) -> int:
123122
124123 def split_text (self , text : str | list [str ]) -> list [str ]:
125124 """Split text method."""
126- if cast ("bool" , pd .isna (text )) or text == "" :
127- return []
128125 if isinstance (text , list ):
129126 text = " " .join (text )
127+ elif cast ("bool" , pd .isna (text )) or text == "" :
128+ return []
130129 if not isinstance (text , str ):
131130 msg = f"Attempting to split a non-string value, actual is { type (text )} "
132131 raise TypeError (msg )
@@ -138,108 +137,57 @@ def split_text(self, text: str | list[str]) -> list[str]:
138137 encode = lambda text : self .encode (text ),
139138 )
140139
141- return split_text_on_tokens (text = text , tokenizer = tokenizer )
142-
143-
144- class TextListSplitterType (str , Enum ):
145- """Enum for the type of the TextListSplitter."""
146-
147- DELIMITED_STRING = "delimited_string"
148- JSON = "json"
149-
150-
151- class TextListSplitter (TextSplitter ):
152- """Text list splitter class definition."""
153-
154- def __init__ (
155- self ,
156- chunk_size : int ,
157- splitter_type : TextListSplitterType = TextListSplitterType .JSON ,
158- input_delimiter : str | None = None ,
159- output_delimiter : str | None = None ,
160- model_name : str | None = None ,
161- encoding_name : str | None = None ,
162- ):
163- """Initialize the TextListSplitter with a chunk size."""
164- # Set the chunk overlap to 0 as we use full strings
165- super ().__init__ (chunk_size , chunk_overlap = 0 )
166- self ._type = splitter_type
167- self ._input_delimiter = input_delimiter
168- self ._output_delimiter = output_delimiter or "\n "
169- self ._length_function = lambda x : num_tokens_from_string (
170- x , model = model_name , encoding_name = encoding_name
171- )
172-
173- def split_text (self , text : str | list [str ]) -> Iterable [str ]:
174- """Split a string list into a list of strings for a given chunk size."""
175- if not text :
176- return []
177-
178- result : list [str ] = []
179- current_chunk : list [str ] = []
180-
181- # Add the brackets
182- current_length : int = self ._length_function ("[]" )
140+ return split_single_text_on_tokens (text = text , tokenizer = tokenizer )
183141
184- # Input should be a string list joined by a delimiter
185- string_list = self ._load_text_list (text )
186142
187- if len (string_list ) == 1 :
188- return string_list
189-
190- for item in string_list :
191- # Count the length of the item and add comma
192- item_length = self ._length_function (f"{ item } ," )
143+ def split_single_text_on_tokens (text : str , tokenizer : Tokenizer ) -> list [str ]:
144+ """Split a single text and return chunks using the tokenizer."""
145+ result = []
146+ input_ids = tokenizer .encode (text )
193147
194- if current_length + item_length > self ._chunk_size :
195- if current_chunk and len (current_chunk ) > 0 :
196- # Add the current chunk to the result
197- self ._append_to_result (result , current_chunk )
148+ start_idx = 0
149+ cur_idx = min (start_idx + tokenizer .tokens_per_chunk , len (input_ids ))
150+ chunk_ids = input_ids [start_idx :cur_idx ]
198151
199- # Start a new chunk
200- current_chunk = [item ]
201- # Add 2 for the brackets
202- current_length = item_length
203- else :
204- # Add the item to the current chunk
205- current_chunk .append (item )
206- # Add 1 for the comma
207- current_length += item_length
152+ while start_idx < len (input_ids ):
153+ chunk_text = tokenizer .decode (list (chunk_ids ))
154+ result .append (chunk_text ) # Append chunked text as string
155+ start_idx += tokenizer .tokens_per_chunk - tokenizer .chunk_overlap
156+ cur_idx = min (start_idx + tokenizer .tokens_per_chunk , len (input_ids ))
157+ chunk_ids = input_ids [start_idx :cur_idx ]
208158
209- # Add the last chunk to the result
210- self ._append_to_result (result , current_chunk )
159+ return result
211160
212- return result
213161
214- def _load_text_list (self , text : str | list [str ]):
215- """Load the text list based on the type."""
216- if isinstance (text , list ):
217- string_list = text
218- elif self ._type == TextListSplitterType .JSON :
219- string_list = json .loads (text )
220- else :
221- string_list = text .split (self ._input_delimiter )
222- return string_list
162+ # Adapted from - https://github.com/langchain-ai/langchain/blob/77b359edf5df0d37ef0d539f678cf64f5557cb54/libs/langchain/langchain/text_splitter.py#L471
163+ # So we could have better control over the chunking process
164+ def split_multiple_texts_on_tokens (
165+ texts : list [str ], tokenizer : Tokenizer , tick : ProgressTicker
166+ ) -> list [TextChunk ]:
167+ """Split multiple texts and return chunks with metadata using the tokenizer."""
168+ result = []
169+ mapped_ids = []
223170
224- def _append_to_result (self , chunk_list : list [str ], new_chunk : list [str ]):
225- """Append the current chunk to the result."""
226- if new_chunk and len (new_chunk ) > 0 :
227- if self ._type == TextListSplitterType .JSON :
228- chunk_list .append (json .dumps (new_chunk , ensure_ascii = False ))
229- else :
230- chunk_list .append (self ._output_delimiter .join (new_chunk ))
171+ for source_doc_idx , text in enumerate (texts ):
172+ encoded = tokenizer .encode (text )
173+ if tick :
174+ tick (1 ) # Track progress if tick callback is provided
175+ mapped_ids .append ((source_doc_idx , encoded ))
231176
177+ input_ids = [
178+ (source_doc_idx , id ) for source_doc_idx , ids in mapped_ids for id in ids
179+ ]
232180
233- def split_text_on_tokens (* , text : str , tokenizer : Tokenizer ) -> list [str ]:
234- """Split incoming text and return chunks using tokenizer."""
235- splits : list [str ] = []
236- input_ids = tokenizer .encode (text )
237181 start_idx = 0
238182 cur_idx = min (start_idx + tokenizer .tokens_per_chunk , len (input_ids ))
239183 chunk_ids = input_ids [start_idx :cur_idx ]
184+
240185 while start_idx < len (input_ids ):
241- splits .append (tokenizer .decode (chunk_ids ))
186+ chunk_text = tokenizer .decode ([id for _ , id in chunk_ids ])
187+ doc_indices = list ({doc_idx for doc_idx , _ in chunk_ids })
188+ result .append (TextChunk (chunk_text , doc_indices , len (chunk_ids )))
242189 start_idx += tokenizer .tokens_per_chunk - tokenizer .chunk_overlap
243190 cur_idx = min (start_idx + tokenizer .tokens_per_chunk , len (input_ids ))
244191 chunk_ids = input_ids [start_idx :cur_idx ]
245- return splits
192+
193+ return result
0 commit comments