|
11 | 11 | import datasets |
12 | 12 | import numpy as np |
13 | 13 | import torch |
14 | | -from datasets import Dataset, IterableDataset |
| 14 | +from datasets import Dataset, IterableDataset, interleave_datasets, load_dataset |
15 | 15 | from datasets.iterable_dataset import ShufflingConfig |
16 | 16 | from torch.distributed.checkpoint.stateful import Stateful |
17 | 17 | from torchdata.stateful_dataloader import StatefulDataLoader |
18 | 18 | from transformers import PreTrainedTokenizer |
19 | 19 |
|
| 20 | +from torchtitan.tools import utils |
20 | 21 | from torchtitan.tools.logging import logger |
21 | 22 |
|
22 | 23 |
|
@@ -541,6 +542,192 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: |
541 | 542 | super().load_state_dict(pickle.loads(state_dict[f'rank_{self.rank}'])) |
542 | 543 |
|
543 | 544 |
|
| 545 | +def build_dataset( |
| 546 | + dataset: IterableDataset, |
| 547 | + dataset_name: str = None, |
| 548 | + dataset_split: str = 'train', |
| 549 | + data_dir: str = None, |
| 550 | + data_files: str = None, |
| 551 | + data_probs: List[float] = None, |
| 552 | + streaming: bool = False, |
| 553 | + dp_degree: Optional[int] = None, |
| 554 | + num_workers: int = 32, |
| 555 | + seed: int = 42, |
| 556 | +) -> IterableDataset: |
| 557 | + color = utils.Color |
| 558 | + min_num_shards = dp_degree * num_workers if dp_degree else None |
| 559 | + if len(dataset.split(',')) == 1: |
| 560 | + dataset = load_dataset( |
| 561 | + path=dataset, |
| 562 | + name=dataset_name, |
| 563 | + split=dataset_split, |
| 564 | + data_dir=data_dir, |
| 565 | + data_files=data_files, |
| 566 | + trust_remote_code=True, |
| 567 | + streaming=streaming, |
| 568 | + num_proc=( |
| 569 | + num_workers |
| 570 | + if not streaming |
| 571 | + else None |
| 572 | + ), |
| 573 | + ) |
| 574 | + |
| 575 | + logger.info(f"Shuffling the dataset with seed {seed}") |
| 576 | + if not streaming: |
| 577 | + # the states of map-style dataset is recoverable after shuffling |
| 578 | + dataset = dataset.shuffle(seed=seed) |
| 579 | + dataset = dataset.to_iterable_dataset(num_shards=min_num_shards or dataset.num_shards) |
| 580 | + else: |
| 581 | + if min_num_shards is not None and dataset.num_shards < min_num_shards: |
| 582 | + logger.warning( |
| 583 | + f"{color.red}" |
| 584 | + f"Dataset {dataset} has insufficient shards ({dataset.num_shards}). " |
| 585 | + f"Need {min_num_shards} shards minimum for {dp_degree} data parallel workers × " |
| 586 | + f"{num_workers} dataloader workers. " |
| 587 | + f"Disabling the streaming mode and resharding dataset to {min_num_shards} shards." |
| 588 | + f"{color.reset}" |
| 589 | + ) |
| 590 | + dataset = load_dataset( |
| 591 | + path=dataset, |
| 592 | + name=dataset_name, |
| 593 | + split=dataset_split, |
| 594 | + data_dir=data_dir, |
| 595 | + data_files=data_files, |
| 596 | + trust_remote_code=True, |
| 597 | + streaming=False, |
| 598 | + num_proc=num_workers, |
| 599 | + ) |
| 600 | + dataset = dataset.shuffle(seed=seed) |
| 601 | + dataset = dataset.to_iterable_dataset(num_shards=min_num_shards) |
| 602 | + else: |
| 603 | + dataset = shuffle(dataset, seed=seed) |
| 604 | + else: |
| 605 | + datasets = dataset.split(",") |
| 606 | + if dataset_name is not None: |
| 607 | + dataset_names = [ |
| 608 | + name or None for name in dataset_name.split(",") |
| 609 | + ] |
| 610 | + assert len(dataset_names) == len(datasets), ( |
| 611 | + "The number of dataset names must match the number of datasets" |
| 612 | + ) |
| 613 | + else: |
| 614 | + dataset_names = [None] * len(datasets) |
| 615 | + if dataset_split is not None: |
| 616 | + dataset_splits = [split or "train"for split in dataset_split.split(",")] |
| 617 | + assert len(dataset_splits) == len(datasets), ( |
| 618 | + "The number of dataset splits must match the number of datasets" |
| 619 | + ) |
| 620 | + else: |
| 621 | + dataset_splits = ["train"] * len(datasets) |
| 622 | + if data_dir is not None: |
| 623 | + data_dirs = [ |
| 624 | + data_dir or None for data_dir in data_dir.split(",") |
| 625 | + ] |
| 626 | + assert len(data_dirs) == len(datasets), ( |
| 627 | + "The number of data dirs must match the number of datasets" |
| 628 | + ) |
| 629 | + else: |
| 630 | + data_dirs = [None] * len(datasets) |
| 631 | + if data_files is not None: |
| 632 | + data_files = data_files.split(",") |
| 633 | + assert len(data_files) == len(datasets), ( |
| 634 | + "The number of data files must match the number of datasets" |
| 635 | + ) |
| 636 | + else: |
| 637 | + data_files = [None] * len(datasets) |
| 638 | + if data_probs is not None: |
| 639 | + data_probs = [float(p) for p in data_probs.split(",")] |
| 640 | + assert len(data_probs) == len(datasets), ( |
| 641 | + "The number of data probabilities must match the number of datasets" |
| 642 | + ) |
| 643 | + else: |
| 644 | + raise ValueError( |
| 645 | + "Data sampling probabilities are required if using multiple datasets" |
| 646 | + ) |
| 647 | + |
| 648 | + subsets = [] |
| 649 | + for i, prob in enumerate(data_probs): |
| 650 | + subset = load_dataset( |
| 651 | + path=datasets[i], |
| 652 | + name=dataset_names[i], |
| 653 | + split=dataset_splits[i], |
| 654 | + data_dir=data_dirs[i], |
| 655 | + data_files=data_files[i], |
| 656 | + trust_remote_code=True, |
| 657 | + streaming=streaming, |
| 658 | + num_proc=( |
| 659 | + num_workers |
| 660 | + if not streaming |
| 661 | + else None |
| 662 | + ), |
| 663 | + ) |
| 664 | + logger.info( |
| 665 | + f"Subset {color.cyan}{datasets[i]}" |
| 666 | + + (f":{dataset_names[i]} " if dataset_names[i] else " ") |
| 667 | + + f"(p = {prob:.3f}){color.reset}:\n" |
| 668 | + + f"{subset}" |
| 669 | + ) |
| 670 | + |
| 671 | + logger.info(f"Shuffling the dataset with seed {seed}") |
| 672 | + if not streaming: |
| 673 | + # the states of map-style dataset is recoverable after shuffling |
| 674 | + subset = subset.shuffle(seed=seed) |
| 675 | + subset = subset.to_iterable_dataset(num_shards=min_num_shards or subset.num_shards) |
| 676 | + else: |
| 677 | + if min_num_shards is not None and subset.num_shards < min_num_shards: |
| 678 | + logger.warning( |
| 679 | + f"{color.red}" |
| 680 | + f"Dataset {datasets[i]} has insufficient shards ({subset.num_shards}). " |
| 681 | + f"Need {min_num_shards} shards minimum for desired data parallel workers × " |
| 682 | + f"{num_workers} dataloader workers. " |
| 683 | + f"Resharding dataset to {min_num_shards} shards and disabling streaming mode." |
| 684 | + f"{color.reset}" |
| 685 | + ) |
| 686 | + # again, it's ok to directly shuffle the map-style dataset |
| 687 | + # we expect an error raised if the map-style dataset still has not enough data shards |
| 688 | + subset = load_dataset( |
| 689 | + path=datasets[i], |
| 690 | + name=dataset_names[i], |
| 691 | + split=dataset_splits[i], |
| 692 | + data_dir=data_dirs[i], |
| 693 | + data_files=data_files[i], |
| 694 | + trust_remote_code=True, |
| 695 | + streaming=False, |
| 696 | + num_proc=num_workers, |
| 697 | + ) |
| 698 | + subset = subset.shuffle(seed=seed) |
| 699 | + subset = subset.to_iterable_dataset(num_shards=min_num_shards or subset.num_shards) |
| 700 | + else: |
| 701 | + # we set relatively small buffer size here as interleaving could provide some randomness |
| 702 | + subset = shuffle( |
| 703 | + subset, |
| 704 | + seed=seed, |
| 705 | + buffer_size=max(128, 1024 // len(datasets)), |
| 706 | + ) |
| 707 | + |
| 708 | + if "text" in subset.column_names: |
| 709 | + subset = subset.select_columns("text") |
| 710 | + elif "content" in subset.column_names: |
| 711 | + subset = subset.select_columns("content") |
| 712 | + else: |
| 713 | + raise ValueError( |
| 714 | + f"Subset {datasets[i]} has no 'text' or 'content' column" |
| 715 | + ) |
| 716 | + subsets.append(subset) |
| 717 | + |
| 718 | + logger.info( |
| 719 | + f"Interleaving {len(subsets)} datasets with probabilities {data_probs}" |
| 720 | + ) |
| 721 | + dataset = interleave_datasets( |
| 722 | + datasets=subsets, |
| 723 | + probabilities=data_probs, |
| 724 | + stopping_strategy="all_exhausted", |
| 725 | + seed=seed, |
| 726 | + ) |
| 727 | + logger.info(f"{dataset}") |
| 728 | + return dataset |
| 729 | + |
| 730 | + |
544 | 731 | def build_dataloader( |
545 | 732 | dataset: IterableDataset, |
546 | 733 | tokenizer: PreTrainedTokenizer, |
|
0 commit comments