Skip to content

Advanced Batching Logic with CombinedStreamingDataset #434

@schopra8

Description

@schopra8

🚀 Feature

When we use CombinedStreamingDataset, samples are drawn across all datasets within a single batch. Other packages like MosaicML's StreamingDataset allows you to specify advanced strategies for batching.

The current implementation maps to the random option from MosaicML. It would great if we could support other batch methods -- especially Per Stream, where in each batch is taken from a single dataset.

Motivation

If you're training over K datasets, where each dataset yields samples of different shapes - we currently have to split the random batch into microbatches, since the samples from the different datasets cannot be stacked together. These leads to 10-20% slowdown in training, because we're not fully utilizing the GPU on ever forward and backward pass.

Pitch

Allow for additional options in batching, e.g. "per stream" -- which dictates how batches are yielded, when using CombinedStreamingDataset.

Alternatives

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions