-
Notifications
You must be signed in to change notification settings - Fork 88
Description
🚀 Feature
When we use CombinedStreamingDataset, samples are drawn across all datasets within a single batch. Other packages like MosaicML's StreamingDataset allows you to specify advanced strategies for batching.
The current implementation maps to the random option from MosaicML. It would great if we could support other batch methods -- especially Per Stream, where in each batch is taken from a single dataset.
Motivation
If you're training over K datasets, where each dataset yields samples of different shapes - we currently have to split the random batch into microbatches, since the samples from the different datasets cannot be stacked together. These leads to 10-20% slowdown in training, because we're not fully utilizing the GPU on ever forward and backward pass.
Pitch
Allow for additional options in batching, e.g. "per stream" -- which dictates how batches are yielded, when using CombinedStreamingDataset.