-
Notifications
You must be signed in to change notification settings - Fork 220
[skyrl-train] Add max_tokens_per_microbatch configuration
#847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[skyrl-train] Add max_tokens_per_microbatch configuration
#847
Conversation
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
Signed-off-by: Justin Yu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces token-based microbatch chunking, which is a great improvement for GPU efficiency and memory management. The implementation is solid, with a new TokenBasedBatchIterator and a balanced_binpacking algorithm. I've identified a few critical and high-severity issues related to edge cases that could lead to crashes or out-of-memory errors. Specifically, the padding microbatch creation has a metadata bug, the bin packing algorithm doesn't handle sequences larger than the token limit, and the batch reordering logic can crash when a worker has no data. My review comments include suggestions to fix these issues to make the implementation more robust.
| def _create_padding_microbatch(self) -> TrainingInputBatch: | ||
| """Create a padding microbatch.""" | ||
| data = TrainingInputBatch( | ||
| { | ||
| "sequences": torch.ones((1, 1), dtype=int, device="cpu"), | ||
| "attention_mask": torch.zeros((1, 1), dtype=int, device="cpu"), | ||
| "action_log_probs": torch.zeros((1, 1), device="cpu"), | ||
| "base_action_log_probs": torch.zeros((1, 1), device="cpu"), | ||
| "values": torch.zeros((1, 1), device="cpu"), | ||
| "returns": torch.zeros((1, 1), device="cpu"), | ||
| "advantages": torch.zeros((1, 1), device="cpu"), | ||
| "loss_mask": torch.zeros((1, 1), dtype=int, device="cpu"), | ||
| "response_mask": torch.zeros((1, 1), dtype=int, device="cpu"), | ||
| } | ||
| ) | ||
| data.metadata = self._data.metadata | ||
| return data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _create_padding_microbatch method incorrectly copies metadata from the original data batch (self._data.metadata). This is problematic because the padding batch will inherit properties like response_length from a real data batch, which will cause indexing errors and crashes during the forward_backward pass when it tries to operate on the small padding tensors (e.g., shape (1,1)).
A padding microbatch should have its own minimal, safe metadata.
Additionally, the keys for the padding batch are hardcoded. This could be brittle if the TrainingInputBatch schema changes. A more robust solution would be to dynamically create dummy tensors for all keys present in self._data.
def _create_padding_microbatch(self) -> TrainingInputBatch:
"""Create a padding microbatch."""
padding_data = {}
# Create dummy tensors for all keys present in the original data
for key, value in self._data.items():
if value is not None:
# Create a (1, 1) tensor with the same dtype.
padding_data[key] = torch.zeros((1, 1), dtype=value.dtype, device="cpu")
else:
padding_data[key] = None
# Ensure masks are correctly typed and zeroed out.
if "attention_mask" in padding_data:
padding_data["attention_mask"] = torch.zeros((1, 1), dtype=torch.long, device="cpu")
if "loss_mask" in padding_data:
padding_data["loss_mask"] = torch.zeros((1, 1), dtype=torch.long, device="cpu")
if "response_mask" in padding_data:
padding_data["response_mask"] = torch.zeros((1, 1), dtype=torch.long, device="cpu")
if "sequences" in padding_data:
padding_data["sequences"] = torch.ones((1, 1), dtype=torch.long, device="cpu")
data = TrainingInputBatch(padding_data)
data.metadata = self._data.metadata.copy()
data.metadata["response_length"] = 0
return data| >>> balanced_binpacking([8, 3, 5, 6, 2, 7], 11) | ||
| [[0, 4], [5, 1], [3, 2]] | ||
| """ | ||
| # TODO: Handle max(token_counts) > max_tokens_per_microbatch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The TODO on this line points out a critical unhandled edge case: if a single sequence's token count exceeds max_tokens_per_microbatch, the current implementation will still place it in a microbatch by itself, violating the token limit. This could lead to out-of-memory errors, which this PR aims to prevent. You should add a check for this condition and raise a ValueError to alert the user to adjust their configuration.
| # TODO: Handle max(token_counts) > max_tokens_per_microbatch | |
| if token_counts and max(token_counts) > max_tokens_per_microbatch: | |
| raise ValueError( | |
| f"A sequence has {max(token_counts)} tokens, which is greater than max_tokens_per_microbatch ({max_tokens_per_microbatch}). " | |
| "Please increase max_tokens_per_microbatch or filter out long sequences." | |
| ) |
| if not non_padding_microbatches: | ||
| # TODO: Can this happen? | ||
| raise ValueError("Cannot reorder an empty list of microbatches.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code raises a ValueError if non_padding_microbatches is empty. This can legitimately happen if a data parallel worker receives a minibatch that results in zero non-padding microbatches (e.g., it contains only very short sequences that get packed by other workers), while other workers do have microbatches. In this scenario, the worker with no data will be padded with dummy microbatches to match the count of other workers, leading to an empty non_padding_microbatches list. Instead of crashing, the function should gracefully handle this by returning an empty TensorBatch.
| if not non_padding_microbatches: | |
| # TODO: Can this happen? | |
| raise ValueError("Cannot reorder an empty list of microbatches.") | |
| if not non_padding_microbatches: | |
| # This can happen if a worker has no data and only padding batches. | |
| # We must return an empty batch with the correct structure. | |
| if not batches: | |
| # This should not be reached if there are padding batches, but as a safeguard: | |
| return TensorBatch({}) | |
| ref_batch = batches[0] | |
| empty_data = { | |
| key: torch.empty((0, *value.shape[1:]), dtype=value.dtype, device=value.device) | |
| for key, value in ref_batch.items() | |
| } | |
| empty_batch = type(ref_batch)(empty_data) | |
| empty_batch.metadata = ref_batch.metadata | |
| return empty_batch |
Signed-off-by: Justin Yu <[email protected]>
| def balanced_binpacking(token_counts: List[int], max_tokens_per_microbatch: int) -> List[List[int]]: | ||
| """Chunk a list of token counts into microbatches so that each | ||
| microbatch's total token count does not exceed `max_tokens_per_microbatch`, | ||
| and the microbatches roughly balanced. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Verl uses a Kamarkar-Karp bin balancing algorithm instead, which minimizes differences across K bins, and estimates K with total_token_count / max_tokens_per_microbatch. But it doesn't guarantee that every bin's token count is actually below the max token configuration. The upside is that they define the bincount upfront and don't need to inject dummy padding batches.
I'm doing balanced binpacking with capacity constraint can guarantee the max token count constraint, but requires adding dummy padding batches since each worker can come up with a different number of bins.
| # Expect: | ||
| # - dp=0: 3 microbatches with [10], [10], [11] | ||
| # - dp=1: 3 microbatches with [10, 5], [10, 5], [<dummy padding batch>] | ||
| train_data.metadata["global_step"] = 0 | ||
|
|
||
| results = ray.get(actor_group.async_run_ray_method("mesh", "forward", train_data)) | ||
| results = ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", train_data)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this with some print statements and it works as expected, but any ideas how to assert these in this test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm yeah good question - is there a way you could mock the actual model forward and make some assertion on model length that way?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lemme try
| critic_mini_batch_size: 256 | ||
| micro_train_batch_size_per_gpu: 1 | ||
| micro_forward_batch_size_per_gpu: 1 | ||
| max_tokens_per_microbatch: -1 # TODO: Maybe split this between forward and train; -1 means no token-based chunking |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need some config validation to make sure that this is larger than the max response length? Maybe in utils/utils.py?
| for microbatch_idx, microbatch in enumerate(microbatch_iterator): | ||
| microbatch_experience = BatchIterator.batch_to_experience(microbatch) | ||
| microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) | ||
| microbatch_weight = len(microbatch) / len(minibatch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a bit confusing for cases where there's padding batches. This is because the padding batch cannot be empty and must have seqlen >= 2.
Example: with 2 real microbatches of 15, and then 1 padding, the weights would be: 15/30, 15/30, 2/30.
- This counts on the dummy microbatch having a fully 0 loss mask, which zeros out this weight.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why does seqlen have to be >= 2?
| # Expect: | ||
| # - dp=0: 3 microbatches with [10], [10], [11] | ||
| # - dp=1: 3 microbatches with [10, 5], [10, 5], [<dummy padding batch>] | ||
| train_data.metadata["global_step"] = 0 | ||
|
|
||
| results = ray.get(actor_group.async_run_ray_method("mesh", "forward", train_data)) | ||
| results = ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", train_data)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lemme try
Summary
Introduces token-based chunking for microbatches, enforcing a
max_tokens_per_microbatchlimit on the total number of real tokens in a microbatch.Replaces the even sample-based chunking
BatchIteratorwith aBalancedBatchIteratorthat respects the microbatch token limits.After this PR, the global training batch sharding and forming minibatches is still even sample-based chunking. The new functionality is adding token-based chunking at the bottom microbatching layer, which enforces a total token maximum per microbatch.
Problem
Previously, microbatches created by chunking sequences evenly so that each microbatch was a certain batch size (
micro_train_batch_size_per_gpu), which could lead to:Solution
Introduce
balanced_binpackingandTokenBasedBatchIteratorutilities that:API changes
config.trainer.max_tokens_per_microbatchconfiguration. Defaults to the old behavior (sample based microbatching).TokenBasedBatchIteratorand renamesBatchIterator -> SampleBasedBatchIterator.Worker.forwardto also use theBatchIteratorinterface, rather than callTrainingBatch.chunkdirectly.TokenBasedBatchIteratorwhenmax_tokens_per_microbatchis set for theforward,Critic.ppo_trainandPolicy.ppo_trainmethods.Example
Example ppo_train call from the unit test:
Follow-up
This only does balanced microbatching at the individual worker level. We can further balance at the minibatch level sent to each worker at a global level to ensure every DP worker has roughly the same amount of work.
There's also different algorithms that we could plug into the microbatch chunking step.
We should also update the megatron codepath as a followup.