Skip to content

Conversation

@justinvyu
Copy link
Contributor

@justinvyu justinvyu commented Jan 6, 2026

Summary

Introduces token-based chunking for microbatches, enforcing a max_tokens_per_microbatch limit on the total number of real tokens in a microbatch.

Replaces the even sample-based chunking BatchIterator with a BalancedBatchIterator that respects the microbatch token limits.

Screenshot 2026-01-06 at 2 55 33 PM

After this PR, the global training batch sharding and forming minibatches is still even sample-based chunking. The new functionality is adding token-based chunking at the bottom microbatching layer, which enforces a total token maximum per microbatch.

Problem

Previously, microbatches created by chunking sequences evenly so that each microbatch was a certain batch size (micro_train_batch_size_per_gpu), which could lead to:

  • Sequence-packed microbatches can exceed GPU memory limits when sequences vary in length, so you need to conservatively set the micro batch size.
  • Inefficient GPU utilization due to uneven token distribution across microbatches. Token imbalance across DP workers leads to stragglers.

Solution

Introduce balanced_binpacking and TokenBasedBatchIterator utilities that:

  • Enforce max_tokens_per_microbatch — no microbatch exceeds the token limit
  • Roughly balance the microbatches so that we avoid straggler microbatches.
  • Ensure that every worker still receives the same number of microbatches by adding padding batches to satisfy FSDP requirements (every DP worker must do the same number of forward/backward passes).

API changes

  • External: added a config.trainer.max_tokens_per_microbatch configuration. Defaults to the old behavior (sample based microbatching).
  • Adds TokenBasedBatchIterator and renames BatchIterator -> SampleBasedBatchIterator.
  • Unifies Worker.forward to also use the BatchIterator interface, rather than call TrainingBatch.chunk directly.
  • Use TokenBasedBatchIterator when max_tokens_per_microbatch is set for the forward, Critic.ppo_train and Policy.ppo_train methods.

Example

Example ppo_train call from the unit test:

minibatch_sequences=tensor([[88, 96,  3, 23, 44, 65, 55,  0, 66, 49],
        [25, 85, 10, 86, 31, 55, 31, 51,  4, 56],
        [28, 38, 76,  8,  0,  0,  0,  0,  0,  0],
        [21, 97, 46, 75, 73,  0,  0,  0,  0,  0]])
minibatch_attention_mask=tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]])
Policy Train epoch [1/1]:   0%|          | 0/2 [00:00<?, ?it/s]
(DeepSpeedPolicyWorkerBase pid=182679) micro_batches_per_mini_batch=2
(DeepSpeedPolicyWorkerBase pid=182679) experience.sequences=tensor([[88, 96,  3, 23, 44, 65, 55,  0, 66, 49],
(DeepSpeedPolicyWorkerBase pid=182679)         [28, 38, 76,  8,  0,  0,  0,  0,  0,  0]])
Policy Train epoch [1/1]:  50%|█████     | 1/2 [00:01<00:01,  1.28s/it, pg=-0.203, glen=4, policy_lr=1e-6, ent=3.08]
(DeepSpeedPolicyWorkerBase pid=182679) experience.sequences=tensor([[25, 85, 10, 86, 31, 55, 31, 51,  4, 56],
(DeepSpeedPolicyWorkerBase pid=182679)         [21, 97, 46, 75, 73,  0,  0,  0,  0,  0]])

Follow-up

This only does balanced microbatching at the individual worker level. We can further balance at the minibatch level sent to each worker at a global level to ensure every DP worker has roughly the same amount of work.

There's also different algorithms that we could plug into the microbatch chunking step.

We should also update the megatron codepath as a followup.

Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces token-based microbatch chunking, which is a great improvement for GPU efficiency and memory management. The implementation is solid, with a new TokenBasedBatchIterator and a balanced_binpacking algorithm. I've identified a few critical and high-severity issues related to edge cases that could lead to crashes or out-of-memory errors. Specifically, the padding microbatch creation has a metadata bug, the bin packing algorithm doesn't handle sequences larger than the token limit, and the batch reordering logic can crash when a worker has no data. My review comments include suggestions to fix these issues to make the implementation more robust.

Comment on lines 203 to 219
def _create_padding_microbatch(self) -> TrainingInputBatch:
"""Create a padding microbatch."""
data = TrainingInputBatch(
{
"sequences": torch.ones((1, 1), dtype=int, device="cpu"),
"attention_mask": torch.zeros((1, 1), dtype=int, device="cpu"),
"action_log_probs": torch.zeros((1, 1), device="cpu"),
"base_action_log_probs": torch.zeros((1, 1), device="cpu"),
"values": torch.zeros((1, 1), device="cpu"),
"returns": torch.zeros((1, 1), device="cpu"),
"advantages": torch.zeros((1, 1), device="cpu"),
"loss_mask": torch.zeros((1, 1), dtype=int, device="cpu"),
"response_mask": torch.zeros((1, 1), dtype=int, device="cpu"),
}
)
data.metadata = self._data.metadata
return data
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The _create_padding_microbatch method incorrectly copies metadata from the original data batch (self._data.metadata). This is problematic because the padding batch will inherit properties like response_length from a real data batch, which will cause indexing errors and crashes during the forward_backward pass when it tries to operate on the small padding tensors (e.g., shape (1,1)).

A padding microbatch should have its own minimal, safe metadata.

Additionally, the keys for the padding batch are hardcoded. This could be brittle if the TrainingInputBatch schema changes. A more robust solution would be to dynamically create dummy tensors for all keys present in self._data.

    def _create_padding_microbatch(self) -> TrainingInputBatch:
        """Create a padding microbatch."""
        padding_data = {}
        # Create dummy tensors for all keys present in the original data
        for key, value in self._data.items():
            if value is not None:
                # Create a (1, 1) tensor with the same dtype.
                padding_data[key] = torch.zeros((1, 1), dtype=value.dtype, device="cpu")
            else:
                padding_data[key] = None

        # Ensure masks are correctly typed and zeroed out.
        if "attention_mask" in padding_data:
            padding_data["attention_mask"] = torch.zeros((1, 1), dtype=torch.long, device="cpu")
        if "loss_mask" in padding_data:
            padding_data["loss_mask"] = torch.zeros((1, 1), dtype=torch.long, device="cpu")
        if "response_mask" in padding_data:
            padding_data["response_mask"] = torch.zeros((1, 1), dtype=torch.long, device="cpu")
        if "sequences" in padding_data:
            padding_data["sequences"] = torch.ones((1, 1), dtype=torch.long, device="cpu")

        data = TrainingInputBatch(padding_data)
        data.metadata = self._data.metadata.copy()
        data.metadata["response_length"] = 0
        return data

>>> balanced_binpacking([8, 3, 5, 6, 2, 7], 11)
[[0, 4], [5, 1], [3, 2]]
"""
# TODO: Handle max(token_counts) > max_tokens_per_microbatch
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The TODO on this line points out a critical unhandled edge case: if a single sequence's token count exceeds max_tokens_per_microbatch, the current implementation will still place it in a microbatch by itself, violating the token limit. This could lead to out-of-memory errors, which this PR aims to prevent. You should add a check for this condition and raise a ValueError to alert the user to adjust their configuration.

Suggested change
# TODO: Handle max(token_counts) > max_tokens_per_microbatch
if token_counts and max(token_counts) > max_tokens_per_microbatch:
raise ValueError(
f"A sequence has {max(token_counts)} tokens, which is greater than max_tokens_per_microbatch ({max_tokens_per_microbatch}). "
"Please increase max_tokens_per_microbatch or filter out long sequences."
)

Comment on lines +259 to +261
if not non_padding_microbatches:
# TODO: Can this happen?
raise ValueError("Cannot reorder an empty list of microbatches.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code raises a ValueError if non_padding_microbatches is empty. This can legitimately happen if a data parallel worker receives a minibatch that results in zero non-padding microbatches (e.g., it contains only very short sequences that get packed by other workers), while other workers do have microbatches. In this scenario, the worker with no data will be padded with dummy microbatches to match the count of other workers, leading to an empty non_padding_microbatches list. Instead of crashing, the function should gracefully handle this by returning an empty TensorBatch.

Suggested change
if not non_padding_microbatches:
# TODO: Can this happen?
raise ValueError("Cannot reorder an empty list of microbatches.")
if not non_padding_microbatches:
# This can happen if a worker has no data and only padding batches.
# We must return an empty batch with the correct structure.
if not batches:
# This should not be reached if there are padding batches, but as a safeguard:
return TensorBatch({})
ref_batch = batches[0]
empty_data = {
key: torch.empty((0, *value.shape[1:]), dtype=value.dtype, device=value.device)
for key, value in ref_batch.items()
}
empty_batch = type(ref_batch)(empty_data)
empty_batch.metadata = ref_batch.metadata
return empty_batch

Signed-off-by: Justin Yu <[email protected]>
Comment on lines +99 to +102
def balanced_binpacking(token_counts: List[int], max_tokens_per_microbatch: int) -> List[List[int]]:
"""Chunk a list of token counts into microbatches so that each
microbatch's total token count does not exceed `max_tokens_per_microbatch`,
and the microbatches roughly balanced.
Copy link
Contributor Author

@justinvyu justinvyu Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verl uses a Kamarkar-Karp bin balancing algorithm instead, which minimizes differences across K bins, and estimates K with total_token_count / max_tokens_per_microbatch. But it doesn't guarantee that every bin's token count is actually below the max token configuration. The upside is that they define the bincount upfront and don't need to inject dummy padding batches.

I'm doing balanced binpacking with capacity constraint can guarantee the max token count constraint, but requires adding dummy padding batches since each worker can come up with a different number of bins.

Comment on lines +279 to +285
# Expect:
# - dp=0: 3 microbatches with [10], [10], [11]
# - dp=1: 3 microbatches with [10, 5], [10, 5], [<dummy padding batch>]
train_data.metadata["global_step"] = 0

results = ray.get(actor_group.async_run_ray_method("mesh", "forward", train_data))
results = ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", train_data))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this with some print statements and it works as expected, but any ideas how to assert these in this test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm yeah good question - is there a way you could mock the actual model forward and make some assertion on model length that way?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lemme try

critic_mini_batch_size: 256
micro_train_batch_size_per_gpu: 1
micro_forward_batch_size_per_gpu: 1
max_tokens_per_microbatch: -1 # TODO: Maybe split this between forward and train; -1 means no token-based chunking
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need some config validation to make sure that this is larger than the max response length? Maybe in utils/utils.py?

for microbatch_idx, microbatch in enumerate(microbatch_iterator):
microbatch_experience = BatchIterator.batch_to_experience(microbatch)
microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch)
microbatch_weight = len(microbatch) / len(minibatch)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing for cases where there's padding batches. This is because the padding batch cannot be empty and must have seqlen >= 2.

Example: with 2 real microbatches of 15, and then 1 padding, the weights would be: 15/30, 15/30, 2/30.

  • This counts on the dummy microbatch having a fully 0 loss mask, which zeros out this weight.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does seqlen have to be >= 2?

Comment on lines +279 to +285
# Expect:
# - dp=0: 3 microbatches with [10], [10], [11]
# - dp=1: 3 microbatches with [10, 5], [10, 5], [<dummy padding batch>]
train_data.metadata["global_step"] = 0

results = ray.get(actor_group.async_run_ray_method("mesh", "forward", train_data))
results = ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", train_data))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lemme try

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants