Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ trainer:
critic_mini_batch_size: 256
micro_train_batch_size_per_gpu: 1
micro_forward_batch_size_per_gpu: 1
max_tokens_per_microbatch: -1 # TODO: Maybe split this between forward and train; -1 means no token-based chunking
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need some config validation to make sure that this is larger than the max response length? Maybe in utils/utils.py?

update_ref_every_epoch: false
use_sample_packing: true
eval_batch_size: 1024
Expand Down
6 changes: 3 additions & 3 deletions skyrl-train/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from skyrl_train.utils.utils import update_model_config, str_to_torch_dtype
from skyrl_train.utils.constants import SKYRL_WORKER_NCCL_TIMEOUT_IN_S
from skyrl_train.training_batch import TrainingOutputBatch
from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics
from skyrl_train.workers.worker_utils import BaseBatchIterator, SampleBasedBatchIterator, reduce_metrics
from skyrl_train.workers.worker import (
PolicyWorkerBase,
RefWorkerBase,
Expand Down Expand Up @@ -512,7 +512,7 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch":
Since we want megatron to handle gradient accumulation over micro batches, we directly pass mini batches into the
worker MegatronModelWrapper.forward_backward_mini_batch method.
"""
dataloader = BatchIterator(
dataloader = SampleBasedBatchIterator(
train_data, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
)

Expand All @@ -538,7 +538,7 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch":
# TODO: Convert this into 2 loops for minibatches and microbatches.
micro_buffer = []
for local_step, microbatch in enumerate(pbar):
experience = BatchIterator.batch_to_experience(microbatch)
experience = BaseBatchIterator.batch_to_experience(microbatch)
experience.to_device(torch.cuda.current_device())
sequences = experience.sequences
attention_mask = experience.attention_mask
Expand Down
65 changes: 48 additions & 17 deletions skyrl-train/skyrl_train/workers/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
from loguru import logger
from skyrl_train.distributed.ulysses import set_ulysses_sequence_parallel_group, apply_monkey_patch
from skyrl_train.utils.ppo_utils import PolicyLossRegistry, ppo_critic_loss, compute_approx_kl
from skyrl_train.workers.worker_utils import BatchIterator, reduce_metrics
from skyrl_train.workers.worker_utils import (
BaseBatchIterator,
SampleBasedBatchIterator,
TokenBasedBatchIterator,
reduce_metrics,
)
from skyrl_train.dataset.replay_buffer import Experience
from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch
from skyrl_train.inference_engines.inference_engine_client import InferenceEngineClient
Expand All @@ -43,6 +48,15 @@
_SET_AFFINITY = False


def _get_microbatch_iterator(
data: TrainingInputBatch, micro_batch_size: int, max_tokens_per_microbatch: int
) -> BaseBatchIterator:
if max_tokens_per_microbatch > 0:
return TokenBasedBatchIterator(data, max_tokens_per_microbatch=max_tokens_per_microbatch)
else:
return SampleBasedBatchIterator(data, sample_batch_size=micro_batch_size, drop_last=False)


# Adapted from OpenRLHF: https://github.com/OpenRLHF/OpenRLHF/blob/main/openrlhf/trainer/ray/launcher.py#L17
class DistributedTorchRayActor:
def __init__(
Expand Down Expand Up @@ -305,16 +319,22 @@ def forward(
) -> TrainingOutputBatch:
"""Run forward pass on the input batch in inference mode.

This is a wrapper around `_forward_micro_batch` that runs in micro batches of `cfg.trainer.micro_forward_batch_size_per_gpu`.
This is a wrapper around `_forward_micro_batch` that runs in micro batches.
Uses token-based chunking if `max_tokens_per_microbatch` is configured, otherwise
falls back to sample-based chunking with `micro_forward_batch_size_per_gpu`.
"""
# run in micro batches of cfg.trainer.micro_forward_batch_size_per_gpu
# TODO (sumanthrh): this can be in the policy/critic impl if the micro batch size can be specific to policy, critic, etc.
micro_batches = data.chunk(self.cfg.trainer.micro_forward_batch_size_per_gpu)
microbatch_iterator = _get_microbatch_iterator(
data,
micro_batch_size=self.cfg.trainer.micro_forward_batch_size_per_gpu,
max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch,
)

outputs = []
for micro_batch in micro_batches:
outputs.append(self._forward_micro_batch(micro_batch))
output = TrainingOutputBatch.cat(outputs)
for microbatch in microbatch_iterator:
outputs.append(self._forward_micro_batch(microbatch))
output = microbatch_iterator.reorder_and_combine_batches(outputs)

if output.device is not None and output.device != torch.device("cpu"):
output = output.to("cpu")
return output
Expand Down Expand Up @@ -724,7 +744,7 @@ def optim_step(self) -> float:

def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch:
global_step = train_data.metadata["global_step"]
minibatch_iterator = BatchIterator(
minibatch_iterator = SampleBasedBatchIterator(
train_data, sample_batch_size=self.policy_mini_batch_size_per_gpu, drop_last=False
)

Expand Down Expand Up @@ -784,16 +804,23 @@ def record_status(status: Dict[str, float]):
disable=not self.strategy.is_rank_0(),
)
for minibatch in minibatch_pbar:
microbatch_iterator = BatchIterator(
minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
microbatch_iterator = _get_microbatch_iterator(
minibatch,
micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu,
max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch,
)
num_microbatches = len(microbatch_iterator)
microbatch_weight = 1.0 / num_microbatches

temp = []
for microbatch_idx, microbatch in enumerate(microbatch_iterator):
microbatch_experience = BatchIterator.batch_to_experience(microbatch)
microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch)
microbatch_weight = len(microbatch) / len(minibatch)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit confusing for cases where there's padding batches. This is because the padding batch cannot be empty and must have seqlen >= 2.

Example: with 2 real microbatches of 15, and then 1 padding, the weights would be: 15/30, 15/30, 2/30.

  • This counts on the dummy microbatch having a fully 0 loss mask, which zeros out this weight.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why does seqlen have to be >= 2?

status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight)

print("!!! loss", self.mesh_rank, status["final_loss"])

temp.append(microbatch["attention_mask"].sum().item())

# Record status for all but the last microbatch in the minibatch.
# The last microbatch should be recorded after the optimizer step.
if microbatch_idx < num_microbatches - 1:
Expand All @@ -808,6 +835,8 @@ def record_status(status: Dict[str, float]):
if grad_norm is not None:
status["raw_grad_norm"] = grad_norm

# print(self.mesh_rank, temp)

if self.record_memory:
self.save_memory_snapshot(global_step, local_step)

Expand Down Expand Up @@ -991,7 +1020,7 @@ def save_hf_model(self, export_dir: str, tokenizer):

def ppo_train(self, train_data: TrainingInputBatch) -> TrainingOutputBatch:
global_step = train_data.metadata["global_step"]
minibatch_iterator = BatchIterator(
minibatch_iterator = SampleBasedBatchIterator(
train_data, sample_batch_size=self.critic_mini_batch_size_per_gpu, drop_last=False
)

Expand Down Expand Up @@ -1019,14 +1048,16 @@ def record_status(status: Dict[str, float]):
disable=not self.strategy.is_rank_0(),
)
for minibatch in minibatch_pbar:
microbatch_iterator = BatchIterator(
minibatch, sample_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu, drop_last=False
microbatch_iterator = _get_microbatch_iterator(
minibatch,
micro_batch_size=self.cfg.trainer.micro_train_batch_size_per_gpu,
max_tokens_per_microbatch=self.cfg.trainer.max_tokens_per_microbatch,
)
num_microbatches = len(microbatch_iterator)
microbatch_weight = 1.0 / num_microbatches

for microbatch_idx, microbatch in enumerate(microbatch_iterator):
microbatch_experience = BatchIterator.batch_to_experience(microbatch)
microbatch_experience = BaseBatchIterator.batch_to_experience(microbatch)
microbatch_weight = len(microbatch) / len(minibatch)
status = self.forward_backward(microbatch_experience, microbatch_weight=microbatch_weight)

if microbatch_idx < num_microbatches - 1:
Expand Down
Loading
Loading