|
| 1 | +# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. |
| 2 | +import os |
| 3 | +from dataclasses import dataclass, field |
| 4 | +from pathlib import Path |
| 5 | +from typing import Union, Optional, Tuple |
| 6 | + |
| 7 | +from torch.utils.data import DataLoader |
| 8 | + |
| 9 | +from litgpt import Tokenizer |
| 10 | +from litgpt.data import DataModule |
| 11 | + |
| 12 | + |
| 13 | +@dataclass |
| 14 | +class LitData(DataModule): |
| 15 | + """Loads data using LitData's StreamingDataset given a path to a folder of preprocessed data (chunks).""" |
| 16 | + |
| 17 | + data_path: Union[str, Path] = Path("data/") |
| 18 | + """The path to the data directory containing the preprocessed chunks for the streaming dataset |
| 19 | + The path can also be a remote path (e.g., s3://). See also ``split_names`` if this path contains subfolders |
| 20 | + for training- and validation splits.""" |
| 21 | + split_names: Optional[Tuple[str, str]] = None |
| 22 | + """Optional tuple for names of subfolders for training and validation under ``data_path``. If not provided, |
| 23 | + all data under data_path will be used for training, and the validation dataloader will be identical to the |
| 24 | + train dataloader.""" |
| 25 | + seed: int = 42 |
| 26 | + """The random seed for shuffling the dataset.""" |
| 27 | + num_workers: int = 8 |
| 28 | + """How many DataLoader processes to use for loading.""" |
| 29 | + |
| 30 | + batch_size: int = field(init=False, repr=False, default=1) |
| 31 | + seq_length: int = field(init=False, repr=False, default=2048) |
| 32 | + |
| 33 | + def __post_init__(self) -> None: |
| 34 | + if self.split_names is not None and len(self.split_names) != 2: |
| 35 | + raise ValueError( |
| 36 | + "If provided `split_names` must be a tuple of two strings, for example: ('train', 'val')." |
| 37 | + ) |
| 38 | + |
| 39 | + def connect( |
| 40 | + self, |
| 41 | + tokenizer: Optional[Tokenizer] = None, |
| 42 | + batch_size: int = 1, |
| 43 | + max_seq_length: Optional[int] = None |
| 44 | + ) -> None: |
| 45 | + self.batch_size = batch_size |
| 46 | + self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well |
| 47 | + |
| 48 | + def train_dataloader(self) -> DataLoader: |
| 49 | + input_dir = os.path.join(self.data_path, self.split_names[0]) if self.split_names else str(self.data_path) |
| 50 | + return self._dataloader(input_dir=input_dir, train=True) |
| 51 | + |
| 52 | + def val_dataloader(self) -> DataLoader: |
| 53 | + input_dir = os.path.join(self.data_path, self.split_names[1]) if self.split_names else str(self.data_path) |
| 54 | + return self._dataloader(input_dir=input_dir, train=False) |
| 55 | + |
| 56 | + def _dataloader(self, input_dir: str, train: bool): |
| 57 | + from litdata.streaming import StreamingDataset, TokensLoader |
| 58 | + |
| 59 | + dataset = StreamingDataset( |
| 60 | + input_dir=input_dir, |
| 61 | + item_loader=TokensLoader(block_size=self.seq_length), |
| 62 | + shuffle=train, |
| 63 | + drop_last=True, |
| 64 | + ) |
| 65 | + dataloader = DataLoader( |
| 66 | + dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True |
| 67 | + ) |
| 68 | + return dataloader |
0 commit comments