-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Description
Feature request
Added sharding function support to the streaming IterableDataset, allowing users to adjust the number of shards according to their training resources. For example:
dataset = load_dataset(*, stream=True)
dataset = dataset.shard(num_shards=num_shards, num_samples=num_samples) #We may need to know the total number of samples (num_samples) in advance.
Motivation
When performing large-scale pre-training in a distributed environment, large datasets may only be loaded in a streaming manner. To improve training efficiency, my current approach is as follows:
file_type="parquet"
dataset_path="./*.parquet"
dataset = load_dataset(file_type,data_files=dataset_path, stream=True)
dataset = split_dataset_by_node(dataset, rank=rank, world_size=world_size)
I split a large file into N = world_size * dataloader_num_workers files and placed them under dataset_path. This ensures that each GPU processes different shards. However, this approach has some issues. If the number of GPUs used to train the model changes next time, I need to split the large file again to ensure that IterableDataset.num_shards = world_size * dataloader_num_workers.
I'd like to know if there's a better approach, such as directly loading the large dataset in a streaming manner and then sharding the IterableDataset based on the number of GPUs and num_workers, similar to the approach in Example 1 of https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset @lhoestq