Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
f5a64c4
refactor trainer worker to fwd/bwd and optim_step
tyler-griggs Dec 22, 2025
e63f0bd
Merge remote-tracking branch 'real/main' into tgriggs/refac_worker
tyler-griggs Dec 22, 2025
58221a1
format
tyler-griggs Dec 22, 2025
0f4d118
fix ci, clean dead code
tyler-griggs Dec 23, 2025
a353fb0
optim_step cleanup
tyler-griggs Dec 23, 2025
ec96fd3
working on rm training step
tyler-griggs Dec 27, 2025
8102e33
Merge remote-tracking branch 'real/main' into HEAD
tyler-griggs Jan 1, 2026
e855d41
Merge remote-tracking branch 'real/main' into tgriggs/rm_training_step
tyler-griggs Jan 5, 2026
1959fed
removing training step
tyler-griggs Jan 5, 2026
6909e3b
fix tests
tyler-griggs Jan 5, 2026
4e79626
x
tyler-griggs Jan 5, 2026
68e1ade
pulling apart infra and alg logic
tyler-griggs Jan 9, 2026
5e68c8a
md file, megatron, dispatch test
tyler-griggs Jan 9, 2026
5b04a24
Merge origin/main into tgriggs/rm_ppo_train
tyler-griggs Jan 19, 2026
8e4a105
Fix formatting in worker files
tyler-griggs Jan 19, 2026
6705e35
Remove ppo_train from FSDP workers, use gradient scaling
tyler-griggs Jan 19, 2026
bcf977b
Fix test_save_load_checkpoint.py for new forward_backward API
tyler-griggs Jan 19, 2026
f6850f6
Restore mini-batch loop in _execute_training_step
tyler-griggs Jan 19, 2026
b757ed2
Fix GPU memory and API issues
tyler-griggs Jan 19, 2026
4b30b61
fixes
tyler-griggs Jan 19, 2026
9c35196
Refactor trainer to route all offload/backload through dispatch
tyler-griggs Jan 19, 2026
35e5bf4
Remove project-summary.md from PR
tyler-griggs Jan 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
404 changes: 185 additions & 219 deletions skyrl-train/skyrl_train/trainer.py

Large diffs are not rendered by default.

18 changes: 16 additions & 2 deletions skyrl-train/skyrl_train/workers/megatron/megatron_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,20 @@ def _broadcast_no_grad(*args, **kwargs):
pp_size=mpu.get_pipeline_model_parallel_world_size(),
)

def _normalize_mini_batch_size(self):
"""
Override to set Megatron-specific batch size attributes.

Megatron's ppo_train method needs policy_mini_batch_size_per_gpu to compute
how many micro batches fit in a mini batch for gradient accumulation.
"""
super()._normalize_mini_batch_size() # Sets _micro_batches_accumulated

# Megatron-specific: compute mini batch size per GPU for ppo_train
n_samples = self.cfg.generator.n_samples_per_prompt
dp_size = self.mesh_rank.dp_size
self.policy_mini_batch_size_per_gpu = (self.cfg.trainer.policy_mini_batch_size * n_samples) // dp_size

def init_model(self, model_path, num_training_steps: int = 1e9):
"""
Initialize the model, optimizer, and scheduler for the policy worker.
Expand Down Expand Up @@ -537,8 +551,8 @@ def ppo_train(self, train_data) -> "TrainingOutputBatch":

# TODO: Convert this into 2 loops for minibatches and microbatches.
micro_buffer = []
for local_step, microbatch in enumerate(pbar):
experience = BatchIterator.batch_to_experience(microbatch)
for local_step, experience in enumerate(pbar):
# BatchIterator now yields Experience objects directly
experience.to_device(torch.cuda.current_device())
sequences = experience.sequences
attention_mask = experience.attention_mask
Expand Down
Loading