|
14 | 14 | from ..core.loggers import get_logger |
15 | 15 | from ..core.parallel import BaseParallelProcessor, QueueType |
16 | 16 | from ..core.paths import get_size, glob_path, join_path, mkdir_p |
| 17 | +from ..core.utils import TYPES_MAP |
17 | 18 | from .data_types import TokenizerOutput # pylint: disable=unused-import |
18 | 19 | from .memmap_writer import MemmapWriter |
19 | | -from .tokenizer import Tokenizer, tokenize_file |
| 20 | +from .tokenizer import make_tokenizer, tokenize_file |
20 | 21 |
|
21 | 22 | TokenizedSeqsQueueType: TypeAlias = "Queue[List[TokenizerOutput]]" |
22 | 23 | PathsQueueType: TypeAlias = "Queue[str]" |
@@ -89,6 +90,18 @@ def process_single(cls, source_path: str, destination_path: str, queue: QueueTyp |
89 | 90 | # whether to split the special tokens into separate tokens, e.g. <s> -> < s > |
90 | 91 | tokenizer_kwargs["encode_special_tokens"] = kwargs.pop("encode_special_tokens", None) or False |
91 | 92 |
|
| 93 | + # name of the text and id fields in the input files |
| 94 | + tokenizer_kwargs["text_field_name"] = kwargs.pop("text_field_name", None) or "text" |
| 95 | + tokenizer_kwargs["id_field_name"] = kwargs.pop("id_field_name", None) |
| 96 | + |
| 97 | + # type of the text and id fields in the input files |
| 98 | + text_field_type_str = kwargs.pop("text_field_type", None) or "str" |
| 99 | + assert text_field_type_str in TYPES_MAP, f"Invalid text field type: {text_field_type_str}" |
| 100 | + tokenizer_kwargs["text_field_type"] = TYPES_MAP[text_field_type_str] |
| 101 | + id_field_type_str = kwargs.pop("id_field_type", None) or "str" |
| 102 | + assert id_field_type_str in TYPES_MAP, f"Invalid id field type: {id_field_type_str}" |
| 103 | + tokenizer_kwargs["id_field_type"] = TYPES_MAP[id_field_type_str] |
| 104 | + |
92 | 105 | # this is useful for making sure the queue does not grows too much |
93 | 106 | cpu_count = multiprocessing.cpu_count() |
94 | 107 |
|
@@ -305,6 +318,10 @@ def tokenize_in_parallel( |
305 | 318 | sample_ring_prop: bool = False, |
306 | 319 | refresh_tokenizer: int = 0, |
307 | 320 | use_fast_tokenizer: bool = True, |
| 321 | + text_field_name: str = "text", |
| 322 | + text_field_type: str = "str", |
| 323 | + id_field_name: Optional[str] = "id", |
| 324 | + id_field_type: str = "str", |
308 | 325 | ): |
309 | 326 | """ |
310 | 327 | Tokenizes the input sources in parallel using multiple writers and readers. |
@@ -334,18 +351,28 @@ def tokenize_in_parallel( |
334 | 351 | refresh_tokenizer (int, optional): Number of batches after which to refresh the tokenizer. |
335 | 352 | Defaults to 0, which means the tokenizer will not be refreshed. |
336 | 353 | use_fast_tokenizer (bool, optional): Whether to use the fast tokenizer. Defaults to True. |
| 354 | + text_field_name (str, optional): Name of the text field in the input files. Defaults to "text". |
| 355 | + text_field_type (str, optional): Type of the text field in the input files. Defaults to "str". |
| 356 | + id_field_name (str, optional): Name of the id field in the input files. Defaults to "id". Set to None if |
| 357 | + the input files do not have an id field. |
| 358 | + id_field_type (str, optional): Type of the id field in the input files. Defaults to "str". |
337 | 359 | """ |
338 | 360 | # variables to avoid issues with parallelism |
339 | 361 | os.environ["TOKENIZERS_PARALLELISM"] = "false" |
340 | 362 |
|
341 | | - # do it once so it gets cached (unless it's local path, so no need) |
342 | | - if not os.path.exists(tokenizer_name_or_path): |
343 | | - Tokenizer.from_pretrained( |
344 | | - identifier=tokenizer_name_or_path, |
345 | | - bos_token_id=bos_token_id, |
346 | | - eos_token_id=eos_token_id, |
347 | | - pad_token_id=pad_token_id, |
348 | | - use_fast=use_fast_tokenizer, |
| 363 | + # do it once so it gets cached, and we can check if dtype is correct |
| 364 | + |
| 365 | + tokenizer = make_tokenizer( |
| 366 | + tokenizer_name_or_path, |
| 367 | + bos_token_id=bos_token_id, |
| 368 | + eos_token_id=eos_token_id, |
| 369 | + pad_token_id=pad_token_id, |
| 370 | + use_fast=use_fast_tokenizer, |
| 371 | + ) |
| 372 | + if tokenizer.dtype != np.dtype(dtype): |
| 373 | + raise TypeError( |
| 374 | + f"Numpy type mismatch: provided dtype '{dtype}' does not match " |
| 375 | + f"inferred dtype '{tokenizer.dtype}' based on vocab size {tokenizer.vocab_size:,}!" |
349 | 376 | ) |
350 | 377 |
|
351 | 378 | # get a run hash |
@@ -380,4 +407,8 @@ def tokenize_in_parallel( |
380 | 407 | sample_ring_prop=sample_ring_prop, |
381 | 408 | use_fast_tokenizer=use_fast_tokenizer, |
382 | 409 | refresh_tokenizer=refresh_tokenizer, |
| 410 | + text_field_name=text_field_name, |
| 411 | + text_field_type=text_field_type, |
| 412 | + id_field_name=id_field_name, |
| 413 | + id_field_type=id_field_type, |
383 | 414 | ) |
0 commit comments