diff --git a/setup.py b/setup.py index 41c8a53e367..647aad5a067 100644 --- a/setup.py +++ b/setup.py @@ -167,6 +167,7 @@ "elasticsearch>=7.17.12,<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch(); 7.9.1 has legacy numpy.float_ which was fixed in https://github.com/elastic/elasticsearch-py/pull/2551. "faiss-cpu>=1.8.0.post1", # Pins numpy < 2 "h5py", + "pylance", "jax>=0.3.14; sys_platform != 'win32'", "jaxlib>=0.3.14; sys_platform != 'win32'", "lz4; python_version < '3.14'", # python 3.14 gives ImportError: cannot import name '_compression' from partially initialized module 'lz4.frame diff --git a/src/datasets/load.py b/src/datasets/load.py index 1218262a856..0824cc56177 100644 --- a/src/datasets/load.py +++ b/src/datasets/load.py @@ -66,8 +66,10 @@ from .iterable_dataset import IterableDataset from .naming import camelcase_to_snakecase, snakecase_to_camelcase from .packaged_modules import ( + _ALL_ALLOWED_EXTENSIONS, _EXTENSION_TO_MODULE, _MODULE_TO_EXTENSIONS, + _MODULE_TO_METADATA_EXTENSIONS, _MODULE_TO_METADATA_FILE_NAMES, _PACKAGED_DATASETS_MODULES, ) @@ -91,8 +93,6 @@ logger = get_logger(__name__) -ALL_ALLOWED_EXTENSIONS = list(_EXTENSION_TO_MODULE.keys()) + [".zip"] - class _InitializeConfiguredDatasetBuilder: """ @@ -328,7 +328,7 @@ def create_builder_configs_from_metadata_configs( ) config_data_files_dict = DataFilesPatternsDict.from_patterns( config_patterns, - allowed_extensions=ALL_ALLOWED_EXTENSIONS, + allowed_extensions=_ALL_ALLOWED_EXTENSIONS, ) except EmptyDatasetError as e: raise EmptyDatasetError( @@ -436,14 +436,15 @@ def get_module(self) -> DatasetModule: data_files = DataFilesDict.from_patterns( patterns, base_path=base_path, - allowed_extensions=ALL_ALLOWED_EXTENSIONS, + allowed_extensions=_ALL_ALLOWED_EXTENSIONS, ) module_name, default_builder_kwargs = infer_module_for_data_files( data_files=data_files, path=self.path, ) data_files = data_files.filter( - extensions=_MODULE_TO_EXTENSIONS[module_name], file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name] + extensions=_MODULE_TO_EXTENSIONS[module_name] + _MODULE_TO_METADATA_EXTENSIONS[module_name], + file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name], ) module_path, _ = _PACKAGED_DATASETS_MODULES[module_name] if metadata_configs: @@ -633,7 +634,7 @@ def get_module(self) -> DatasetModule: data_files = DataFilesDict.from_patterns( patterns, base_path=base_path, - allowed_extensions=ALL_ALLOWED_EXTENSIONS, + allowed_extensions=_ALL_ALLOWED_EXTENSIONS, download_config=self.download_config, ) module_name, default_builder_kwargs = infer_module_for_data_files( @@ -642,7 +643,8 @@ def get_module(self) -> DatasetModule: download_config=self.download_config, ) data_files = data_files.filter( - extensions=_MODULE_TO_EXTENSIONS[module_name], file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name] + extensions=_MODULE_TO_EXTENSIONS[module_name] + _MODULE_TO_METADATA_EXTENSIONS[module_name], + file_names=_MODULE_TO_METADATA_FILE_NAMES[module_name], ) module_path, _ = _PACKAGED_DATASETS_MODULES[module_name] if metadata_configs: diff --git a/src/datasets/packaged_modules/__init__.py b/src/datasets/packaged_modules/__init__.py index c9a32ff71f0..15500d3cc54 100644 --- a/src/datasets/packaged_modules/__init__.py +++ b/src/datasets/packaged_modules/__init__.py @@ -12,6 +12,7 @@ from .hdf5 import hdf5 from .imagefolder import imagefolder from .json import json +from .lance import lance from .niftifolder import niftifolder from .pandas import pandas from .parquet import parquet @@ -53,6 +54,7 @@ def _hash_python_lines(lines: list[str]) -> str: "xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())), "hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())), "eval": (eval.__name__, _hash_python_lines(inspect.getsource(eval).splitlines())), + "lance": (lance.__name__, _hash_python_lines(inspect.getsource(lance).splitlines())), } # get importable module names and hash for caching @@ -85,6 +87,7 @@ def _hash_python_lines(lines: list[str]) -> str: ".hdf5": ("hdf5", {}), ".h5": ("hdf5", {}), ".eval": ("eval", {}), + ".lance": ("lance", {}), } _EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) _EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS}) @@ -114,3 +117,14 @@ def _hash_python_lines(lines: list[str]) -> str: _MODULE_TO_METADATA_FILE_NAMES["videofolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["pdffolder"] = imagefolder.ImageFolder.METADATA_FILENAMES _MODULE_TO_METADATA_FILE_NAMES["niftifolder"] = imagefolder.ImageFolder.METADATA_FILENAMES + +_MODULE_TO_METADATA_EXTENSIONS: Dict[str, List[str]] = {} +for _module in _MODULE_TO_EXTENSIONS: + _MODULE_TO_METADATA_EXTENSIONS[_module] = [] +_MODULE_TO_METADATA_EXTENSIONS["lance"] = lance.Lance.METADATA_EXTENSIONS + +# Total + +_ALL_EXTENSIONS = list(_EXTENSION_TO_MODULE.keys()) + [".zip"] +_ALL_METADATA_EXTENSIONS = list({_ext for _exts in _MODULE_TO_METADATA_EXTENSIONS.values() for _ext in _exts}) +_ALL_ALLOWED_EXTENSIONS = _ALL_EXTENSIONS + _ALL_METADATA_EXTENSIONS diff --git a/src/datasets/packaged_modules/lance/__init__.py b/src/datasets/packaged_modules/lance/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/datasets/packaged_modules/lance/lance.py b/src/datasets/packaged_modules/lance/lance.py new file mode 100644 index 00000000000..88f5fcc009a --- /dev/null +++ b/src/datasets/packaged_modules/lance/lance.py @@ -0,0 +1,166 @@ +import re +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Dict, List, Optional + +import pyarrow as pa +from huggingface_hub import HfApi + +import datasets +from datasets.builder import Key +from datasets.table import table_cast +from datasets.utils.file_utils import is_local_path + + +if TYPE_CHECKING: + import lance + import lance.file + +logger = datasets.utils.logging.get_logger(__name__) + + +@dataclass +class LanceConfig(datasets.BuilderConfig): + """ + BuilderConfig for Lance format. + + Args: + features: (`Features`, *optional*): + Cast the data to `features`. + columns: (`List[str]`, *optional*): + List of columns to load, the other ones are ignored. + batch_size: (`int`, *optional*): + Size of the RecordBatches to iterate on. Default to 256. + token: (`str`, *optional*): + Optional HF token to use to download datasets. + """ + + features: Optional[datasets.Features] = None + columns: Optional[List[str]] = None + batch_size: Optional[int] = 256 + token: Optional[str] = None + + +def resolve_dataset_uris(files: List[str]) -> Dict[str, List[str]]: + dataset_uris = set() + for file_path in files: + path = Path(file_path) + if path.parent.name in {"_transactions", "_indices", "_versions"}: + dataset_root = path.parent.parent + dataset_uris.add(str(dataset_root)) + return list(dataset_uris) + + +def _fix_hf_uri(uri: str) -> str: + # replace the revision tag from hf uri + if "@" in uri: + matched = re.match(r"(hf://.+?)(@[0-9a-f]+)(/.*)", uri) + if matched: + uri = matched.group(1) + matched.group(3) + return uri + + +def _fix_local_version_file(uri: str) -> str: + # replace symlinks with real files for _version + if "/_versions/" in uri and is_local_path(uri): + path = Path(uri) + if path.is_symlink(): + data = path.read_bytes() + path.unlink() + path.write_bytes(data) + return uri + + +class Lance(datasets.ArrowBasedBuilder): + BUILDER_CONFIG_CLASS = LanceConfig + METADATA_EXTENSIONS = [".idx", ".txn", ".manifest"] + + def _info(self): + return datasets.DatasetInfo(features=self.config.features) + + def _split_generators(self, dl_manager): + import lance + import lance.file + + if not self.config.data_files: + raise ValueError(f"At least one data file must be specified, but got data_files={self.config.data_files}") + if self.repo_id: + api = HfApi(**dl_manager.download_config.storage_options["hf"]) + dataset_sha = api.dataset_info(self.repo_id).sha + if dataset_sha != self.hash: + raise NotImplementedError( + f"lance doesn't support loading other revisions than 'main' yet, but got {self.hash}" + ) + data_files = dl_manager.download(self.config.data_files) + + # TODO: remove once Lance supports HF links with revisions + data_files = {split: [_fix_hf_uri(file) for file in files] for split, files in data_files.items()} + # TODO: remove once Lance supports symlinks for _version files + data_files = {split: [_fix_local_version_file(file) for file in files] for split, files in data_files.items()} + + splits = [] + for split_name, files in data_files.items(): + storage_options = dl_manager.download_config.storage_options.get(files[0].split("://", 0)[0] + "://") + + lance_dataset_uris = resolve_dataset_uris(files) + if lance_dataset_uris: + fragments = [ + frag + for uri in lance_dataset_uris + for frag in lance.dataset(uri, storage_options=storage_options).get_fragments() + ] + if self.info.features is None: + pa_schema = fragments[0]._ds.schema + splits.append( + datasets.SplitGenerator( + name=split_name, + gen_kwargs={"fragments": fragments, "lance_files": None}, + ) + ) + else: + lance_files = [ + lance.file.LanceFileReader(file, storage_options=storage_options, columns=self.config.columns) + for file in files + ] + if self.info.features is None: + pa_schema = lance_files[0].metadata().schema + splits.append( + datasets.SplitGenerator( + name=split_name, + gen_kwargs={"fragments": None, "lance_files": lance_files}, + ) + ) + if self.info.features is None: + if self.config.columns: + fields = [ + pa_schema.field(name) for name in self.config.columns if pa_schema.get_field_index(name) != -1 + ] + pa_schema = pa.schema(fields) + self.info.features = datasets.Features.from_arrow_schema(pa_schema) + + return splits + + def _cast_table(self, pa_table: pa.Table) -> pa.Table: + if self.info.features is not None: + # more expensive cast to support nested features with keys in a different order + # allows str <-> int/float or str to Audio for example + pa_table = table_cast(pa_table, self.info.features.arrow_schema) + return pa_table + + def _generate_tables( + self, + fragments: Optional[List["lance.LanceFragment"]], + lance_files: Optional[List["lance.file.LanceFileReader"]], + ): + if fragments: + for frag_idx, fragment in enumerate(fragments): + for batch_idx, batch in enumerate( + fragment.to_batches(columns=self.config.columns, batch_size=self.config.batch_size) + ): + table = pa.Table.from_batches([batch]) + yield Key(frag_idx, batch_idx), self._cast_table(table) + else: + for file_idx, lance_file in enumerate(lance_files): + for batch_idx, batch in enumerate(lance_file.read_all(batch_size=self.config.batch_size).to_batches()): + table = pa.Table.from_batches([batch]) + yield Key(file_idx, batch_idx), self._cast_table(table) diff --git a/tests/packaged_modules/test_lance.py b/tests/packaged_modules/test_lance.py new file mode 100644 index 00000000000..823909e0258 --- /dev/null +++ b/tests/packaged_modules/test_lance.py @@ -0,0 +1,106 @@ +import lance +import numpy as np +import pyarrow as pa +import pytest + +from datasets import load_dataset + + +@pytest.fixture +def lance_dataset(tmp_path) -> str: + data = pa.table( + { + "id": pa.array([1, 2, 3, 4]), + "value": pa.array([10.0, 20.0, 30.0, 40.0]), + "text": pa.array(["a", "b", "c", "d"]), + "vector": pa.FixedSizeListArray.from_arrays(pa.array([0.1] * 16, pa.float32()), list_size=4), + } + ) + dataset_path = tmp_path / "test_dataset.lance" + lance.write_dataset(data, dataset_path) + return str(dataset_path) + + +@pytest.fixture +def lance_hf_dataset(tmp_path) -> str: + data = pa.table( + { + "id": pa.array([1, 2, 3, 4]), + "value": pa.array([10.0, 20.0, 30.0, 40.0]), + "text": pa.array(["a", "b", "c", "d"]), + "vector": pa.FixedSizeListArray.from_arrays(pa.array([0.1] * 16, pa.float32()), list_size=4), + } + ) + dataset_dir = tmp_path / "data" / "train.lance" + dataset_dir.parent.mkdir(parents=True, exist_ok=True) + lance.write_dataset(data, dataset_dir) + lance.write_dataset(data[:2], tmp_path / "data" / "test.lance") + + with open(tmp_path / "README.md", "w") as f: + f.write("""--- +size_categories: +- 1M