Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@
"elasticsearch>=7.17.12,<8.0.0", # 8.0 asks users to provide hosts or cloud_id when instantiating ElasticSearch(); 7.9.1 has legacy numpy.float_ which was fixed in https://github.com/elastic/elasticsearch-py/pull/2551.
"faiss-cpu>=1.8.0.post1", # Pins numpy < 2
"h5py",
"pylance",
"jax>=0.3.14; sys_platform != 'win32'",
"jaxlib>=0.3.14; sys_platform != 'win32'",
"lz4; python_version < '3.14'", # python 3.14 gives ImportError: cannot import name '_compression' from partially initialized module 'lz4.frame
Expand Down
3 changes: 3 additions & 0 deletions src/datasets/packaged_modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .hdf5 import hdf5
from .imagefolder import imagefolder
from .json import json
from .lance import lance
from .niftifolder import niftifolder
from .pandas import pandas
from .parquet import parquet
Expand Down Expand Up @@ -53,6 +54,7 @@ def _hash_python_lines(lines: list[str]) -> str:
"xml": (xml.__name__, _hash_python_lines(inspect.getsource(xml).splitlines())),
"hdf5": (hdf5.__name__, _hash_python_lines(inspect.getsource(hdf5).splitlines())),
"eval": (eval.__name__, _hash_python_lines(inspect.getsource(eval).splitlines())),
"lance": (lance.__name__, _hash_python_lines(inspect.getsource(lance).splitlines())),
}

# get importable module names and hash for caching
Expand Down Expand Up @@ -85,6 +87,7 @@ def _hash_python_lines(lines: list[str]) -> str:
".hdf5": ("hdf5", {}),
".h5": ("hdf5", {}),
".eval": ("eval", {}),
".lance": ("lance", {}),
}
_EXTENSION_TO_MODULE.update({ext: ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
_EXTENSION_TO_MODULE.update({ext.upper(): ("imagefolder", {}) for ext in imagefolder.ImageFolder.EXTENSIONS})
Expand Down
Empty file.
202 changes: 202 additions & 0 deletions src/datasets/packaged_modules/lance/lance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import re
import shutil
import tempfile
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional

import pyarrow as pa

import datasets
from datasets.builder import Key
from datasets.table import table_cast


logger = datasets.utils.logging.get_logger(__name__)


@dataclass
class LanceConfig(datasets.BuilderConfig):
"""
BuilderConfig for Lance format.

Args:
features: (`Features`, *optional*):
Cast the data to `features`.
columns: (`List[str]`, *optional*):
List of columns to load, the other ones are ignored.
batch_size: (`int`, *optional*):
Size of the RecordBatches to iterate on. Default to 256.
token: (`str`, *optional*):
Optional HF token to use to download datasets.
"""

features: Optional[datasets.Features] = None
columns: Optional[List[str]] = None
batch_size: Optional[int] = 256
token: Optional[str] = None

def __post_init__(self):
return super().__post_init__()


class _LanceDataset:
"""Reconstruct a Lance dataset from huggingface snapshot or remote files.

TODO: support sharding
"""

def __init__(
self,
paths: List[datasets.utils.track.tracked_str],
streaming: bool = False,
dataset_uri: Optional[str] = None,
token: Optional[str] = None,
version: Optional[int] = None,
):
self._paths = paths
self._dataset_version = version
self._streaming = streaming
self._dataset_uri = dataset_uri
self._hf_token = token

def __repr__(self):
return "_HfLanceDataset(uri={}, streaming={})".format(self._dataset_uri, self._streaming)

def _open_local_dataset(self) -> "lance.LanceDataset":
import lance

# Reconstruct a temporary dataset directory
self.temp_root = tempfile.TemporaryDirectory(delete=True)
self.dataset_uri = Path(self.temp_root.name)
(self.dataset_uri / "data").mkdir(parents=True, exist_ok=True)
(self.dataset_uri / "_versions").mkdir(parents=True, exist_ok=True)
(self.dataset_uri / "_transactions").mkdir(parents=True, exist_ok=True)

# Reconstruct the dataset
for p in self._paths:
original_path = Path(p.get_origin())
parent_dir = original_path.parent.name
if parent_dir == "data":
(self.dataset_uri / "data" / original_path.name).symlink_to(p)
elif parent_dir == "_transactions":
(self.dataset_uri / "_transactions" / original_path.name).symlink_to(p)
elif parent_dir == "_indices":
(self.dataset_uri / "_indices" / original_path.name).symlink_to(p)
elif parent_dir == "_versions":
shutil.copyfile(p, self.dataset_uri / "_versions" / original_path.name)
return lance.dataset(self.dataset_uri.as_posix(), version=self._dataset_version)

def _open_streaming_dataset(self) -> "lance.LanceDataset":
import lance

storage_opts = {"token": self._hf_token} if self._hf_token else {}
return lance.dataset(self._dataset_uri, version=self._dataset_version, storage_options=storage_opts)

def get_fragments(self) -> List["LanceFragment"]:
if self._streaming:
ds = self._open_streaming_dataset()
else:
ds = self._open_local_dataset()
# TODO: filter fragments based on the provided data files
for fragment in ds.get_fragments():
yield fragment


def _group_by_dataset(files: Iterable[str]) -> Dict[str, List[str]]:
files_per_dataset = defaultdict(list)
for file_path in files:
path = Path(file_path)
if path.parent.name in {"data", "_transactions", "_indices", "_versions"}:
dataset_root = path.parent.parent
files_per_dataset[str(dataset_root)].append(file_path)
return files_per_dataset


def _normalize_hf_uri(uri: str) -> str:
# replace the revision tag from hf uri
if "@" in uri:
matched = re.match(r"(hf://.+?)(@[0-9a-f]+)(/.*)", uri)
if matched:
uri = matched.group(1) + matched.group(3)
return uri


class Lance(datasets.ArrowBasedBuilder):
BUILDER_CONFIG_CLASS = LanceConfig

def _info(self):
return datasets.DatasetInfo(features=self.config.features)

def _get_features(self, ds: _LanceDataset) -> datasets.Features:
if self.info.features is None:
pa_schema = ds._open_local_dataset().schema if not ds._streaming else ds._open_streaming_dataset().schema
if self.config.columns:
fields = [
pa_schema.field(name) for name in self.config.columns if pa_schema.get_field_index(name) != -1
]
pa_schema = pa.schema(fields)
return datasets.Features.from_arrow_schema(pa_schema)
return self.info.features

def _split_generators(self, dl_manager):
dl_manager.download_config.extract_on_the_fly = True

splits = []
for split, files in self.config.data_files.items():
dataset_paths = _group_by_dataset(files)

is_streaming = getattr(dl_manager, "is_streaming", False)

datasets_per_split = []
if is_streaming:
# STREAMING MODE
for dataset_root, file_list in dataset_paths.items():
data_files = [f for f in file_list if "/data/" in f]
# TODO: support revision
if "@" in dataset_root:
# temporarily remove the revision from the dataset root
dataset_root = _normalize_hf_uri(dataset_root)

streaming_ds = _LanceDataset(
data_files,
dataset_uri=dataset_root,
streaming=True,
token=self.config.token,
)
if self.info.features is None:
self.info.features = self._get_features(streaming_ds)
datasets_per_split.append(streaming_ds)
else:
# NON-STREAMING MODE: Download files (existing behavior)
all_files_to_download = list(dl_manager.iter_files(list(dataset_paths.keys())))
local_dataset_paths = _group_by_dataset(
dl_manager.iter_files(dl_manager.download(all_files_to_download))
)
for paths in local_dataset_paths.values():
ds = _LanceDataset(paths, streaming=False, token=self.config.token)
datasets_per_split.append(ds)
splits.append(
datasets.SplitGenerator(
name=split,
gen_kwargs={"datasets": datasets_per_split},
)
)
return splits

def _cast_table(self, pa_table: pa.Table) -> pa.Table:
if self.info.features is not None:
# more expensive cast to support nested features with keys in a different order
# allows str <-> int/float or str to Audio for example
pa_table = table_cast(pa_table, self.info.features.arrow_schema)
return pa_table

def _generate_tables(self, datasets: List[_LanceDataset]):
for ds in datasets:
for frag_idx, fragment in enumerate(ds.get_fragments()):
for batch_idx, batch in enumerate(
fragment.to_batches(columns=self.config.columns, batch_size=self.config.batch_size)
):
table = pa.Table.from_batches([batch])
yield Key(frag_idx, batch_idx), self._cast_table(table)
106 changes: 106 additions & 0 deletions tests/packaged_modules/test_lance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import lance
import numpy as np
import pyarrow as pa
import pytest

from datasets import load_dataset


@pytest.fixture
def lance_dataset(tmp_path) -> str:
data = pa.table(
{
"id": pa.array([1, 2, 3, 4]),
"value": pa.array([10.0, 20.0, 30.0, 40.0]),
"text": pa.array(["a", "b", "c", "d"]),
"vector": pa.FixedSizeListArray.from_arrays(pa.array([0.1] * 16, pa.float32()), list_size=4),
}
)
dataset_path = tmp_path / "test_dataset.lance"
lance.write_dataset(data, dataset_path)
return str(dataset_path)


@pytest.fixture
def lance_hf_dataset(tmp_path) -> str:
data = pa.table(
{
"id": pa.array([1, 2, 3, 4]),
"value": pa.array([10.0, 20.0, 30.0, 40.0]),
"text": pa.array(["a", "b", "c", "d"]),
"vector": pa.FixedSizeListArray.from_arrays(pa.array([0.1] * 16, pa.float32()), list_size=4),
}
)
dataset_dir = tmp_path / "data" / "train.lance"
dataset_dir.parent.mkdir(parents=True, exist_ok=True)
lance.write_dataset(data, dataset_dir)
lance.write_dataset(data[:2], tmp_path / "data" / "test.lance")

with open(tmp_path / "README.md", "w") as f:
f.write("""---
size_categories:
- 1M<n<10M
source_datasets:
- lance_test
---
# Test Lance Dataset\n\n
# My Markdown is fancier\n
""")

return str(tmp_path)


def test_load_lance_dataset(lance_dataset):
dataset_dict = load_dataset(lance_dataset)
assert "train" in dataset_dict.keys()

dataset = dataset_dict["train"]
assert "id" in dataset.column_names
assert "value" in dataset.column_names
assert "text" in dataset.column_names
assert "vector" in dataset.column_names
ids = dataset["id"]
assert ids == [1, 2, 3, 4]


@pytest.mark.parametrize("streaming", [False, True])
def test_load_hf_dataset(lance_hf_dataset, streaming):
dataset_dict = load_dataset(lance_hf_dataset, columns=["id", "text"], streaming=streaming)
assert "train" in dataset_dict.keys()
assert "test" in dataset_dict.keys()
dataset = dataset_dict["train"]

assert "id" in dataset.column_names
assert "text" in dataset.column_names
assert "value" not in dataset.column_names
assert "vector" not in dataset.column_names
ids = list(dataset["id"])
assert ids == [1, 2, 3, 4]
text = list(dataset["text"])
assert text == ["a", "b", "c", "d"]
assert "value" not in dataset.column_names


def test_load_vectors(lance_hf_dataset):
dataset_dict = load_dataset(lance_hf_dataset, columns=["vector"])
assert "train" in dataset_dict.keys()
dataset = dataset_dict["train"]

assert "vector" in dataset.column_names
vectors = dataset.data["vector"].combine_chunks().values.to_numpy(zero_copy_only=False)
assert np.allclose(vectors, np.full(16, 0.1))


@pytest.mark.parametrize("streaming", [False, True])
def test_load_lance_streaming_modes(lance_hf_dataset, streaming):
"""Test loading Lance dataset in both streaming and non-streaming modes."""
from datasets import IterableDataset

ds = load_dataset(lance_hf_dataset, split="train", streaming=streaming)
if streaming:
assert isinstance(ds, IterableDataset)
items = list(ds)
else:
items = list(ds)
assert len(items) == 4
assert all("id" in item for item in items)