Skip to content

Commit 0546dcd

Browse files
awaelchlirasbt
authored andcommitted
Add general purpose LitData streaming data module (#1118)
1 parent e56765e commit 0546dcd

File tree

19 files changed

+157
-51
lines changed

19 files changed

+157
-51
lines changed

litgpt/data/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
22

3-
from litgpt.data.base import LitDataModule, SFTDataset, get_sft_collate_fn
3+
from litgpt.data.base import DataModule, SFTDataset, get_sft_collate_fn
44
from litgpt.data.alpaca import Alpaca
55
from litgpt.data.alpaca_2k import Alpaca2k
66
from litgpt.data.alpaca_gpt4 import AlpacaGPT4
@@ -9,6 +9,7 @@
99
from litgpt.data.dolly import Dolly
1010
from litgpt.data.flan import FLAN
1111
from litgpt.data.lima import LIMA
12+
from litgpt.data.lit_data import LitData
1213
from litgpt.data.longform import LongForm
1314
from litgpt.data.tinyllama import TinyLlama
1415
from litgpt.data.tinystories import TinyStories
@@ -24,7 +25,8 @@
2425
"FLAN",
2526
"JSON",
2627
"LIMA",
27-
"LitDataModule",
28+
"LitData",
29+
"DataModule",
2830
"LongForm",
2931
"OpenWebText",
3032
"SFTDataset",

litgpt/data/alpaca.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@
99
import torch
1010
from torch.utils.data import random_split, DataLoader
1111
from lightning_utilities.core.imports import RequirementCache
12-
from litgpt.data import SFTDataset, get_sft_collate_fn, LitDataModule
12+
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
1313
from litgpt.prompts import PromptStyle
1414
from litgpt.tokenizer import Tokenizer
1515

1616
_URL = "https://raw.githubusercontent.com/tloen/alpaca-lora/main/alpaca_data_cleaned_archive.json"
1717

1818

1919
@dataclass
20-
class Alpaca(LitDataModule):
20+
class Alpaca(DataModule):
2121
"""Alpaca data module for supervised finetuning."""
2222

2323
mask_prompt: bool = False

litgpt/data/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from litgpt.prompts import PromptStyle
1313

1414

15-
class LitDataModule(LightningDataModule):
15+
class DataModule(LightningDataModule):
1616
"""Base class for all data modules in LitGPT."""
1717

1818
@abstractmethod

litgpt/data/deita.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from torch.utils.data import DataLoader
1010

1111
from litgpt import PromptStyle
12-
from litgpt.data import LitDataModule, SFTDataset, get_sft_collate_fn
12+
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
1313
from litgpt.tokenizer import Tokenizer
1414

1515

1616
@dataclass
17-
class Deita(LitDataModule):
17+
class Deita(DataModule):
1818
"""Deita data module for supervised finetuning."""
1919

2020
mask_prompt: bool = False

litgpt/data/flan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.utils.data import DataLoader
1010

1111
from litgpt import PromptStyle
12-
from litgpt.data import SFTDataset, get_sft_collate_fn, LitDataModule
12+
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
1313
from litgpt.data.alpaca import download_if_missing
1414
from litgpt.tokenizer import Tokenizer
1515

@@ -19,7 +19,7 @@
1919
# TODO: Including all subsets, FLAN is too large to be loaded in memory. Switch the implementation to cache
2020
# on disk or use Lightning Data
2121
@dataclass
22-
class FLAN(LitDataModule):
22+
class FLAN(DataModule):
2323
"""FLAN data module for supervised finetuning."""
2424

2525
mask_prompt: bool = False

litgpt/data/json_data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from torch.utils.data import random_split, DataLoader
1010

1111
from litgpt import PromptStyle
12-
from litgpt.data import SFTDataset, get_sft_collate_fn, LitDataModule
12+
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
1313
from litgpt.tokenizer import Tokenizer
1414

1515

1616
@dataclass
17-
class JSON(LitDataModule):
17+
class JSON(DataModule):
1818
"""Loads JSON or JSONL data for supervised finetuning."""
1919

2020
json_path: Path

litgpt/data/lima.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99
from torch.utils.data import random_split, DataLoader
1010

1111
from litgpt import PromptStyle
12-
from litgpt.data import LitDataModule, SFTDataset, get_sft_collate_fn
12+
from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn
1313
from litgpt.tokenizer import Tokenizer
1414

1515

1616
@dataclass
17-
class LIMA(LitDataModule):
17+
class LIMA(DataModule):
1818
"""LIMA data module for supervised finetuning."""
1919

2020
mask_prompt: bool = False

litgpt/data/lit_data.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
2+
import os
3+
from dataclasses import dataclass, field
4+
from pathlib import Path
5+
from typing import Union, Optional, Tuple
6+
7+
from torch.utils.data import DataLoader
8+
9+
from litgpt import Tokenizer
10+
from litgpt.data import DataModule
11+
12+
13+
@dataclass
14+
class LitData(DataModule):
15+
"""Loads data using LitData's StreamingDataset given a path to a folder of preprocessed data (chunks)."""
16+
17+
data_path: Union[str, Path] = Path("data/")
18+
"""The path to the data directory containing the preprocessed chunks for the streaming dataset
19+
The path can also be a remote path (e.g., s3://). See also ``split_names`` if this path contains subfolders
20+
for training- and validation splits."""
21+
split_names: Optional[Tuple[str, str]] = None
22+
"""Optional tuple for names of subfolders for training and validation under ``data_path``. If not provided,
23+
all data under data_path will be used for training, and the validation dataloader will be identical to the
24+
train dataloader."""
25+
seed: int = 42
26+
"""The random seed for shuffling the dataset."""
27+
num_workers: int = 8
28+
"""How many DataLoader processes to use for loading."""
29+
30+
batch_size: int = field(init=False, repr=False, default=1)
31+
seq_length: int = field(init=False, repr=False, default=2048)
32+
33+
def __post_init__(self) -> None:
34+
if self.split_names is not None and len(self.split_names) != 2:
35+
raise ValueError(
36+
"If provided `split_names` must be a tuple of two strings, for example: ('train', 'val')."
37+
)
38+
39+
def connect(
40+
self,
41+
tokenizer: Optional[Tokenizer] = None,
42+
batch_size: int = 1,
43+
max_seq_length: Optional[int] = None
44+
) -> None:
45+
self.batch_size = batch_size
46+
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well
47+
48+
def train_dataloader(self) -> DataLoader:
49+
input_dir = os.path.join(self.data_path, self.split_names[0]) if self.split_names else str(self.data_path)
50+
return self._dataloader(input_dir=input_dir, train=True)
51+
52+
def val_dataloader(self) -> DataLoader:
53+
input_dir = os.path.join(self.data_path, self.split_names[1]) if self.split_names else str(self.data_path)
54+
return self._dataloader(input_dir=input_dir, train=False)
55+
56+
def _dataloader(self, input_dir: str, train: bool):
57+
from litdata.streaming import StreamingDataset, TokensLoader
58+
59+
dataset = StreamingDataset(
60+
input_dir=input_dir,
61+
item_loader=TokensLoader(block_size=self.seq_length),
62+
shuffle=train,
63+
drop_last=True,
64+
)
65+
dataloader = DataLoader(
66+
dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
67+
)
68+
return dataloader

litgpt/data/longform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch.utils.data import DataLoader
1010

1111
from litgpt import PromptStyle
12-
from litgpt.data import SFTDataset, get_sft_collate_fn, LitDataModule
12+
from litgpt.data import SFTDataset, get_sft_collate_fn, DataModule
1313
from litgpt.data.alpaca import download_if_missing
1414
from litgpt.tokenizer import Tokenizer
1515

@@ -18,7 +18,7 @@
1818

1919

2020
@dataclass
21-
class LongForm(LitDataModule):
21+
class LongForm(DataModule):
2222
"""LongForm data module for supervised finetuning."""
2323

2424
mask_prompt: bool = False

litgpt/data/openwebtext.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
from torch.utils.data import DataLoader
99

1010
from litgpt import Tokenizer
11-
from litgpt.data import LitDataModule
11+
from litgpt.data import DataModule
1212

1313

1414
@dataclass
15-
class OpenWebText(LitDataModule):
15+
class OpenWebText(DataModule):
1616
"""The OpenWebText data module for pretraining."""
1717

1818
data_path: Union[str, Path] = Path("data/openwebtext")

0 commit comments

Comments
 (0)