From cb0af6e3263f5c0c08b80dff115bd3f1bd3a08e6 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 13:11:38 -0800 Subject: [PATCH 01/12] add batch iterator interface Signed-off-by: Justin Yu --- .../skyrl_train/workers/worker_utils.py | 67 +++++++++++++------ 1 file changed, 46 insertions(+), 21 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 897d032ea..c1d52ad92 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -1,7 +1,7 @@ import math from skyrl_train.dataset.replay_buffer import Experience from typing import List, Dict -from skyrl_train.training_batch import TrainingInputBatch +from skyrl_train.training_batch import TrainingInputBatch, TensorBatch def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: @@ -16,33 +16,22 @@ def reduce_metrics(metrics: Dict[str, List[float]]) -> Dict[str, float]: return reduced_metrics -class BatchIterator: - """A simple iterator to yield micro batches of data from the training batch.""" - - def __init__(self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False): +class BaseBatchIterator: + def __init__(self, data: TrainingInputBatch): self.data = data - self.sample_batch_size = sample_batch_size - self.total_batch_size = data.batch_size - self.drop_last = drop_last - assert not drop_last, "drop_last is not supported yet" - num_micro_batches = self.total_batch_size / self.sample_batch_size - self.num_micro_batches = int(num_micro_batches) if drop_last else math.ceil(num_micro_batches) - # TODO: switch to tensordict.map_iter if possible - self._chunks = self.data.chunk(self.sample_batch_size) - self._iter = iter(self._chunks) def __len__(self): - return self.num_micro_batches + raise NotImplementedError + + def __next__(self) -> TrainingInputBatch: + raise NotImplementedError def __iter__(self): return self - def __next__(self) -> TrainingInputBatch: - try: - return next(self._iter) - except StopIteration: - self._iter = iter(self._chunks) - raise StopIteration + def reorder_microbatches(self, microbatches: List[TensorBatch]) -> TensorBatch: + """Reorder and combine output microbatches to form a single minibatch output.""" + raise NotImplementedError @staticmethod def batch_to_experience(batch: TrainingInputBatch): @@ -67,3 +56,39 @@ def batch_to_experience(batch: TrainingInputBatch): metadata=batch.metadata, ) return exp + + +class SampleBasedBatchIterator: + """A simple iterator to yield microbatches of the same size from the training batch.""" + + def __init__(self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False): + self.data = data + self.sample_batch_size = sample_batch_size + self.total_batch_size = data.batch_size + self.drop_last = drop_last + assert not drop_last, "drop_last is not supported yet" + num_micro_batches = self.total_batch_size / self.sample_batch_size + self.num_micro_batches = int(num_micro_batches) if drop_last else math.ceil(num_micro_batches) + # TODO: switch to tensordict.map_iter if possible + self._chunks = self.data.chunk(self.sample_batch_size) + self._iter = iter(self._chunks) + + def __len__(self): + return self.num_micro_batches + + def __iter__(self): + return self + + def __next__(self) -> TrainingInputBatch: + try: + return next(self._iter) + except StopIteration: + self._iter = iter(self._chunks) + raise StopIteration + + def reorder_microbatches(self, microbatches: List[TensorBatch]) -> TensorBatch: + """Concatenate output microbatches to form a single minibatch output. + + This iterator evenly splits the minibatch into microbatches of the same size, + so there's no need to reorder.""" + return TensorBatch.cat(microbatches) From 18a1792f28e6b0ed93af4fe0a3c8f6960c966cee Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 13:16:06 -0800 Subject: [PATCH 02/12] update imports Signed-off-by: Justin Yu --- .../workers/megatron/megatron_worker.py | 6 +++--- skyrl-train/skyrl_train/workers/worker.py | 14 +++++++------- skyrl-train/skyrl_train/workers/worker_utils.py | 16 ++++++++-------- skyrl-train/tests/cpu/test_trainer.py | 4 ++-- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 080b9bb42..2587d6f95 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -30,7 +30,7 @@ from skyrl_train.utils.utils import update_model_config, str_to_torch_dtype from skyrl_train.utils.constants import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl_train.training_batch import TrainingOutputBatch -from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import SampleBasedBatchIterator, reduce_metrics from skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -512,7 +512,7 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": Since we want megatron to handle gradient accumulation over micro batches, we directly pass mini batches into the worker MegatronModelWrapper.forward_backward_mini_batch method. """ - dataloader = BatchIterator( + dataloader = SampleBasedBatchIterator( train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False ) @@ -538,7 +538,7 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # TODO: Convert this into 2 loops for minibatches and microbatches. micro_buffer = [] for local_step, microbatch in enumerate(pbar): - experience = BatchIterator.batch_to_experience(microbatch) + experience = SampleBasedBatchIterator.batch_to_experience(microbatch) experience.to_device(torch.cuda.current_device()) sequences = experience.sequences attention_mask = experience.attention_mask diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 361a0777e..5afe72a10 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -32,7 +32,7 @@ from loguru import logger from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl -from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import SampleBasedBatchIterator, reduce_metrics from skyrl_train.dataset.replay_buffer import Experience from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient @@ -724,7 +724,7 @@ def optim_step(self) -> float: def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: global_step = train_data.metadata["global_step"] - minibatch_iterator = BatchIterator( + minibatch_iterator = SampleBasedBatchIterator( train_data, sample_batch_size=self.policy_mini_batch_size_per_gpu, drop_last=False ) @@ -784,14 +784,14 @@ def record_status(status: Dict[str, float]): disable=not self.strategy.is_rank_0(), ) for minibatch in minibatch_pbar: - microbatch_iterator = BatchIterator( + microbatch_iterator = SampleBasedBatchIterator( minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False ) num_microbatches = len(microbatch_iterator) microbatch_weight = 1.0 / num_microbatches for microbatch_idx, microbatch in enumerate(microbatch_iterator): - microbatch_experience = BatchIterator.batch_to_experience(microbatch) + microbatch_experience = SampleBasedBatchIterator.batch_to_experience(microbatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) # Record status for all but the last microbatch in the minibatch. @@ -991,7 +991,7 @@ def save_hf_model(self, export_dir: str, tokenizer): def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: global_step = train_data.metadata["global_step"] - minibatch_iterator = BatchIterator( + minibatch_iterator = SampleBasedBatchIterator( train_data, sample_batch_size=self.critic_mini_batch_size_per_gpu, drop_last=False ) @@ -1019,14 +1019,14 @@ def record_status(status: Dict[str, float]): disable=not self.strategy.is_rank_0(), ) for minibatch in minibatch_pbar: - microbatch_iterator = BatchIterator( + microbatch_iterator = SampleBasedBatchIterator( minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False ) num_microbatches = len(microbatch_iterator) microbatch_weight = 1.0 / num_microbatches for microbatch_idx, microbatch in enumerate(microbatch_iterator): - microbatch_experience = BatchIterator.batch_to_experience(microbatch) + microbatch_experience = SampleBasedBatchIterator.batch_to_experience(microbatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) if microbatch_idx < num_microbatches - 1: diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index c1d52ad92..1417e2475 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -29,8 +29,8 @@ def __next__(self) -> TrainingInputBatch: def __iter__(self): return self - def reorder_microbatches(self, microbatches: List[TensorBatch]) -> TensorBatch: - """Reorder and combine output microbatches to form a single minibatch output.""" + def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: + """Reorder and combine output batches to form a single output.""" raise NotImplementedError @staticmethod @@ -58,8 +58,8 @@ def batch_to_experience(batch: TrainingInputBatch): return exp -class SampleBasedBatchIterator: - """A simple iterator to yield microbatches of the same size from the training batch.""" +class SampleBasedBatchIterator(BaseBatchIterator): + """A simple iterator to yield batches of the same size from the training input.""" def __init__(self, data: TrainingInputBatch, sample_batch_size: int, drop_last: bool = False): self.data = data @@ -86,9 +86,9 @@ def __next__(self) -> TrainingInputBatch: self._iter = iter(self._chunks) raise StopIteration - def reorder_microbatches(self, microbatches: List[TensorBatch]) -> TensorBatch: - """Concatenate output microbatches to form a single minibatch output. + def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: + """Concatenate output batches to form a single output. - This iterator evenly splits the minibatch into microbatches of the same size, + This iterator evenly splits the input into batches of the same size, so there's no need to reorder.""" - return TensorBatch.cat(microbatches) + return TensorBatch.cat(batches) diff --git a/skyrl-train/tests/cpu/test_trainer.py b/skyrl-train/tests/cpu/test_trainer.py index 599f8a3f4..140da8df9 100644 --- a/skyrl-train/tests/cpu/test_trainer.py +++ b/skyrl-train/tests/cpu/test_trainer.py @@ -14,7 +14,7 @@ from skyrl_train.training_batch import TrainingInputBatch import numpy as np from skyrl_train.workers.worker import PolicyWorkerBase, CriticWorkerBase -from skyrl_train.workers.worker_utils import BatchIterator +from skyrl_train.workers.worker_utils import SampleBasedBatchIterator from skyrl_train.utils.utils import validate_batch_sizes from skyrl_train.config.utils import get_default_config from tests.cpu.util import example_dummy_config @@ -559,7 +559,7 @@ def mock_policy_forward_backward(experience, microbatch_weight): policy_worker.record_memory = False # Calculate expected values based on new accumulation logic - dataloader = BatchIterator( + dataloader = SampleBasedBatchIterator( dummy_databatch, sample_batch_size=cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False ) total_micro_batches = len(dataloader) # Should be 6 From 8b029aaac7b9522f90e1375f6cfbd9ea7939e5c3 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 13:24:01 -0800 Subject: [PATCH 03/12] reorder microbatch outputs Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 5afe72a10..d5e48910a 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -309,12 +309,15 @@ def forward( """ # run in micro batches of cfg.trainer.micro_forward_batch_size_per_gpu # TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc. - micro_batches = data.chunk(self.cfg.trainer.micro_forward_batch_size_per_gpu) + microbatch_iterator = SampleBasedBatchIterator( + data, sample_batch_size=self.cfg.trainer.micro_forward_batch_size_per_gpu, drop_last=False + ) outputs = [] - for micro_batch in micro_batches: - outputs.append(self._forward_micro_batch(micro_batch)) - output = TrainingOutputBatch.cat(outputs) + for microbatch in microbatch_iterator: + outputs.append(self._forward_micro_batch(microbatch)) + output = microbatch_iterator.reorder_and_combine_batches(outputs) + if output.device is not None and output.device != torch.device("cpu"): output = output.to("cpu") return output From 267104cc97ee43ca6d66f4586aa83ab47e21f980 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 13:25:07 -0800 Subject: [PATCH 04/12] use base class for batch_to_exp Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/megatron/megatron_worker.py | 4 ++-- skyrl-train/skyrl_train/workers/worker.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py index 2587d6f95..c2e07d7a1 100644 --- a/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py +++ b/skyrl-train/skyrl_train/workers/megatron/megatron_worker.py @@ -30,7 +30,7 @@ from skyrl_train.utils.utils import update_model_config, str_to_torch_dtype from skyrl_train.utils.constants import SKYRL_WORKER_NCCL_TIMEOUT_IN_S from skyrl_train.training_batch import TrainingOutputBatch -from skyrl_train.workers.worker_utils import SampleBasedBatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import BaseBatchIterator, SampleBasedBatchIterator, reduce_metrics from skyrl_train.workers.worker import ( PolicyWorkerBase, RefWorkerBase, @@ -538,7 +538,7 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch": # TODO: Convert this into 2 loops for minibatches and microbatches. micro_buffer = [] for local_step, microbatch in enumerate(pbar): - experience = SampleBasedBatchIterator.batch_to_experience(microbatch) + experience = BaseBatchIterator.batch_to_experience(microbatch) experience.to_device(torch.cuda.current_device()) sequences = experience.sequences attention_mask = experience.attention_mask diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index d5e48910a..997cea6f8 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -32,7 +32,7 @@ from loguru import logger from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl -from skyrl_train.workers.worker_utils import SampleBasedBatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import BaseBatchIterator, SampleBasedBatchIterator, reduce_metrics from skyrl_train.dataset.replay_buffer import Experience from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient @@ -794,7 +794,7 @@ def record_status(status: Dict[str, float]): microbatch_weight = 1.0 / num_microbatches for microbatch_idx, microbatch in enumerate(microbatch_iterator): - microbatch_experience = SampleBasedBatchIterator.batch_to_experience(microbatch) + microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) # Record status for all but the last microbatch in the minibatch. @@ -1029,7 +1029,7 @@ def record_status(status: Dict[str, float]): microbatch_weight = 1.0 / num_microbatches for microbatch_idx, microbatch in enumerate(microbatch_iterator): - microbatch_experience = SampleBasedBatchIterator.batch_to_experience(microbatch) + microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) if microbatch_idx < num_microbatches - 1: From e9efe6a72ce925d654a0e62a70dca3ad88c67563 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 13:31:06 -0800 Subject: [PATCH 05/12] add token based batch iterator impl Signed-off-by: Justin Yu --- .../skyrl_train/workers/worker_utils.py | 209 +++++++++++++++++- 1 file changed, 204 insertions(+), 5 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 1417e2475..5eef3615b 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -1,6 +1,11 @@ +import heapq import math +from typing import List, Dict, Iterator + +import torch +import torch.distributed as dist + from skyrl_train.dataset.replay_buffer import Experience -from typing import List, Dict from skyrl_train.training_batch import TrainingInputBatch, TensorBatch @@ -23,12 +28,9 @@ def __init__(self, data: TrainingInputBatch): def __len__(self): raise NotImplementedError - def __next__(self) -> TrainingInputBatch: + def __iter__(self) -> Iterator[TrainingInputBatch]: raise NotImplementedError - def __iter__(self): - return self - def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: """Reorder and combine output batches to form a single output.""" raise NotImplementedError @@ -92,3 +94,200 @@ def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch This iterator evenly splits the input into batches of the same size, so there's no need to reorder.""" return TensorBatch.cat(batches) + + +def balanced_binpacking(token_counts: List[int], max_tokens_per_microbatch: int) -> List[List[int]]: + """Chunk a list of token counts into microbatches so that each + microbatch's total token count does not exceed `max_tokens_per_microbatch`, + and the microbatches roughly balanced. + + Roughly balance by assigning sequences to the microbatch with + the least number of tokens so far. + + Args: + token_counts: List of token counts for each sample. + max_tokens_per_microbatch: Maximum total tokens allowed per microbatch. + + Returns: + A list of microbatches, where each microbatch is a list of indices (ints) + referring to entries in `token_counts`. + + >>> balanced_binpacking([10, 10, 5, 5], 15) + [[0, 2], [1, 3]] + >>> balanced_binpacking([10, 1, 1, 1, 1, 1], 10) + [[0], [1, 2, 3, 4, 5]] + >>> balanced_binpacking([8, 3, 5, 6, 2, 7], 11) + [[0, 4], [5, 1], [3, 2]] + """ + # TODO: Handle max(token_counts) > max_tokens_per_microbatch + + # Create list of (index, token_count) pairs and sort by token count descending + seq_lens = [(i, seq_len) for i, seq_len in enumerate(token_counts)] + seq_lens.sort(key=lambda x: x[1], reverse=True) + + # Track microbatch indices and their current token counts + microbatch_indices: List[List[int]] = [] + + # Heap to track the total number of tokens in each microbatch + microbatch_tokens_heap = [] # (current_total, bin_idx) + + for idx, seq_len in seq_lens: + placed = False + + # Look for an existing microbatch with the least number of tokens + # that can fit the sequence without exceeding the token limit. + if microbatch_tokens_heap: + microbatch_len, i = microbatch_tokens_heap[0] + new_microbatch_len = microbatch_len + seq_len + if new_microbatch_len <= max_tokens_per_microbatch: + microbatch_indices[i].append(idx) + heapq.heapreplace(microbatch_tokens_heap, (new_microbatch_len, i)) + placed = True + + # If no microbatch can fit the sequence, create a new microbatch. + if not placed: + microbatch_indices.append([idx]) + heapq.heappush(microbatch_tokens_heap, (seq_len, len(microbatch_indices) - 1)) + + return microbatch_indices + + +class TokenBasedBatchIterator(BaseBatchIterator): + """An iterator that chunks microbatches based on real token count. + Packs samples into microbatches, ensuring each microbatch doesn't exceed + max_tokens_per_microbatch. All data parallel workers will have the same number of + microbatches (where padding microbatches are added if needed). + """ + + def __init__( + self, + data: TrainingInputBatch, + max_tokens_per_microbatch: int, + ): + """ + Args: + data: The training input batch to chunk + max_tokens_per_microbatch: Maximum number of tokens per microbatch + """ + self._data = data + self._max_tokens_per_microbatch = max_tokens_per_microbatch + + # Compute token counts per sample using attention_mask + attention_mask = data["attention_mask"] + # Count non-padding tokens per sample + self._token_counts = attention_mask.sum(dim=1).cpu().tolist() # [batch_size] + + # Create microbatches based on token count + # TODO: Allow for different chunking strategies. + self._microbatches = balanced_binpacking(self._token_counts, self._max_tokens_per_microbatch) + + # Synchronize the number of microbatches across all DP workers + max_num_microbatches = self._sync_num_microbatches() + + self._num_padding_microbatches = max_num_microbatches - len(self._microbatches) + + @property + def num_microbatches(self) -> int: + return len(self._microbatches) + self._num_padding_microbatches + + def _create_microbatch_from_indices(self, indices: List[int]) -> TrainingInputBatch: + """Create a TrainingInputBatch from a list of sample indices.""" + indices_tensor = torch.tensor(indices, dtype=torch.long, device="cpu") + selected_data = {} + for key, value in self._data.items(): + selected_data[key] = value[indices_tensor] + microbatch = TrainingInputBatch(selected_data) + microbatch.metadata = self._data.metadata + return microbatch + + def _create_padding_microbatch(self) -> TrainingInputBatch: + """Create a padding microbatch.""" + data = TrainingInputBatch( + { + "sequences": torch.ones((1, 1), dtype=int, device="cpu"), + "attention_mask": torch.zeros((1, 1), dtype=int, device="cpu"), + "action_log_probs": torch.zeros((1, 1), device="cpu"), + "base_action_log_probs": torch.zeros((1, 1), device="cpu"), + "values": torch.zeros((1, 1), device="cpu"), + "returns": torch.zeros((1, 1), device="cpu"), + "advantages": torch.zeros((1, 1), device="cpu"), + "loss_mask": torch.zeros((1, 1), dtype=int, device="cpu"), + "response_mask": torch.zeros((1, 1), dtype=int, device="cpu"), + } + ) + data.metadata = self._data.metadata + return data + + def _sync_num_microbatches(self) -> int: + """Ensure all DP workers have the same number of micro batches.""" + local_num_microbatches = len(self._microbatches) + + if not dist.is_initialized(): + return local_num_microbatches + + # Get the maximum number of batches across all DP workers + if torch.cuda.is_available(): + device = torch.cuda.current_device() + else: + device = torch.device("cpu") + num_microbatches_tensor = torch.tensor(local_num_microbatches, dtype=torch.long, device=device) + dist.all_reduce(num_microbatches_tensor, op=dist.ReduceOp.MAX) + return num_microbatches_tensor.item() + + def __len__(self): + return len(self._microbatches) + self._num_padding_microbatches + + def __iter__(self): + for microbatch in self._microbatches: + yield self._create_microbatch_from_indices(microbatch) + for _ in range(self._num_padding_microbatches): + yield self._create_padding_microbatch() + + def reorder_microbatches(self, microbatches: List[TensorBatch]) -> TensorBatch: + """Reorder microbatch data into a single batch with the same order as the original data. + Example: [[0, 2], [1, 3]] -> [0, 1, 2, 3] + Args: + microbatches: List of microbatches to reorder. + Returns: + A single reordered batch. + """ + # TODO: Move this stuff to utility functions. + non_padding_microbatches = microbatches[: len(microbatches) - self._num_padding_microbatches] + + if not non_padding_microbatches: + # TODO: Can this happen? + raise ValueError("Cannot reorder an empty list of microbatches.") + + # Create a reverse mapping of original idx -> (microbatch idx, sample idx) + original_idx_to_microbatch_idx = {} + + for microbatch_idx, original_indices in enumerate(self._microbatches): + for sample_idx, original_idx in enumerate(original_indices): + original_idx_to_microbatch_idx[original_idx] = (microbatch_idx, sample_idx) + + # Get reference microbatch to know keys and tensor shapes + ref_microbatch = non_padding_microbatches[0] + reordered_data = {} + + for key, ref_value in ref_microbatch.items(): + # Get shape of a single sample (remove batch dimension) + sample_shape = ref_value.shape[1:] + device = ref_value.device + dtype = ref_value.dtype + + # Pre-allocate output tensor: [batch_size, *sample_shape] + batch_size = len(self._token_counts) + output_tensor = torch.zeros((batch_size, *sample_shape), dtype=dtype, device=device) + + # Copy each sample directly into the correct position + for original_idx in range(batch_size): + microbatch_idx, sample_idx = original_idx_to_microbatch_idx[original_idx] + source_tensor = non_padding_microbatches[microbatch_idx][key] + output_tensor[original_idx] = source_tensor[sample_idx] + + reordered_data[key] = output_tensor + + # Create single TensorBatch with reordered data + reordered_batch = type(ref_microbatch)(reordered_data) + reordered_batch.metadata = ref_microbatch.metadata + return reordered_batch From ee1f3e12016fe3b369bff31d8891e5ba8d9ac49f Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 13:32:56 -0800 Subject: [PATCH 06/12] add config Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/config/ppo_base_config.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/skyrl-train/skyrl_train/config/ppo_base_config.yaml b/skyrl-train/skyrl_train/config/ppo_base_config.yaml index 89c16b165..8337c6071 100644 --- a/skyrl-train/skyrl_train/config/ppo_base_config.yaml +++ b/skyrl-train/skyrl_train/config/ppo_base_config.yaml @@ -184,6 +184,7 @@ trainer: critic_mini_batch_size: 256 micro_train_batch_size_per_gpu: 1 micro_forward_batch_size_per_gpu: 1 + max_tokens_per_microbatch: -1 # TODO: Maybe split this between forward and train; -1 means no token-based chunking update_ref_every_epoch: false use_sample_packing: true eval_batch_size: 1024 From a3b3320e5480c292d227d21c9b55df3eff9a131e Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 13:41:36 -0800 Subject: [PATCH 07/12] use token based batch iterator Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 24 +++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 997cea6f8..a5042c707 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -32,7 +32,12 @@ from loguru import logger from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl -from skyrl_train.workers.worker_utils import BaseBatchIterator, SampleBasedBatchIterator, reduce_metrics +from skyrl_train.workers.worker_utils import ( + BaseBatchIterator, + SampleBasedBatchIterator, + TokenBasedBatchIterator, + reduce_metrics, +) from skyrl_train.dataset.replay_buffer import Experience from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient @@ -787,11 +792,18 @@ def record_status(status: Dict[str, float]): disable=not self.strategy.is_rank_0(), ) for minibatch in minibatch_pbar: - microbatch_iterator = SampleBasedBatchIterator( - minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False - ) - num_microbatches = len(microbatch_iterator) - microbatch_weight = 1.0 / num_microbatches + if self.cfg.trainer.max_tokens_per_microbatch > 0: + microbatch_iterator = TokenBasedBatchIterator( + minibatch, max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch + ) + num_microbatches = len(microbatch_iterator) + microbatch_weight = 1.0 / num_microbatches # TODO: return the correct weight from the iterator + else: + microbatch_iterator = SampleBasedBatchIterator( + minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False + ) + num_microbatches = len(microbatch_iterator) + microbatch_weight = 1.0 / num_microbatches for microbatch_idx, microbatch in enumerate(microbatch_iterator): microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) From 9f54515880b7996c2c218047ed97d701c46e27ed Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 13:52:48 -0800 Subject: [PATCH 08/12] clean up impl Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 34 +++++++++++++---------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index a5042c707..81f59747f 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -48,6 +48,15 @@ _SET_AFFINITY = False +def _get_microbatch_iterator( + data: TrainingInputBatch, micro_batch_size: int, max_tokens_per_microbatch: int +) -> BaseBatchIterator: + if max_tokens_per_microbatch > 0: + return TokenBasedBatchIterator(data, max_tokens_per_microbatch=max_tokens_per_microbatch) + else: + return SampleBasedBatchIterator(data, sample_batch_size=micro_batch_size, drop_last=False) + + # Adapted from OpenRLHF: https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/trainer/ray/launcher.py#L17 class DistributedTorchRayActor: def __init__( @@ -314,8 +323,10 @@ def forward( """ # run in micro batches of cfg.trainer.micro_forward_batch_size_per_gpu # TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc. - microbatch_iterator = SampleBasedBatchIterator( - data, sample_batch_size=self.cfg.trainer.micro_forward_batch_size_per_gpu, drop_last=False + microbatch_iterator = _get_microbatch_iterator( + data, + micro_batch_size=self.cfg.trainer.micro_forward_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch, ) outputs = [] @@ -792,21 +803,16 @@ def record_status(status: Dict[str, float]): disable=not self.strategy.is_rank_0(), ) for minibatch in minibatch_pbar: - if self.cfg.trainer.max_tokens_per_microbatch > 0: - microbatch_iterator = TokenBasedBatchIterator( - minibatch, max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch - ) - num_microbatches = len(microbatch_iterator) - microbatch_weight = 1.0 / num_microbatches # TODO: return the correct weight from the iterator - else: - microbatch_iterator = SampleBasedBatchIterator( - minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False - ) - num_microbatches = len(microbatch_iterator) - microbatch_weight = 1.0 / num_microbatches + microbatch_iterator = _get_microbatch_iterator( + minibatch, + micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch, + ) + num_microbatches = len(microbatch_iterator) for microbatch_idx, microbatch in enumerate(microbatch_iterator): microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) + microbatch_weight = len(microbatch) / len(minibatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) # Record status for all but the last microbatch in the minibatch. From 8e301d3b4c1f80f3a47aa30c711d355559e695a5 Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 13:54:49 -0800 Subject: [PATCH 09/12] update critic ppo_train Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 81f59747f..23e0d717e 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -1040,14 +1040,16 @@ def record_status(status: Dict[str, float]): disable=not self.strategy.is_rank_0(), ) for minibatch in minibatch_pbar: - microbatch_iterator = SampleBasedBatchIterator( - minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False + microbatch_iterator = _get_microbatch_iterator( + minibatch, + micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, + max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch, ) num_microbatches = len(microbatch_iterator) - microbatch_weight = 1.0 / num_microbatches for microbatch_idx, microbatch in enumerate(microbatch_iterator): microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) + microbatch_weight = len(microbatch) / len(minibatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) if microbatch_idx < num_microbatches - 1: From b2f64994a2a0acc51169c2f772d93f3311cbe1cd Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 14:00:08 -0800 Subject: [PATCH 10/12] fix fn definition Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index 5eef3615b..fc3f5adb8 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -243,16 +243,18 @@ def __iter__(self): for _ in range(self._num_padding_microbatches): yield self._create_padding_microbatch() - def reorder_microbatches(self, microbatches: List[TensorBatch]) -> TensorBatch: - """Reorder microbatch data into a single batch with the same order as the original data. + def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: + """Reorder and combine output batches into a single batch with + the same order as the original input data. + Example: [[0, 2], [1, 3]] -> [0, 1, 2, 3] + Args: - microbatches: List of microbatches to reorder. + batches: List of microbatches to reorder. Returns: A single reordered batch. """ - # TODO: Move this stuff to utility functions. - non_padding_microbatches = microbatches[: len(microbatches) - self._num_padding_microbatches] + non_padding_microbatches = batches[: len(batches) - self._num_padding_microbatches] if not non_padding_microbatches: # TODO: Can this happen? From ae7d981dcb985c646f92dfad6230098e2f80d26e Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Tue, 6 Jan 2026 14:01:02 -0800 Subject: [PATCH 11/12] fix docstring Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index 23e0d717e..ee620ec52 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -319,9 +319,10 @@ def forward( ) -> TrainingOutputBatch: """Run forward pass on the input batch in inference mode. - This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.trainer.micro_forward_batch_size_per_gpu`. + This is a wrapper around `_forward_micro_batch` that runs in micro batches. + Uses token-based chunking if `max_tokens_per_microbatch` is configured, otherwise + falls back to sample-based chunking with `micro_forward_batch_size_per_gpu`. """ - # run in micro batches of cfg.trainer.micro_forward_batch_size_per_gpu # TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc. microbatch_iterator = _get_microbatch_iterator( data, From 9335c33f13e1e28084c74a1c01073c018fd56b8c Mon Sep 17 00:00:00 2001 From: Justin Yu Date: Wed, 7 Jan 2026 01:32:31 +0000 Subject: [PATCH 12/12] fix dummy microbatch case Signed-off-by: Justin Yu --- skyrl-train/skyrl_train/workers/worker.py | 7 ++ .../skyrl_train/workers/worker_utils.py | 30 +++++--- .../tests/gpu/gpu_ci/test_ppo_train.py | 72 +++++++++++++++++++ 3 files changed, 98 insertions(+), 11 deletions(-) diff --git a/skyrl-train/skyrl_train/workers/worker.py b/skyrl-train/skyrl_train/workers/worker.py index ee620ec52..791e0df65 100644 --- a/skyrl-train/skyrl_train/workers/worker.py +++ b/skyrl-train/skyrl_train/workers/worker.py @@ -811,11 +811,16 @@ def record_status(status: Dict[str, float]): ) num_microbatches = len(microbatch_iterator) + temp = [] for microbatch_idx, microbatch in enumerate(microbatch_iterator): microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) microbatch_weight = len(microbatch) / len(minibatch) status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) + print("!!! loss", self.mesh_rank, status["final_loss"]) + + temp.append(microbatch["attention_mask"].sum().item()) + # Record status for all but the last microbatch in the minibatch. # The last microbatch should be recorded after the optimizer step. if microbatch_idx < num_microbatches - 1: @@ -830,6 +835,8 @@ def record_status(status: Dict[str, float]): if grad_norm is not None: status["raw_grad_norm"] = grad_norm + # print(self.mesh_rank, temp) + if self.record_memory: self.save_memory_snapshot(global_step, local_step) diff --git a/skyrl-train/skyrl_train/workers/worker_utils.py b/skyrl-train/skyrl_train/workers/worker_utils.py index fc3f5adb8..2913635a5 100644 --- a/skyrl-train/skyrl_train/workers/worker_utils.py +++ b/skyrl-train/skyrl_train/workers/worker_utils.py @@ -202,17 +202,22 @@ def _create_microbatch_from_indices(self, indices: List[int]) -> TrainingInputBa def _create_padding_microbatch(self) -> TrainingInputBatch: """Create a padding microbatch.""" + seq_len = 2 + num_actions = self._data.metadata["response_length"] + batch_size = 1 + data = TrainingInputBatch( { - "sequences": torch.ones((1, 1), dtype=int, device="cpu"), - "attention_mask": torch.zeros((1, 1), dtype=int, device="cpu"), - "action_log_probs": torch.zeros((1, 1), device="cpu"), - "base_action_log_probs": torch.zeros((1, 1), device="cpu"), - "values": torch.zeros((1, 1), device="cpu"), - "returns": torch.zeros((1, 1), device="cpu"), - "advantages": torch.zeros((1, 1), device="cpu"), - "loss_mask": torch.zeros((1, 1), dtype=int, device="cpu"), - "response_mask": torch.zeros((1, 1), dtype=int, device="cpu"), + "sequences": torch.randint(0, 100, (batch_size, seq_len), device="cpu"), + "attention_mask": torch.ones((batch_size, seq_len), dtype=int, device="cpu"), + "action_log_probs": 0.4 * torch.ones((batch_size, num_actions), device="cpu"), + "base_action_log_probs": 0.3 * torch.ones((batch_size, num_actions), device="cpu"), + "values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), + # Loss mask is all zeros because we don't want padding samples to contribute to the loss. + "loss_mask": torch.zeros((batch_size, num_actions), dtype=int, device="cpu"), + "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), } ) data.metadata = self._data.metadata @@ -239,9 +244,12 @@ def __len__(self): def __iter__(self): for microbatch in self._microbatches: - yield self._create_microbatch_from_indices(microbatch) + microbatch_data = self._create_microbatch_from_indices(microbatch) + yield microbatch_data + for _ in range(self._num_padding_microbatches): - yield self._create_padding_microbatch() + padding_microbatch = self._create_padding_microbatch() + yield padding_microbatch def reorder_and_combine_batches(self, batches: List[TensorBatch]) -> TensorBatch: """Reorder and combine output batches into a single batch with diff --git a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py index c9880a017..9d1782fac 100644 --- a/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py +++ b/skyrl-train/tests/gpu/gpu_ci/test_ppo_train.py @@ -218,3 +218,75 @@ def test_gradient_accumulation_scenarios( print(f" - Actual optimizer steps: {actual_optimizer_steps}") finally: ray.shutdown() + + +import torch +from skyrl_train.training_batch import TrainingInputBatch + +def make_dummy_batch(seq_lens, num_actions=4) -> TrainingInputBatch: + """Create a dummy TrainingInputBatch""" + + batch_size = len(seq_lens) + max_seq_len = max(seq_lens) + + sequences = torch.zeros((batch_size, max_seq_len), dtype=int, device="cpu") + attention_mask = torch.zeros((batch_size, max_seq_len), dtype=int, device="cpu") + for i, seq_len in enumerate(seq_lens): + sequences[i, :seq_len] = torch.randint(0, 100, (seq_len,), dtype=int, device="cpu") + attention_mask[i, :seq_len] = 1 + + print(sequences) + print(attention_mask) + + # Add all the required fields for training + data = TrainingInputBatch( + { + "sequences": sequences, + "attention_mask": attention_mask, + "action_log_probs": 0.4 * torch.ones((batch_size, num_actions), device="cpu"), + "base_action_log_probs": 0.3 * torch.ones((batch_size, num_actions), device="cpu"), + "values": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "returns": 0.5 * torch.ones((batch_size, num_actions), device="cpu"), + "advantages": 0.6 * torch.ones((batch_size, num_actions), device="cpu"), + "loss_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + "response_mask": torch.ones((batch_size, num_actions), dtype=int, device="cpu"), + } + ) + data.metadata = {"response_length": num_actions} + return data + + + +@pytest.mark.parametrize("worker_type", ["policy", "critic"]) +def test_max_tokens_per_microbatch(ray_init_fixture, cfg, worker_type): + try: + cfg.trainer.strategy = "fsdp2" # Strategy logic is not tested here. + cfg.trainer.max_tokens_per_microbatch = 15 + + # Hard-code to a single worker for simplicity + cfg.trainer.placement.policy_num_gpus_per_node = 2 + cfg.trainer.policy_mini_batch_size = 8 + + actor_group = init_worker_with_type( + worker_type, + shared_pg=None, + colocate_all=False, + num_gpus_per_node=cfg.trainer.placement.policy_num_gpus_per_node, + cfg=cfg, + ) + + train_data = make_dummy_batch([10, 10, 6, 5, 10, 10, 5, 5], num_actions=4) + # Expect: + # - dp=0: 3 microbatches with [10], [10], [11] + # - dp=1: 3 microbatches with [10, 5], [10, 5], [] + train_data.metadata["global_step"] = 0 + + results = ray.get(actor_group.async_run_ray_method("mesh", "forward", train_data)) + results = ray.get(actor_group.async_run_ray_method("mesh", "ppo_train", train_data)) + assert len(results) == cfg.trainer.placement.policy_num_gpus_per_node, "Should get result from each GPU" + + print(results) + # TODO: Add assertions for the results. + + finally: + ray.shutdown()