Skip to content

Commit 2c49713

Browse files
committed
Move dataset building to flame.data
1 parent d834161 commit 2c49713

File tree

3 files changed

+202
-185
lines changed

3 files changed

+202
-185
lines changed

3rdparty/flash-linear-attention

flame/data.py

Lines changed: 188 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,13 @@
1111
import datasets
1212
import numpy as np
1313
import torch
14-
from datasets import Dataset, IterableDataset
14+
from datasets import Dataset, IterableDataset, interleave_datasets, load_dataset
1515
from datasets.iterable_dataset import ShufflingConfig
1616
from torch.distributed.checkpoint.stateful import Stateful
1717
from torchdata.stateful_dataloader import StatefulDataLoader
1818
from transformers import PreTrainedTokenizer
1919

20+
from torchtitan.tools import utils
2021
from torchtitan.tools.logging import logger
2122

2223

@@ -541,6 +542,192 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
541542
super().load_state_dict(pickle.loads(state_dict[f'rank_{self.rank}']))
542543

543544

545+
def build_dataset(
546+
dataset: IterableDataset,
547+
dataset_name: str = None,
548+
dataset_split: str = 'train',
549+
data_dir: str = None,
550+
data_files: str = None,
551+
data_probs: List[float] = None,
552+
streaming: bool = False,
553+
dp_degree: Optional[int] = None,
554+
num_workers: int = 32,
555+
seed: int = 42,
556+
) -> IterableDataset:
557+
color = utils.Color
558+
min_num_shards = dp_degree * num_workers if dp_degree else None
559+
if len(dataset.split(',')) == 1:
560+
dataset = load_dataset(
561+
path=dataset,
562+
name=dataset_name,
563+
split=dataset_split,
564+
data_dir=data_dir,
565+
data_files=data_files,
566+
trust_remote_code=True,
567+
streaming=streaming,
568+
num_proc=(
569+
num_workers
570+
if not streaming
571+
else None
572+
),
573+
)
574+
575+
logger.info(f"Shuffling the dataset with seed {seed}")
576+
if not streaming:
577+
# the states of map-style dataset is recoverable after shuffling
578+
dataset = dataset.shuffle(seed=seed)
579+
dataset = dataset.to_iterable_dataset(num_shards=min_num_shards or dataset.num_shards)
580+
else:
581+
if min_num_shards is not None and dataset.num_shards < min_num_shards:
582+
logger.warning(
583+
f"{color.red}"
584+
f"Dataset {dataset} has insufficient shards ({dataset.num_shards}). "
585+
f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × "
586+
f"{num_workers} dataloader workers. "
587+
f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards."
588+
f"{color.reset}"
589+
)
590+
dataset = load_dataset(
591+
path=dataset,
592+
name=dataset_name,
593+
split=dataset_split,
594+
data_dir=data_dir,
595+
data_files=data_files,
596+
trust_remote_code=True,
597+
streaming=False,
598+
num_proc=num_workers,
599+
)
600+
dataset = dataset.shuffle(seed=seed)
601+
dataset = dataset.to_iterable_dataset(num_shards=min_num_shards)
602+
else:
603+
dataset = shuffle(dataset, seed=seed)
604+
else:
605+
datasets = dataset.split(",")
606+
if dataset_name is not None:
607+
dataset_names = [
608+
name or None for name in dataset_name.split(",")
609+
]
610+
assert len(dataset_names) == len(datasets), (
611+
"The number of dataset names must match the number of datasets"
612+
)
613+
else:
614+
dataset_names = [None] * len(datasets)
615+
if dataset_split is not None:
616+
dataset_splits = [split or "train"for split in dataset_split.split(",")]
617+
assert len(dataset_splits) == len(datasets), (
618+
"The number of dataset splits must match the number of datasets"
619+
)
620+
else:
621+
dataset_splits = ["train"] * len(datasets)
622+
if data_dir is not None:
623+
data_dirs = [
624+
data_dir or None for data_dir in data_dir.split(",")
625+
]
626+
assert len(data_dirs) == len(datasets), (
627+
"The number of data dirs must match the number of datasets"
628+
)
629+
else:
630+
data_dirs = [None] * len(datasets)
631+
if data_files is not None:
632+
data_files = data_files.split(",")
633+
assert len(data_files) == len(datasets), (
634+
"The number of data files must match the number of datasets"
635+
)
636+
else:
637+
data_files = [None] * len(datasets)
638+
if data_probs is not None:
639+
data_probs = [float(p) for p in data_probs.split(",")]
640+
assert len(data_probs) == len(datasets), (
641+
"The number of data probabilities must match the number of datasets"
642+
)
643+
else:
644+
raise ValueError(
645+
"Data sampling probabilities are required if using multiple datasets"
646+
)
647+
648+
subsets = []
649+
for i, prob in enumerate(data_probs):
650+
subset = load_dataset(
651+
path=datasets[i],
652+
name=dataset_names[i],
653+
split=dataset_splits[i],
654+
data_dir=data_dirs[i],
655+
data_files=data_files[i],
656+
trust_remote_code=True,
657+
streaming=streaming,
658+
num_proc=(
659+
num_workers
660+
if not streaming
661+
else None
662+
),
663+
)
664+
logger.info(
665+
f"Subset {color.cyan}{datasets[i]}"
666+
+ (f":{dataset_names[i]} " if dataset_names[i] else " ")
667+
+ f"(p = {prob:.3f}){color.reset}:\n"
668+
+ f"{subset}"
669+
)
670+
671+
logger.info(f"Shuffling the dataset with seed {seed}")
672+
if not streaming:
673+
# the states of map-style dataset is recoverable after shuffling
674+
subset = subset.shuffle(seed=seed)
675+
subset = subset.to_iterable_dataset(num_shards=min_num_shards or subset.num_shards)
676+
else:
677+
if min_num_shards is not None and subset.num_shards < min_num_shards:
678+
logger.warning(
679+
f"{color.red}"
680+
f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). "
681+
f"Need {min_num_shards} shards minimum for desired data parallel workers × "
682+
f"{num_workers} dataloader workers. "
683+
f"Resharding dataset to {min_num_shards} shards and disabling streaming mode."
684+
f"{color.reset}"
685+
)
686+
# again, it's ok to directly shuffle the map-style dataset
687+
# we expect an error raised if the map-style dataset still has not enough data shards
688+
subset = load_dataset(
689+
path=datasets[i],
690+
name=dataset_names[i],
691+
split=dataset_splits[i],
692+
data_dir=data_dirs[i],
693+
data_files=data_files[i],
694+
trust_remote_code=True,
695+
streaming=False,
696+
num_proc=num_workers,
697+
)
698+
subset = subset.shuffle(seed=seed)
699+
subset = subset.to_iterable_dataset(num_shards=min_num_shards or subset.num_shards)
700+
else:
701+
# we set relatively small buffer size here as interleaving could provide some randomness
702+
subset = shuffle(
703+
subset,
704+
seed=seed,
705+
buffer_size=max(128, 1024 // len(datasets)),
706+
)
707+
708+
if "text" in subset.column_names:
709+
subset = subset.select_columns("text")
710+
elif "content" in subset.column_names:
711+
subset = subset.select_columns("content")
712+
else:
713+
raise ValueError(
714+
f"Subset {datasets[i]} has no 'text' or 'content' column"
715+
)
716+
subsets.append(subset)
717+
718+
logger.info(
719+
f"Interleaving {len(subsets)} datasets with probabilities {data_probs}"
720+
)
721+
dataset = interleave_datasets(
722+
datasets=subsets,
723+
probabilities=data_probs,
724+
stopping_strategy="all_exhausted",
725+
seed=seed,
726+
)
727+
logger.info(f"{dataset}")
728+
return dataset
729+
730+
544731
def build_dataloader(
545732
dataset: IterableDataset,
546733
tokenizer: PreTrainedTokenizer,

0 commit comments

Comments
 (0)