-
Notifications
You must be signed in to change notification settings - Fork 222
[skyrl-train] Add max_tokens_per_microbatch configuration
#847
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cb0af6e
18a1792
8b029aa
267104c
e9efe6a
ee1f3e1
a3b3320
9f54515
8e301d3
b2f6499
ae7d981
9335c33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,12 @@ | |
| from loguru import logger | ||
| from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch | ||
| from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl | ||
| from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics | ||
| from skyrl_train.workers.worker_utils import ( | ||
| BaseBatchIterator, | ||
| SampleBasedBatchIterator, | ||
| TokenBasedBatchIterator, | ||
| reduce_metrics, | ||
| ) | ||
| from skyrl_train.dataset.replay_buffer import Experience | ||
| from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch | ||
| from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient | ||
|
|
@@ -43,6 +48,15 @@ | |
| _SET_AFFINITY = False | ||
|
|
||
|
|
||
| def _get_microbatch_iterator( | ||
| data: TrainingInputBatch, micro_batch_size: int, max_tokens_per_microbatch: int | ||
| ) -> BaseBatchIterator: | ||
| if max_tokens_per_microbatch > 0: | ||
| return TokenBasedBatchIterator(data, max_tokens_per_microbatch=max_tokens_per_microbatch) | ||
| else: | ||
| return SampleBasedBatchIterator(data, sample_batch_size=micro_batch_size, drop_last=False) | ||
|
|
||
|
|
||
| # Adapted from OpenRLHF: https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/trainer/ray/launcher.py#L17 | ||
| class DistributedTorchRayActor: | ||
| def __init__( | ||
|
|
@@ -305,16 +319,22 @@ def forward( | |
| ) -> TrainingOutputBatch: | ||
| """Run forward pass on the input batch in inference mode. | ||
|
|
||
| This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.trainer.micro_forward_batch_size_per_gpu`. | ||
| This is a wrapper around `_forward_micro_batch` that runs in micro batches. | ||
| Uses token-based chunking if `max_tokens_per_microbatch` is configured, otherwise | ||
| falls back to sample-based chunking with `micro_forward_batch_size_per_gpu`. | ||
| """ | ||
| # run in micro batches of cfg.trainer.micro_forward_batch_size_per_gpu | ||
| # TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc. | ||
| micro_batches = data.chunk(self.cfg.trainer.micro_forward_batch_size_per_gpu) | ||
| microbatch_iterator = _get_microbatch_iterator( | ||
| data, | ||
| micro_batch_size=self.cfg.trainer.micro_forward_batch_size_per_gpu, | ||
| max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch, | ||
| ) | ||
|
|
||
| outputs = [] | ||
| for micro_batch in micro_batches: | ||
| outputs.append(self._forward_micro_batch(micro_batch)) | ||
| output = TrainingOutputBatch.cat(outputs) | ||
| for microbatch in microbatch_iterator: | ||
| outputs.append(self._forward_micro_batch(microbatch)) | ||
| output = microbatch_iterator.reorder_and_combine_batches(outputs) | ||
|
|
||
| if output.device is not None and output.device != torch.device("cpu"): | ||
| output = output.to("cpu") | ||
| return output | ||
|
|
@@ -724,7 +744,7 @@ def optim_step(self) -> float: | |
|
|
||
| def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: | ||
| global_step = train_data.metadata["global_step"] | ||
| minibatch_iterator = BatchIterator( | ||
| minibatch_iterator = SampleBasedBatchIterator( | ||
| train_data, sample_batch_size=self.policy_mini_batch_size_per_gpu, drop_last=False | ||
| ) | ||
|
|
||
|
|
@@ -784,16 +804,23 @@ def record_status(status: Dict[str, float]): | |
| disable=not self.strategy.is_rank_0(), | ||
| ) | ||
| for minibatch in minibatch_pbar: | ||
| microbatch_iterator = BatchIterator( | ||
| minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| microbatch_iterator = _get_microbatch_iterator( | ||
| minibatch, | ||
| micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, | ||
| max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch, | ||
| ) | ||
| num_microbatches = len(microbatch_iterator) | ||
| microbatch_weight = 1.0 / num_microbatches | ||
|
|
||
| temp = [] | ||
| for microbatch_idx, microbatch in enumerate(microbatch_iterator): | ||
| microbatch_experience = BatchIterator.batch_to_experience(microbatch) | ||
| microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) | ||
| microbatch_weight = len(microbatch) / len(minibatch) | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit confusing for cases where there's padding batches. This is because the padding batch cannot be empty and must have seqlen >= 2. Example: with 2 real microbatches of 15, and then 1 padding, the weights would be: 15/30, 15/30, 2/30.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does seqlen have to be >= 2? |
||
| status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) | ||
|
|
||
| print("!!! loss", self.mesh_rank, status["final_loss"]) | ||
|
|
||
| temp.append(microbatch["attention_mask"].sum().item()) | ||
|
|
||
| # Record status for all but the last microbatch in the minibatch. | ||
| # The last microbatch should be recorded after the optimizer step. | ||
| if microbatch_idx < num_microbatches - 1: | ||
|
|
@@ -808,6 +835,8 @@ def record_status(status: Dict[str, float]): | |
| if grad_norm is not None: | ||
| status["raw_grad_norm"] = grad_norm | ||
|
|
||
| # print(self.mesh_rank, temp) | ||
|
|
||
| if self.record_memory: | ||
| self.save_memory_snapshot(global_step, local_step) | ||
|
|
||
|
|
@@ -991,7 +1020,7 @@ def save_hf_model(self, export_dir: str, tokenizer): | |
|
|
||
| def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch: | ||
| global_step = train_data.metadata["global_step"] | ||
| minibatch_iterator = BatchIterator( | ||
| minibatch_iterator = SampleBasedBatchIterator( | ||
| train_data, sample_batch_size=self.critic_mini_batch_size_per_gpu, drop_last=False | ||
| ) | ||
|
|
||
|
|
@@ -1019,14 +1048,16 @@ def record_status(status: Dict[str, float]): | |
| disable=not self.strategy.is_rank_0(), | ||
| ) | ||
| for minibatch in minibatch_pbar: | ||
| microbatch_iterator = BatchIterator( | ||
| minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False | ||
| microbatch_iterator = _get_microbatch_iterator( | ||
| minibatch, | ||
| micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, | ||
| max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch, | ||
| ) | ||
| num_microbatches = len(microbatch_iterator) | ||
| microbatch_weight = 1.0 / num_microbatches | ||
|
|
||
| for microbatch_idx, microbatch in enumerate(microbatch_iterator): | ||
| microbatch_experience = BatchIterator.batch_to_experience(microbatch) | ||
| microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch) | ||
| microbatch_weight = len(microbatch) / len(minibatch) | ||
| status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight) | ||
|
|
||
| if microbatch_idx < num_microbatches - 1: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need some config validation to make sure that this is larger than the max response length? Maybe in utils/utils.py?