|
9 | 9 | from itertools import accumulate |
10 | 10 | from contextlib import suppress |
11 | 11 |
|
| 12 | +import mpire |
| 13 | + |
12 | 14 | from tqdm import tqdm |
13 | 15 |
|
14 | 16 | if TYPE_CHECKING: |
@@ -168,7 +170,11 @@ def chunkerify( |
168 | 170 | memoize (bool, optional): Whether to memoize the token counter. Defaults to `True`. |
169 | 171 | |
170 | 172 | Returns: |
171 | | - Callable[[str | Sequence[str], bool], list[str] | list[list[str]]]: A function that takes either a single text or a sequence of texts and returns, if a single text has been provided, a list of chunks up to `chunk_size`-tokens-long with any whitespace used to split the text removed, or, if multiple texts have been provided, a list of lists of chunks, with each inner list corresponding to the chunks of one of the provided input texts. The function can also be passed a `progress` argument which if set to `True` and multiple texts are passed, will display a progress bar.""" |
| 173 | + Callable[[str | Sequence[str], bool, bool], list[str] | list[list[str]]]: A function that takes either a single text or a sequence of texts and returns, if a single text has been provided, a list of chunks up to `chunk_size`-tokens-long with any whitespace used to split the text removed, or, if multiple texts have been provided, a list of lists of chunks, with each inner list corresponding to the chunks of one of the provided input texts. |
| 174 | + |
| 175 | + The resulting chunker function can also be passed a `processes` argument that specifies the number of processes to be used when chunking multiple texts. |
| 176 | + |
| 177 | + It is also possible to pass a `progress` argument which, if set to `True` and multiple texts are passed, will display a progress bar.""" |
172 | 178 |
|
173 | 179 | # If the provided tokenizer is a string, try to load it with either `tiktoken` or `transformers` or raise an error if neither is available. |
174 | 180 | if isinstance(tokenizer_or_token_counter, str): |
@@ -251,24 +257,36 @@ def faster_token_counter(text: str) -> int: |
251 | 257 | if memoize: |
252 | 258 | token_counter = _memoized_token_counters.setdefault(token_counter, cache(token_counter)) |
253 | 259 |
|
| 260 | + # Construct a chunking function that passes the chunk size and token counter to `chunk()`. |
| 261 | + def chunking_function(text: str) -> list[str]: |
| 262 | + return chunk(text, chunk_size, token_counter, memoize = False) |
| 263 | + |
254 | 264 | # Construct and return the chunker. |
255 | | - def chunker(text_or_texts: str | Sequence[str], progress: bool = False) -> list[str] | list[list[str]]: |
| 265 | + def chunker( |
| 266 | + text_or_texts: str | Sequence[str], |
| 267 | + processes: int = 1, |
| 268 | + progress: bool = False, |
| 269 | + ) -> list[str] | list[list[str]]: |
256 | 270 | """Split text or texts into semantically meaningful chunks of a specified size as determined by the provided tokenizer or token counter. |
257 | 271 | |
258 | 272 | Args: |
259 | 273 | text_or_texts (str | Sequence[str]): The text or texts to be chunked. |
260 | 274 | |
261 | 275 | Returns: |
262 | 276 | list[str] | list[list[str]]: If a single text has been provided, a list of chunks up to `chunk_size`-tokens-long, with any whitespace used to split the text removed, or, if multiple texts have been provided, a list of lists of chunks, with each inner list corresponding to the chunks of one of the provided input texts. |
| 277 | + processes (int, optional): The number of processes to use when chunking multiple texts. Defaults to `1` in which case chunking will occur in the main process. |
263 | 278 | progress (bool, optional): Whether to display a progress bar when chunking multiple texts. Defaults to `False`.""" |
264 | 279 |
|
265 | 280 | if isinstance(text_or_texts, str): |
266 | | - return chunk(text_or_texts, chunk_size, token_counter, memoize = False) |
| 281 | + return chunking_function(text_or_texts) |
267 | 282 |
|
268 | | - if progress: |
269 | | - return [chunk(text, chunk_size, token_counter, memoize = False) for text in tqdm(text_or_texts)] |
| 283 | + if progress and processes == 1: |
| 284 | + text_or_texts = tqdm(text_or_texts) |
270 | 285 |
|
271 | | - else: |
272 | | - return [chunk(text, chunk_size, token_counter, memoize = False) for text in text_or_texts] |
| 286 | + if processes == 1: |
| 287 | + return [chunking_function(text) for text in text_or_texts] |
| 288 | + |
| 289 | + with mpire.WorkerPool(processes, use_dill = True) as pool: |
| 290 | + return pool.map(chunking_function, text_or_texts, progress_bar = progress) |
273 | 291 |
|
274 | 292 | return chunker |
0 commit comments