Skip to content

Commit 67b5bd9

Browse files
committed
remove unused keys dedupe
1 parent 1f33908 commit 67b5bd9

File tree

6 files changed

+7
-225
lines changed

6 files changed

+7
-225
lines changed

src/datasets/arrow_writer.py

Lines changed: 3 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import json
1717
import sys
1818
from collections.abc import Iterable
19-
from typing import Any, Optional, Union
19+
from typing import Any, Optional
2020

2121
import fsspec
2222
import numpy as np
@@ -40,7 +40,6 @@
4040
)
4141
from .filesystems import is_remote_filesystem
4242
from .info import DatasetInfo
43-
from .keyhash import DuplicatedKeysError, KeyHasher
4443
from .table import array_cast, cast_array_to_feature, embed_table_storage, table_cast
4544
from .utils import logging
4645
from .utils.py_utils import asdict, convert_file_size_to_int, first_non_null_non_empty_value
@@ -414,8 +413,6 @@ def __init__(
414413
stream: Optional[pa.NativeFile] = None,
415414
fingerprint: Optional[str] = None,
416415
writer_batch_size: Optional[int] = None,
417-
hash_salt: Optional[str] = None,
418-
check_duplicates: Optional[bool] = False,
419416
disable_nullable: bool = False,
420417
update_features: bool = False,
421418
with_metadata: bool = True,
@@ -435,13 +432,6 @@ def __init__(
435432
self._features = None
436433
self._schema = None
437434

438-
if hash_salt is not None:
439-
# Create KeyHasher instance using split name as hash salt
440-
self._hasher = KeyHasher(hash_salt)
441-
else:
442-
self._hasher = KeyHasher("")
443-
444-
self._check_duplicates = check_duplicates
445435
self._disable_nullable = disable_nullable
446436

447437
if stream is None:
@@ -592,51 +582,21 @@ def write_rows_on_file(self):
592582
def write(
593583
self,
594584
example: dict[str, Any],
595-
key: Optional[Union[str, int, bytes]] = None,
596585
writer_batch_size: Optional[int] = None,
597586
):
598587
"""Add a given (Example,Key) pair to the write-pool of examples which is written to file.
599588
600589
Args:
601590
example: the Example to add.
602-
key: Optional, a unique identifier(str, int or bytes) associated with each example
603591
"""
604-
# Utilize the keys and duplicate checking when `self._check_duplicates` is passed True
605-
if self._check_duplicates:
606-
# Create unique hash from key and store as (key, example) pairs
607-
hash = self._hasher.hash(key)
608-
self.current_examples.append((example, hash))
609-
# Maintain record of keys and their respective hashes for checking duplicates
610-
self.hkey_record.append((hash, key))
611-
else:
612-
# Store example as a tuple so as to keep the structure of `self.current_examples` uniform
613-
self.current_examples.append((example, ""))
592+
# Store example as a tuple so as to keep the structure of `self.current_examples` uniform
593+
self.current_examples.append((example, ""))
614594

615595
if writer_batch_size is None:
616596
writer_batch_size = self.writer_batch_size
617597
if writer_batch_size is not None and len(self.current_examples) >= writer_batch_size:
618-
if self._check_duplicates:
619-
self.check_duplicate_keys()
620-
# Re-initializing to empty list for next batch
621-
self.hkey_record = []
622-
623598
self.write_examples_on_file()
624599

625-
def check_duplicate_keys(self):
626-
"""Raises error if duplicates found in a batch"""
627-
tmp_record = set()
628-
for hash, key in self.hkey_record:
629-
if hash in tmp_record:
630-
duplicate_key_indices = [
631-
str(self._num_examples + index)
632-
for index, (duplicate_hash, _) in enumerate(self.hkey_record)
633-
if duplicate_hash == hash
634-
]
635-
636-
raise DuplicatedKeysError(key, duplicate_key_indices)
637-
else:
638-
tmp_record.add(hash)
639-
640600
def write_row(self, row: pa.Table, writer_batch_size: Optional[int] = None):
641601
"""Add a given single-row Table to the write-pool of rows which is written to file.
642602
@@ -721,10 +681,6 @@ def write_table(self, pa_table: pa.Table, writer_batch_size: Optional[int] = Non
721681
def finalize(self, close_stream=True):
722682
self.write_rows_on_file()
723683
# In case current_examples < writer_batch_size, but user uses finalize()
724-
if self._check_duplicates:
725-
self.check_duplicate_keys()
726-
# Re-initializing to empty list for next batch
727-
self.hkey_record = []
728684
self.write_examples_on_file()
729685
# If schema is known, infer features even if no examples were written
730686
if self.pa_writer is None and self.schema:

src/datasets/builder.py

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
from .fingerprint import Hasher
6060
from .info import DatasetInfo, PostProcessedInfo
6161
from .iterable_dataset import ArrowExamplesIterable, ExamplesIterable, IterableDataset
62-
from .keyhash import DuplicatedKeysError
6362
from .naming import INVALID_WINDOWS_CHARACTERS_IN_PATH, camelcase_to_snakecase
6463
from .splits import Split, SplitDict, SplitGenerator, SplitInfo
6564
from .streaming import extend_dataset_builder_for_streaming
@@ -979,13 +978,6 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_split_k
979978
+ "\nOriginal error:\n"
980979
+ str(e)
981980
) from None
982-
# If check_duplicates is set to True , then except DuplicatedKeysError
983-
except DuplicatedKeysError as e:
984-
raise DuplicatedKeysError(
985-
e.key,
986-
e.duplicate_key_indices,
987-
fix_msg=f"To avoid duplicate keys, please fix the dataset splits for {self.name}",
988-
) from None
989981
dl_manager.manage_extracted_files()
990982

991983
if verification_mode == VerificationMode.BASIC_CHECKS or verification_mode == VerificationMode.ALL_CHECKS:
@@ -1400,7 +1392,6 @@ def _generate_examples(self, **kwargs) -> Iterable[tuple[int, int], dict[str, An
14001392
def _prepare_split(
14011393
self,
14021394
split_generator: SplitGenerator,
1403-
check_duplicate_keys: bool,
14041395
file_format="arrow",
14051396
num_proc: Optional[int] = None,
14061397
max_shard_size: Optional[Union[int, str]] = None,
@@ -1440,7 +1431,6 @@ def _prepare_split(
14401431
"file_format": file_format,
14411432
"max_shard_size": max_shard_size,
14421433
"split_info": split_info,
1443-
"check_duplicate_keys": check_duplicate_keys,
14441434
}
14451435

14461436
if num_proc is None or num_proc == 1:
@@ -1558,7 +1548,6 @@ def _prepare_split_single(
15581548
file_format: str,
15591549
max_shard_size: int,
15601550
split_info: SplitInfo,
1561-
check_duplicate_keys: bool,
15621551
job_id: int,
15631552
) -> Iterable[tuple[int, bool, tuple[int, int, Features, int, int, int]]]:
15641553
generator = self._generate_examples(**gen_kwargs)
@@ -1575,8 +1564,6 @@ def _prepare_split_single(
15751564
features=self.info.features,
15761565
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
15771566
writer_batch_size=self._writer_batch_size,
1578-
hash_salt=split_info.name,
1579-
check_duplicates=check_duplicate_keys,
15801567
storage_options=self._fs.storage_options,
15811568
embed_local_files=embed_local_files,
15821569
)
@@ -1594,13 +1581,11 @@ def _prepare_split_single(
15941581
features=writer._features,
15951582
path=fpath.replace("SSSSS", f"{shard_id:05d}").replace("JJJJJ", f"{job_id:05d}"),
15961583
writer_batch_size=self._writer_batch_size,
1597-
hash_salt=split_info.name,
1598-
check_duplicates=check_duplicate_keys,
15991584
storage_options=self._fs.storage_options,
16001585
embed_local_files=embed_local_files,
16011586
)
16021587
example = self.info.features.encode_example(record) if self.info.features is not None else record
1603-
writer.write(example, (input_shard_idx, example_idx))
1588+
writer.write(example)
16041589
if len(input_shard_lengths) == input_shard_idx:
16051590
input_shard_lengths.append(1)
16061591
else:
@@ -1634,8 +1619,6 @@ def _download_and_prepare(self, dl_manager, verification_mode, **prepare_splits_
16341619
super()._download_and_prepare(
16351620
dl_manager,
16361621
verification_mode,
1637-
check_duplicate_keys=verification_mode == VerificationMode.BASIC_CHECKS
1638-
or verification_mode == VerificationMode.ALL_CHECKS,
16391622
**prepare_splits_kwargs,
16401623
)
16411624

src/datasets/keyhash.py

Lines changed: 0 additions & 104 deletions
This file was deleted.

src/datasets/utils/info_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ class VerificationMode(enum.Enum):
2929
3030
| | Verification checks |
3131
|---------------------------|------------------------------------------------------------------------------ |
32-
| `ALL_CHECKS` | Split checks, uniqueness of the keys yielded in case of the GeneratorBuilder |
33-
| | and the validity (number of files, checksums, etc.) of downloaded files |
32+
| `ALL_CHECKS` | Split checks and validity (number of files, checksums) of downloaded files |
3433
| `BASIC_CHECKS` (default) | Same as `ALL_CHECKS` but without checking downloaded files |
3534
| `NO_CHECKS` | None |
3635

tests/test_arrow_writer.py

Lines changed: 2 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from datasets.arrow_writer import ArrowWriter, OptimizedTypedSequence, ParquetWriter, TypedSequence
1515
from datasets.features import Array2D, ClassLabel, Features, Image, Value
1616
from datasets.features.features import Array2DExtensionType, cast_to_python_objects
17-
from datasets.keyhash import DuplicatedKeysError, InvalidKeyError
1817

1918
from .utils import require_pil
2019

@@ -133,46 +132,15 @@ def test_write_with_features():
133132
assert features == Features.from_arrow_schema(schema)
134133

135134

136-
@pytest.mark.parametrize("writer_batch_size", [None, 1, 10])
137-
def test_key_datatype(writer_batch_size):
138-
output = pa.BufferOutputStream()
139-
with ArrowWriter(
140-
stream=output,
141-
writer_batch_size=writer_batch_size,
142-
hash_salt="split_name",
143-
check_duplicates=True,
144-
) as writer:
145-
with pytest.raises(InvalidKeyError):
146-
writer.write({"col_1": "foo", "col_2": 1}, key=[1, 2])
147-
num_examples, num_bytes = writer.finalize()
148-
149-
150-
@pytest.mark.parametrize("writer_batch_size", [None, 2, 10])
151-
def test_duplicate_keys(writer_batch_size):
152-
output = pa.BufferOutputStream()
153-
with ArrowWriter(
154-
stream=output,
155-
writer_batch_size=writer_batch_size,
156-
hash_salt="split_name",
157-
check_duplicates=True,
158-
) as writer:
159-
with pytest.raises(DuplicatedKeysError):
160-
writer.write({"col_1": "foo", "col_2": 1}, key=10)
161-
writer.write({"col_1": "bar", "col_2": 2}, key=10)
162-
num_examples, num_bytes = writer.finalize()
163-
164-
165135
@pytest.mark.parametrize("writer_batch_size", [None, 2, 10])
166136
def test_write_with_keys(writer_batch_size):
167137
output = pa.BufferOutputStream()
168138
with ArrowWriter(
169139
stream=output,
170140
writer_batch_size=writer_batch_size,
171-
hash_salt="split_name",
172-
check_duplicates=True,
173141
) as writer:
174-
writer.write({"col_1": "foo", "col_2": 1}, key=1)
175-
writer.write({"col_1": "bar", "col_2": 2}, key=2)
142+
writer.write({"col_1": "foo", "col_2": 1})
143+
writer.write({"col_1": "bar", "col_2": 2})
176144
num_examples, num_bytes = writer.finalize()
177145
assert num_examples == 2
178146
assert num_bytes > 0

tests/test_builder.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -666,26 +666,6 @@ def test_generator_based_download_and_prepare(self):
666666
os.path.exists(os.path.join(tmp_dir, builder.dataset_name, "default", "0.0.0", "dataset_info.json"))
667667
)
668668

669-
# Test that duplicated keys are ignored if verification_mode is "no_checks"
670-
with tempfile.TemporaryDirectory() as tmp_dir:
671-
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir)
672-
with patch("datasets.builder.ArrowWriter", side_effect=ArrowWriter) as mock_arrow_writer:
673-
builder.download_and_prepare(
674-
download_mode=DownloadMode.FORCE_REDOWNLOAD, verification_mode=VerificationMode.NO_CHECKS
675-
)
676-
mock_arrow_writer.assert_called_once()
677-
args, kwargs = mock_arrow_writer.call_args_list[0]
678-
self.assertFalse(kwargs["check_duplicates"])
679-
680-
mock_arrow_writer.reset_mock()
681-
682-
builder.download_and_prepare(
683-
download_mode=DownloadMode.FORCE_REDOWNLOAD, verification_mode=VerificationMode.BASIC_CHECKS
684-
)
685-
mock_arrow_writer.assert_called_once()
686-
args, kwargs = mock_arrow_writer.call_args_list[0]
687-
self.assertTrue(kwargs["check_duplicates"])
688-
689669
def test_cache_dir_no_args(self):
690670
with tempfile.TemporaryDirectory() as tmp_dir:
691671
builder = DummyGeneratorBasedBuilder(cache_dir=tmp_dir, data_dir=None, data_files=None)

0 commit comments

Comments
 (0)