Skip to content

Commit 6e4a409

Browse files
Feat: Update indexing of parquet dataset and also add streaming support to huggingface datasets (#505)
* moved to constants * moved constants * update hf downloader * updated writer if file obj * updated num workers * add existence check for chunk file before loading in ParquetLoader * add close method to ParquetLoader for memory management * fix closing of parquet chunks * refactor: replace shutil.copy2 with shutil.copyfile * update preload * upd documentation for default_cache_dir function * added test case for hf downloader * update test cases for parquet * update index hf dataset * updated parquet writer * added test case for index hf dataset * validate item_loader type for hf datasets and improve error handling * add support for ParquetLoader in StreamingDataset tests * simplified the parquet indexing process from different file services * update num workers * cleanup * updaet the order * add validation for low memory mode with ParquetLoader in StreamingDataset * update params * update item loader for low memory usage * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update naming conventions * fix type error * fix type errors * fix patch * add read count --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a2b2570 commit 6e4a409

File tree

13 files changed

+536
-294
lines changed

13 files changed

+536
-294
lines changed

src/litdata/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
_TQDM_AVAILABLE = RequirementCache("tqdm")
4040
_LIGHTNING_SDK_AVAILABLE = RequirementCache("lightning_sdk")
4141
_HF_HUB_AVAILABLE = RequirementCache("huggingface_hub")
42+
_PYARROW_AVAILABLE = RequirementCache("pyarrow")
4243
_POLARS_AVAILABLE = RequirementCache("polars>1.0.0")
4344
_DEBUG = bool(int(os.getenv("DEBUG_LITDATA", "0")))
4445

src/litdata/processing/readers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,10 @@
1616
from abc import ABC, abstractmethod
1717
from typing import Any, List
1818

19-
from lightning_utilities.core.imports import RequirementCache
20-
19+
from litdata.constants import _PYARROW_AVAILABLE
2120
from litdata.streaming.dataloader import StreamingDataLoader
2221
from litdata.utilities.format import _get_tqdm_iterator_if_available
2322

24-
_PYARROW_AVAILABLE = RequirementCache("pyarrow")
25-
2623

2724
class BaseReader(ABC):
2825
"""The `BaseReader` interface defines how to read and preprocess data

src/litdata/streaming/dataset.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,25 @@ def __init__(
100100

101101
if input_dir.url is not None and input_dir.url.startswith("hf://"):
102102
if index_path is None:
103-
# no index path provide, load from cache, or try indexing on the go.
103+
# No index_path was provided. Attempt to load it from cache or generate it dynamically on the fly.
104104
index_path = index_hf_dataset(input_dir.url)
105105
cache_dir.path = index_path
106106
input_dir.path = index_path
107-
item_loader = ParquetLoader()
107+
108+
if item_loader is not None and not isinstance(item_loader, ParquetLoader):
109+
raise ValueError(
110+
"Invalid item_loader for hf://datasets. "
111+
"The item_loader must be an instance of ParquetLoader. "
112+
"Please provide a valid ParquetLoader instance."
113+
)
114+
115+
if item_loader is not None and item_loader._low_memory and shuffle:
116+
raise ValueError(
117+
"You have enabled shuffling when using low memory with ParquetLoader. "
118+
"This configuration may lead to performance issues during the training process. "
119+
"Consider disabling shuffling or using a ParquetLoader without low memory mode."
120+
)
121+
item_loader = item_loader or ParquetLoader()
108122

109123
self.input_dir = input_dir
110124
self.cache_dir = cache_dir
@@ -548,9 +562,7 @@ def _validate_state_dict(self) -> None:
548562
"The provided `item_loader` state doesn't match the current one. "
549563
f"Found `{self.item_loader.state_dict()}` instead of `{state['item_loader']}`."
550564
)
551-
logger.warning(
552-
f"Overriding state item_loader {state['item_loader']} " f"to {self.item_loader.state_dict()}."
553-
)
565+
logger.warning(f"Overriding state item_loader {state['item_loader']} to {self.item_loader.state_dict()}.")
554566
state["item_loader"] = self.item_loader.state_dict()
555567

556568
if state["drop_last"] != self.drop_last:

src/litdata/streaming/downloader.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import os
1616
import shutil
1717
import subprocess
18+
import tempfile
1819
from abc import ABC
1920
from contextlib import suppress
2021
from typing import Any, Dict, List, Optional, Type
@@ -58,9 +59,9 @@ def download_chunk_from_index(self, chunk_index: int) -> None:
5859
local_chunkpath = os.path.join(self._cache_dir, chunk_filename)
5960
remote_chunkpath = os.path.join(self._remote_dir, chunk_filename)
6061

61-
self.download_file(remote_chunkpath, local_chunkpath, chunk_filename)
62+
self.download_file(remote_chunkpath, local_chunkpath)
6263

63-
def download_file(self, remote_chunkpath: str, local_chunkpath: str, remote_chunk_filename: str = "") -> None:
64+
def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
6465
pass
6566

6667

@@ -74,7 +75,7 @@ def __init__(
7475
if not self._s5cmd_available or _DISABLE_S5CMD:
7576
self._client = S3Client(storage_options=self._storage_options)
7677

77-
def download_file(self, remote_filepath: str, local_filepath: str, remote_chunk_filename: str = "") -> None:
78+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
7879
obj = parse.urlparse(remote_filepath)
7980

8081
if obj.scheme != "s3":
@@ -158,7 +159,7 @@ def __init__(
158159

159160
super().__init__(remote_dir, cache_dir, chunks, storage_options)
160161

161-
def download_file(self, remote_filepath: str, local_filepath: str, remote_chunk_filename: str = "") -> None:
162+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
162163
from google.cloud import storage
163164

164165
obj = parse.urlparse(remote_filepath)
@@ -193,7 +194,7 @@ def __init__(
193194

194195
super().__init__(remote_dir, cache_dir, chunks, storage_options)
195196

196-
def download_file(self, remote_filepath: str, local_filepath: str, remote_chunk_filename: str = "") -> None:
197+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
197198
from azure.storage.blob import BlobServiceClient
198199

199200
obj = parse.urlparse(remote_filepath)
@@ -220,7 +221,7 @@ def download_file(self, remote_filepath: str, local_filepath: str, remote_chunk_
220221

221222

222223
class LocalDownloader(Downloader):
223-
def download_file(self, remote_filepath: str, local_filepath: str, remote_chunk_filename: str = "") -> None:
224+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
224225
if not os.path.exists(remote_filepath):
225226
raise FileNotFoundError(f"The provided remote_path doesn't exist: {remote_filepath}")
226227

@@ -248,32 +249,42 @@ def __init__(
248249
)
249250

250251
super().__init__(remote_dir, cache_dir, chunks, storage_options)
251-
from huggingface_hub import HfFileSystem
252252

253-
self.fs = HfFileSystem()
253+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
254+
"""Download a file from the Hugging Face Hub.
255+
The remote_filepath should be in the format `hf://<repo_type>/<repo_org>/<repo_name>/path`. For more
256+
information, see
257+
https://huggingface.co/docs/huggingface_hub/en/guides/hf_file_system#integrations.
258+
"""
259+
from huggingface_hub import hf_hub_download
254260

255-
def download_file(self, remote_filepath: str, local_filepath: str, remote_chunk_filename: str = "") -> None:
256-
# for HF dataset downloading, we don't need remote_filepath, but remote_chunk_filename
257-
with suppress(Timeout), FileLock(local_filepath + ".lock", timeout=0):
258-
temp_path = local_filepath + ".tmp" # Avoid partial writes
259-
try:
260-
with self.fs.open(remote_chunk_filename, "rb") as cloud_file, open(temp_path, "wb") as local_file:
261-
for chunk in iter(lambda: cloud_file.read(4096), b""): # Stream in 4KB chunks local_file.
262-
local_file.write(chunk)
261+
obj = parse.urlparse(remote_filepath)
262+
263+
if obj.scheme != "hf":
264+
raise ValueError(f"Expected obj.scheme to be `hf`, instead, got {obj.scheme} for remote={remote_filepath}")
263265

264-
os.rename(temp_path, local_filepath) # Atomic move after successful write
266+
if os.path.exists(local_filepath):
267+
return
265268

266-
except Exception as e:
267-
print(f"Error processing {remote_chunk_filename}: {e}")
269+
with suppress(Timeout), FileLock(local_filepath + ".lock", timeout=0), tempfile.TemporaryDirectory() as tmpdir:
270+
_, _, _, repo_org, repo_name, path = remote_filepath.split("/", 5)
271+
repo_id = f"{repo_org}/{repo_name}"
268272

269-
finally:
270-
# Ensure cleanup of temp file if an error occurs
271-
if os.path.exists(temp_path):
272-
os.remove(temp_path)
273+
downloaded_path = hf_hub_download(
274+
repo_id,
275+
path,
276+
cache_dir=tmpdir,
277+
repo_type="dataset",
278+
**self._storage_options,
279+
)
280+
if downloaded_path != local_filepath and os.path.exists(downloaded_path):
281+
temp_file_path = local_filepath + ".tmp"
282+
shutil.copyfile(downloaded_path, temp_file_path)
283+
os.rename(temp_file_path, local_filepath)
273284

274285

275286
class LocalDownloaderWithCache(LocalDownloader):
276-
def download_file(self, remote_filepath: str, local_filepath: str, remote_chunk_filename: str = "") -> None:
287+
def download_file(self, remote_filepath: str, local_filepath: str) -> None:
277288
remote_filepath = remote_filepath.replace("local:", "")
278289
super().download_file(remote_filepath, local_filepath)
279290

src/litdata/streaming/item_loader.py

Lines changed: 125 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1111
# See the License for the specific language governing permissions and
1212
# limitations under the License.
13-
1413
import functools
14+
import logging
1515
import os
1616
from abc import ABC, abstractmethod
1717
from collections import defaultdict, namedtuple
@@ -29,6 +29,7 @@
2929
_MAX_WAIT_TIME,
3030
_NUMPY_DTYPES_MAPPING,
3131
_POLARS_AVAILABLE,
32+
_PYARROW_AVAILABLE,
3233
_TORCH_DTYPES_MAPPING,
3334
)
3435
from litdata.streaming.serializers import Serializer
@@ -37,6 +38,8 @@
3738

3839
Interval = namedtuple("Interval", ["chunk_start", "roi_start_idx", "roi_end_idx", "chunk_end"])
3940

41+
logger = logging.getLogger(__name__)
42+
4043

4144
class BaseItemLoader(ABC):
4245
"""The base item loader is responsible to decide how the items within a chunk are loaded."""
@@ -527,13 +530,25 @@ def encode_data(cls, data: List[bytes], _: List[int], flattened: List[Any]) -> T
527530

528531

529532
class ParquetLoader(BaseItemLoader):
530-
def __init__(self) -> None:
533+
def __init__(self, pre_load_chunk: bool = False, low_memory: bool = True) -> None:
531534
if not _POLARS_AVAILABLE:
532535
raise ModuleNotFoundError(
533536
"You are using the Parquet item loader, which depends on `Polars > 1.0.0`.",
534537
"Please, run: `pip install polars>1.0.0`",
535538
)
539+
if not _PYARROW_AVAILABLE:
540+
raise ModuleNotFoundError("Please, run: `pip install pyarrow`")
541+
536542
self._chunk_filepaths: Dict[str, bool] = {}
543+
self._pre_load_chunk = pre_load_chunk
544+
self._low_memory = low_memory
545+
546+
if not self._low_memory:
547+
logger.warning(
548+
"You have set low_memory=False in ParquetLoader. "
549+
"This may result in high memory usage when processing large Parquet chunk files. "
550+
"Consider setting low_memory=True to reduce memory consumption."
551+
)
537552

538553
def setup(
539554
self,
@@ -548,7 +563,9 @@ def setup(
548563
self._data_format = self._config["data_format"]
549564
self._shift_idx = len(self._data_format) * 4
550565
self.region_of_interest = region_of_interest
551-
self._df: Dict[str, Any] = {}
566+
self._df: Dict[int, Any] = {}
567+
self._chunk_row_groups: Dict[int, Any] = {}
568+
self._chunk_row_group_item_read_count: Dict[int, Any] = {}
552569

553570
def generate_intervals(self) -> List[Interval]:
554571
intervals = []
@@ -566,11 +583,14 @@ def generate_intervals(self) -> List[Interval]:
566583
return intervals
567584

568585
def pre_load_chunk(self, chunk_index: int, chunk_filepath: str) -> None:
569-
"""Logic to load the chunk in background to gain some time."""
586+
"""Preload the chunk in the background to gain some time."""
587+
if not self._pre_load_chunk or self._low_memory:
588+
return
589+
570590
import polars as pl
571591

572-
if chunk_filepath not in self._df:
573-
self._df[chunk_filepath] = pl.scan_parquet(chunk_filepath).collect()
592+
if chunk_index not in self._df and os.path.exists(chunk_filepath):
593+
self._df[chunk_index] = pl.scan_parquet(chunk_filepath, low_memory=True).collect()
574594

575595
def load_item_from_chunk(
576596
self,
@@ -580,7 +600,7 @@ def load_item_from_chunk(
580600
begin: int,
581601
filesize_bytes: int,
582602
) -> Any:
583-
"""Returns an item loaded from a chunk."""
603+
"""Returns an item loaded from a parquet chunk."""
584604
if chunk_filepath in self._chunk_filepaths and not os.path.isfile(chunk_filepath):
585605
del self._chunk_filepaths[chunk_filepath]
586606

@@ -593,21 +613,112 @@ def load_item_from_chunk(
593613

594614
self._chunk_filepaths[chunk_filepath] = True
595615

596-
return self.get_df(chunk_filepath).row(index - begin)
616+
# relative index of the desired row within the chunk.
617+
relative_index = index - begin
618+
if self._low_memory:
619+
return self._get_item_with_low_memory(chunk_index, chunk_filepath, relative_index)
597620

598-
def get_df(self, chunk_filepath: str) -> Any:
621+
return self._get_item(chunk_index, chunk_filepath, relative_index)
622+
623+
def _get_item_with_low_memory(self, chunk_index: int, chunk_filepath: str, row_index: int) -> Any:
624+
"""Retrieve a dataframe row from a parquet chunk in low memory mode.
625+
626+
This method reads only the necessary row group from the parquet file using PyArrow and Polars,
627+
which helps in reducing memory usage.
628+
629+
Args:
630+
chunk_index (int): The index of the chunk to be accessed.
631+
chunk_filepath (str): The file path of the parquet chunk.
632+
row_index (int): The relative row index within the loaded chunk.
633+
634+
Returns:
635+
Any: The dataframe row corresponding to the specified index.
636+
"""
599637
import polars as pl
638+
import pyarrow.parquet as pq
639+
640+
# Load the Parquet file metadata if not already loaded
641+
if chunk_index not in self._df:
642+
self._df[chunk_index] = pq.ParquetFile(chunk_filepath)
643+
644+
# Determine the row group and the row index within the row group
645+
parquet_file = self._df[chunk_index]
646+
num_rows_per_row_group = parquet_file.metadata.row_group(0).num_rows
647+
row_group_index = row_index // num_rows_per_row_group
648+
row_index_within_group = row_index % num_rows_per_row_group
649+
650+
# Check if the row group is already loaded
651+
if chunk_index in self._chunk_row_groups and row_group_index in self._chunk_row_groups[chunk_index]:
652+
# Use the cached row group
653+
row_group_df = self._chunk_row_groups[chunk_index][row_group_index]
654+
# update read count
655+
self._chunk_row_group_item_read_count[chunk_index][row_group_index] += 1
656+
else:
657+
# Load the row group and convert it to a Polars DataFrame
658+
row_group = self._df[chunk_index].read_row_group(row_group_index)
659+
row_group_df = pl.from_arrow(row_group)
660+
661+
# Cache the loaded row group
662+
if chunk_index not in self._chunk_row_groups:
663+
self._chunk_row_groups[chunk_index] = {}
664+
self._chunk_row_group_item_read_count[chunk_index] = {}
600665

601-
if chunk_filepath not in self._df:
602-
self._df[chunk_filepath] = pl.scan_parquet(chunk_filepath).collect()
603-
return self._df[chunk_filepath]
666+
self._chunk_row_groups[chunk_index][row_group_index] = row_group_df
667+
self._chunk_row_group_item_read_count[chunk_index][row_group_index] = 1
668+
669+
# Check if the row group has been fully read and release memory if necessary
670+
read_count = self._chunk_row_group_item_read_count[chunk_index][row_group_index]
671+
if read_count >= num_rows_per_row_group:
672+
# Release memory for the fully read row group
673+
del self._chunk_row_groups[chunk_index][row_group_index]
674+
del self._chunk_row_group_item_read_count[chunk_index][row_group_index]
675+
676+
# Return the specific row from the dataframe
677+
return row_group_df.row(row_index_within_group) # type: ignore
678+
679+
def _get_item(self, chunk_index: int, chunk_filepath: str, index: int) -> Any:
680+
"""Retrieve a dataframe row from a parquet chunk by loading the entire chunk into memory.
681+
682+
Note:
683+
This method reads the complete parquet file using Polars. Exercise caution with large files as it
684+
may significantly increase memory usage.
685+
686+
Args:
687+
chunk_index (int): The index of the chunk to be accessed.
688+
chunk_filepath (str): The file path of the parquet chunk.
689+
index (int): The relative row index within the loaded chunk.
690+
691+
Returns:
692+
Any: The dataframe row corresponding to the specified index.
693+
"""
694+
import polars as pl
695+
696+
if chunk_index not in self._df:
697+
self._df[chunk_index] = pl.scan_parquet(chunk_filepath, low_memory=True).collect()
698+
return self._df[chunk_index].row(index)
604699

605700
def delete(self, chunk_index: int, chunk_filepath: str) -> None:
606701
"""Delete a chunk from the local filesystem."""
702+
if chunk_index in self._df:
703+
del self._df[chunk_index]
704+
if chunk_index in self._chunk_row_groups:
705+
del self._chunk_row_groups[chunk_index]
706+
707+
if chunk_index in self._chunk_row_group_item_read_count:
708+
del self._chunk_row_group_item_read_count[chunk_index]
607709
if os.path.exists(chunk_filepath):
608710
os.remove(chunk_filepath)
609-
if chunk_filepath in self._df:
610-
del self._df[chunk_filepath]
711+
712+
def close(self, chunk_index: int) -> None:
713+
"""Release the memory-mapped file for a specific chunk index."""
714+
if chunk_index in self._df:
715+
del self._df[chunk_index]
716+
717+
if chunk_index in self._chunk_row_groups:
718+
del self._chunk_row_groups[chunk_index]
719+
720+
if chunk_index in self._chunk_row_group_item_read_count:
721+
del self._chunk_row_group_item_read_count[chunk_index]
611722

612723
def encode_data(self, data: List[bytes], sizes: List[int], flattened: List[Any]) -> Any:
613724
pass

0 commit comments

Comments
 (0)