Skip to content

Commit f66cd3f

Browse files
soldniCopilot
andauthored
Tokenizer over custom fields and w/o IDs; BOS/EOS tokens. (#266)
* pass type and name * new tests * adding tests * more PRers * tests * Refactor tokenizer functions to improve type annotations and enhance tokenization output. Updated `make_spec_from_fields` and `recursively_make_struct` to return `type[msgspec.Struct]`. Modified `tokenize_file` to yield `TokenizerOutput` with dtype parameter. * Refactor tokenizer initialization to use `make_tokenizer` for improved dtype validation. Added a new test case to check for dtype mismatch errors during tokenization. * documentation. * Update tests/python/test_tokenizer.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update CI workflow to use `uv` for environment management and command execution. Refactor type annotations in tokenizer-related files to use `Optional` for nullable fields. Enhance S3 utility functions to improve type safety. * Add `uv venv` command to CI workflow for environment setup * Update dependencies in pyproject.toml, enhance CI workflow with UV logging format, and modify record_info.py to handle optional fastwarc import with error handling. * Enhance CI workflow by adding a step to install the latest version of the toolkit, ensuring up-to-date dependencies are used during the build process. * removed unnecessary deps + better rust caching Refactor dependency management in pyproject.toml by removing unnecessary cached-path entry and updating PII detection comments. Enhance CI workflow to cache Rust targets alongside the virtual environment for improved build efficiency. Update imports in various modules to use the new cached_path location. * one final thing whatever * Refactor type annotations in test files for improved clarity Updated type annotations in `test_tokenizer.py` to specify the type of `extracted_sequences` as `list[list[int]]`. Removed unnecessary type ignore comment in `test_nested_struct.py` for better code readability. * sorting * typo * style * mypy madness * Disable test for CodeProseCompositionClassifier until path issue is resolved * sorting * Update Python version in CI workflow from 3.9 to 3.10 * Refactor error handling in tokenize_file function to improve logging and maintainability. Moved try-except blocks to streamline error management and added logging for line processing errors. * Remove unused import of 'exception' from the logging module in tokenizer.py to clean up the code. * 3.10 doesnt delete=false * removed older langid --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 492fcfb commit f66cd3f

File tree

18 files changed

+619
-70
lines changed

18 files changed

+619
-70
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ jobs:
100100
if: steps.cache-venv.outputs.cache-hit != 'true'
101101
uses: actions/setup-python@v4
102102
with:
103-
python-version: "3.9"
103+
python-version: "3.10"
104104
architecture: "x64"
105105

106106
- name: Create a new Python environment & install maturin

docs/tokenize.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,7 @@ The following parameters are supported either via CLI (e.g. `dolma tokens --para
4444
|`work_dir.output`|No| Path to a local scratch directory where temporary output files can be placed. If not provided, Dolma will make one for you and delete it upon completion. |
4545
|`dryrun`|No| If true, only print the configuration and exit without running the tokenizer. |
4646
|`seed`|No| Seed for random number generation. |
47+
|`fields.text_field_name`|No|Name of the text field in the input files. Can be a nested field (e.g. "text.nested"). Defaults to "text". |
48+
|`fields.text_field_type`|No|Type of the text field in the input files. Defaults to "str". |
49+
|`fields.id_field_name`|No|Name of the id field in the input files. Can be a nested field (e.g. "id.nested.more"). Can be set to null to disable id field. Defaults to "id". |
50+
|`fields.id_field_type`|No|Type of the id field in the input files. Defaults to "str". |

pyproject.toml

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
[project]
22
name = "dolma"
3-
version = "1.1.2"
4-
description = "Data filters"
3+
version = "1.2.0"
4+
description = "Toolkit for pre-processing LLM training data."
55
license = { text = "Apache-2.0" }
66
readme = "README.md"
7-
requires-python = ">=3.9"
7+
requires-python = ">=3.10,<3.13"
88
dependencies = [
99
"anyascii>=0.3.2",
1010
"blingfire==0.1.8",
11-
"boto3>=1.28",
12-
"cached-path>=1.5.1",
11+
# "boto3>=1.28",
12+
"boto3",
13+
# "cached-path>=1.5.1", # no longer needed
1314
# "fasttext==0.9.2", # broken with new version of setuptools; using fasttext-wheel instead
1415
"fasttext-wheel==0.9.2",
1516
"fsspec>=2023.6.0",
@@ -26,7 +27,7 @@ dependencies = [
2627
"requests",
2728
"rich",
2829
"s3fs==2023.6.0",
29-
"smart-open",
30+
"smart-open>=7.0.4",
3031
"tokenizers>=0.15.0,<=0.19.1",
3132
"tqdm",
3233
"uniseg",
@@ -118,14 +119,18 @@ dev = [
118119
# extension to process code
119120
code = ["detect-secrets==1.4.0", "beautifulsoup4>=4", "pygments", "regex"]
120121
# extension to detect PIIs using presidio
121-
pii = ["presidio_analyzer==2.2.32", "regex"]
122+
pii = [
123+
# "presidio_analyzer==2.2.32", # presidio causes too many issues with installation, asking users to install it manually
124+
"regex",
125+
]
122126

123127
# language detection; by default, we use fastttext, everything else is optional
124128
lang = [
125129
"fasttext-wheel==0.9.2",
126-
"LTpycld2==0.42", # fork of pycld2 that works on Apple Silicon
130+
# "LTpycld2==0.42", # LTpycld2/pycld2 all so buggy; recommending user install them on their own
131+
"pycld2",
127132
"lingua-language-detector>=2.0.0",
128-
"langdetect>=1.0.9",
133+
# "langdetect>=1.0.9",
129134
]
130135

131136
# extension to parse warc files
@@ -227,7 +232,7 @@ recursive = true
227232
aggressive = 3
228233

229234
[tool.mypy]
230-
python_version = "3.9"
235+
python_version = "3.10"
231236
ignore_missing_imports = true
232237
no_site_packages = true
233238
allow_redefinition = false

python/dolma/cli/resolvers.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import multiprocessing
22
from typing import List, TypeVar
33

4-
from cached_path import cached_path
54
from omegaconf.omegaconf import OmegaConf as om
65
from omegaconf.omegaconf import Resolver
76

8-
from ..core.paths import glob_path
7+
from ..core.paths import cached_path, glob_path
98

109
__all__ = ["cache", "glob", "processes"]
1110

python/dolma/cli/tokenizer.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,14 @@ def deprecated_init(cls, tokenizer_name_or_path: str) -> "TokenizerConfig":
9494
)
9595

9696

97+
@dataclass
98+
class FieldsConfig:
99+
text_field_name: str = field(default="text", help="Name of the text field in the input files.")
100+
text_field_type: str = field(default="str", help="Type of the text field in the input files.")
101+
id_field_name: Optional[str] = field(default="id", help="Name of the id field in the input files.")
102+
id_field_type: str = field(default="str", help="Type of the id field in the input files.")
103+
104+
97105
@dataclass
98106
class TokenizationConfig:
99107
documents: List[str] = field(
@@ -131,6 +139,7 @@ class TokenizationConfig:
131139
help="Number of sequences to tokenize before writing to disk.",
132140
)
133141
ring_size: int = field(default=8, help="Number of files to open in parallel for tokenization.")
142+
fields: FieldsConfig = field(default=FieldsConfig(), help="Configuration for the fields in the input files.")
134143
sample_ring_prop: bool = field(
135144
default=False,
136145
help="Whether to sample the ring proportionally to the number of documents in each source.",
@@ -221,4 +230,8 @@ def run(cls, parsed_config: TokenizationConfig):
221230
sample_ring_prop=parsed_config.sample_ring_prop,
222231
use_fast_tokenizer=parsed_config.tokenizer.fast,
223232
refresh_tokenizer=parsed_config.tokenizer.refresh,
233+
text_field_name=parsed_config.fields.text_field_name,
234+
text_field_type=parsed_config.fields.text_field_type,
235+
id_field_name=parsed_config.fields.id_field_name,
236+
id_field_type=parsed_config.fields.id_field_type,
224237
)

python/dolma/core/ft_tagger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from typing import Iterable, Literal, NamedTuple, Optional
1212

1313
import smart_open
14-
from cached_path import cached_path
1514
from fasttext import train_supervised
1615
from fasttext.FastText import _FastText
1716

1817
from .data_types import DocResult, Document, Span, TextSlice
18+
from .paths import cached_path
1919
from .taggers import BaseTagger
2020
from .utils import split_paragraphs, split_sentences
2121

python/dolma/core/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,24 @@
3434
logger = get_logger(__name__)
3535

3636

37+
TYPES_MAP = {
38+
"object": dict,
39+
"dict": dict,
40+
"array": list,
41+
"list": list,
42+
"string": str,
43+
"str": str,
44+
"number": float,
45+
"float": float,
46+
"integer": int,
47+
"int": int,
48+
"boolean": bool,
49+
"bool": bool,
50+
"null": type(None),
51+
"None": type(None),
52+
}
53+
54+
3755
def make_variable_name(name: str, remove_multiple_underscores: bool = False) -> str:
3856
# use underscores for any non-valid characters in variable name
3957
name = re.sub(r"[^a-zA-Z0-9_]", "_", name)

python/dolma/taggers/language.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
if CLD2_AVAILABLE or TYPE_CHECKING:
2525
import pycld2 as cld2 # pyright:ignore pylint:disable=import-error
2626

27-
2827
with necessary.necessary("langdetect", soft=True) as LANGDETECT_AVAILABLE:
2928
if LANGDETECT_AVAILABLE or TYPE_CHECKING:
3029
from langdetect import PROFILES_DIRECTORY, DetectorFactory, LangDetectException
@@ -98,7 +97,10 @@ class Cld2LanguageTagger(BaseLanguageTagger):
9897
def __init__(self) -> None:
9998
super().__init__()
10099
if not CLD2_AVAILABLE:
101-
raise ImportError("pycld2 is not installed, please run `pip install dolma[lang]`.")
100+
raise ImportError(
101+
"pycld2 is not available, please run `pip install pycld2` "
102+
"or `pip install LTpycld2` (whichever works)."
103+
)
102104

103105
def _sanitize_input(self, text: str) -> str:
104106
return self.RE_BAD_CHARS.sub("", text)

python/dolma/taggers/pii.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
# presidio
6969
if self.method == self.PRESIDIO:
7070
if not PRESIDIO_AVAILABLE:
71-
raise RuntimeError("Presidio is not available; please run `pip install dolma[pii]`")
71+
raise RuntimeError("Presidio is not available; please run `pip install presidio-analyzer`")
7272
self.analyzer = AnalyzerEngine()
7373

7474
def predict(self, doc: Document) -> DocResult:

python/dolma/tokenizer/executor.py

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
from ..core.loggers import get_logger
1515
from ..core.parallel import BaseParallelProcessor, QueueType
1616
from ..core.paths import get_size, glob_path, join_path, mkdir_p
17+
from ..core.utils import TYPES_MAP
1718
from .data_types import TokenizerOutput # pylint: disable=unused-import
1819
from .memmap_writer import MemmapWriter
19-
from .tokenizer import Tokenizer, tokenize_file
20+
from .tokenizer import make_tokenizer, tokenize_file
2021

2122
TokenizedSeqsQueueType: TypeAlias = "Queue[List[TokenizerOutput]]"
2223
PathsQueueType: TypeAlias = "Queue[str]"
@@ -89,6 +90,18 @@ def process_single(cls, source_path: str, destination_path: str, queue: QueueTyp
8990
# whether to split the special tokens into separate tokens, e.g. <s> -> < s >
9091
tokenizer_kwargs["encode_special_tokens"] = kwargs.pop("encode_special_tokens", None) or False
9192

93+
# name of the text and id fields in the input files
94+
tokenizer_kwargs["text_field_name"] = kwargs.pop("text_field_name", None) or "text"
95+
tokenizer_kwargs["id_field_name"] = kwargs.pop("id_field_name", None)
96+
97+
# type of the text and id fields in the input files
98+
text_field_type_str = kwargs.pop("text_field_type", None) or "str"
99+
assert text_field_type_str in TYPES_MAP, f"Invalid text field type: {text_field_type_str}"
100+
tokenizer_kwargs["text_field_type"] = TYPES_MAP[text_field_type_str]
101+
id_field_type_str = kwargs.pop("id_field_type", None) or "str"
102+
assert id_field_type_str in TYPES_MAP, f"Invalid id field type: {id_field_type_str}"
103+
tokenizer_kwargs["id_field_type"] = TYPES_MAP[id_field_type_str]
104+
92105
# this is useful for making sure the queue does not grows too much
93106
cpu_count = multiprocessing.cpu_count()
94107

@@ -305,6 +318,10 @@ def tokenize_in_parallel(
305318
sample_ring_prop: bool = False,
306319
refresh_tokenizer: int = 0,
307320
use_fast_tokenizer: bool = True,
321+
text_field_name: str = "text",
322+
text_field_type: str = "str",
323+
id_field_name: Optional[str] = "id",
324+
id_field_type: str = "str",
308325
):
309326
"""
310327
Tokenizes the input sources in parallel using multiple writers and readers.
@@ -334,18 +351,28 @@ def tokenize_in_parallel(
334351
refresh_tokenizer (int, optional): Number of batches after which to refresh the tokenizer.
335352
Defaults to 0, which means the tokenizer will not be refreshed.
336353
use_fast_tokenizer (bool, optional): Whether to use the fast tokenizer. Defaults to True.
354+
text_field_name (str, optional): Name of the text field in the input files. Defaults to "text".
355+
text_field_type (str, optional): Type of the text field in the input files. Defaults to "str".
356+
id_field_name (str, optional): Name of the id field in the input files. Defaults to "id". Set to None if
357+
the input files do not have an id field.
358+
id_field_type (str, optional): Type of the id field in the input files. Defaults to "str".
337359
"""
338360
# variables to avoid issues with parallelism
339361
os.environ["TOKENIZERS_PARALLELISM"] = "false"
340362

341-
# do it once so it gets cached (unless it's local path, so no need)
342-
if not os.path.exists(tokenizer_name_or_path):
343-
Tokenizer.from_pretrained(
344-
identifier=tokenizer_name_or_path,
345-
bos_token_id=bos_token_id,
346-
eos_token_id=eos_token_id,
347-
pad_token_id=pad_token_id,
348-
use_fast=use_fast_tokenizer,
363+
# do it once so it gets cached, and we can check if dtype is correct
364+
365+
tokenizer = make_tokenizer(
366+
tokenizer_name_or_path,
367+
bos_token_id=bos_token_id,
368+
eos_token_id=eos_token_id,
369+
pad_token_id=pad_token_id,
370+
use_fast=use_fast_tokenizer,
371+
)
372+
if tokenizer.dtype != np.dtype(dtype):
373+
raise TypeError(
374+
f"Numpy type mismatch: provided dtype '{dtype}' does not match "
375+
f"inferred dtype '{tokenizer.dtype}' based on vocab size {tokenizer.vocab_size:,}!"
349376
)
350377

351378
# get a run hash
@@ -380,4 +407,8 @@ def tokenize_in_parallel(
380407
sample_ring_prop=sample_ring_prop,
381408
use_fast_tokenizer=use_fast_tokenizer,
382409
refresh_tokenizer=refresh_tokenizer,
410+
text_field_name=text_field_name,
411+
text_field_type=text_field_type,
412+
id_field_name=id_field_name,
413+
id_field_type=id_field_type,
383414
)

0 commit comments

Comments
 (0)