-
Notifications
You must be signed in to change notification settings - Fork 243
[Tinker] Refactor trainer and worker (to move algo to trainer and infra to worker) #859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
tyler-griggs
merged 22 commits into
NovaSky-AI:main
from
tyler-griggs:tgriggs/rm_ppo_train
Jan 19, 2026
Merged
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
f5a64c4
refactor trainer worker to fwd/bwd and optim_step
tyler-griggs e63f0bd
Merge remote-tracking branch 'real/main' into tgriggs/refac_worker
tyler-griggs 58221a1
format
tyler-griggs 0f4d118
fix ci, clean dead code
tyler-griggs a353fb0
optim_step cleanup
tyler-griggs ec96fd3
working on rm training step
tyler-griggs 8102e33
Merge remote-tracking branch 'real/main' into HEAD
tyler-griggs e855d41
Merge remote-tracking branch 'real/main' into tgriggs/rm_training_step
tyler-griggs 1959fed
removing training step
tyler-griggs 6909e3b
fix tests
tyler-griggs 4e79626
x
tyler-griggs 68e1ade
pulling apart infra and alg logic
tyler-griggs 5e68c8a
md file, megatron, dispatch test
tyler-griggs 5b04a24
Merge origin/main into tgriggs/rm_ppo_train
tyler-griggs 8e4a105
Fix formatting in worker files
tyler-griggs 6705e35
Remove ppo_train from FSDP workers, use gradient scaling
tyler-griggs bcf977b
Fix test_save_load_checkpoint.py for new forward_backward API
tyler-griggs f6850f6
Restore mini-batch loop in _execute_training_step
tyler-griggs b757ed2
Fix GPU memory and API issues
tyler-griggs 4b30b61
fixes
tyler-griggs 9c35196
Refactor trainer to route all offload/backload through dispatch
tyler-griggs 35e5bf4
Remove project-summary.md from PR
tyler-griggs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,314 @@ | ||
| """ | ||
| WorkerDispatch: Manages all actor groups with automatic offload/onload. | ||
|
|
||
| When colocate_all=True, automatically handles GPU placement: | ||
| - Offloads other models when one is requested | ||
| - Tracks which model is currently on GPU | ||
|
|
||
| The trainer interacts with this as if all models are always on GPU. | ||
| """ | ||
|
|
||
| from typing import Dict, Optional, List | ||
| from dataclasses import dataclass | ||
| import ray | ||
| from omegaconf import DictConfig | ||
|
|
||
| from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch | ||
| from skyrl_train.workers.worker import PPORayActorGroup | ||
| from skyrl_train.distributed.dispatch import concatenate_outputs_after_mesh_dispatch | ||
|
|
||
|
|
||
| @dataclass | ||
| class GPUState: | ||
| """Tracks what's on GPU for a model.""" | ||
| model_on_gpu: bool = False | ||
| optimizer_on_gpu: bool = False | ||
|
|
||
|
|
||
| class WorkerDispatch: | ||
| """ | ||
| Unified dispatch layer that manages all actor groups (policy, critic, ref). | ||
|
|
||
| Handles automatic offload/onload when colocate_all=True. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| cfg: DictConfig, | ||
| policy_actor_group: PPORayActorGroup, | ||
| critic_actor_group: Optional[PPORayActorGroup] = None, | ||
| ref_actor_group: Optional[PPORayActorGroup] = None, | ||
| ): | ||
| self.cfg = cfg | ||
| self.colocate_all = cfg.trainer.placement.colocate_all | ||
| self.colocate_policy_ref = cfg.trainer.placement.colocate_policy_ref | ||
|
|
||
| # Actor groups by name. | ||
| # TODO: Remove these role-specific identifiers. We will move to using model IDs and add support for generic models beyond these. | ||
| self._actor_groups: Dict[str, PPORayActorGroup] = {"policy": policy_actor_group} | ||
| if critic_actor_group is not None: | ||
| self._actor_groups["critic"] = critic_actor_group | ||
| if ref_actor_group is not None: | ||
| self._actor_groups["ref"] = ref_actor_group | ||
|
|
||
| # GPU state tracking (only matters when colocated) | ||
| self._gpu_state: Dict[str, GPUState] = { | ||
| name: GPUState() for name in self._actor_groups.keys() | ||
| } | ||
|
|
||
| # === Properties === | ||
|
|
||
| def get_lcm_dp_size(self) -> int: | ||
| """Get LCM of all models' dp_size.""" | ||
| import math | ||
| dp_size = self._actor_groups["policy"].actor_infos[0].rank.dp_size | ||
| if "critic" in self._actor_groups: | ||
| dp_size = math.lcm(dp_size, self._actor_groups["critic"].actor_infos[0].rank.dp_size) | ||
| if "ref" in self._actor_groups: | ||
| dp_size = math.lcm(dp_size, self._actor_groups["ref"].actor_infos[0].rank.dp_size) | ||
| return dp_size | ||
|
|
||
| # === GPU State Management === | ||
|
|
||
| def _should_manage_offload(self, model: str) -> bool: | ||
| """Check if we need to manage offload for this model.""" | ||
| if self.colocate_all: | ||
| return True | ||
| if self.colocate_policy_ref and model in ("policy", "ref"): | ||
| return True | ||
| return False | ||
|
|
||
| def _get_colocation_group(self, model: str) -> List[str]: | ||
| """Get which models share GPU with the given model.""" | ||
| if self.colocate_all: | ||
| return list(self._actor_groups.keys()) | ||
| elif self.colocate_policy_ref and model in ("policy", "ref"): | ||
| return [m for m in ["policy", "ref"] if m in self._actor_groups] | ||
| return [model] | ||
|
|
||
| def _ensure_on_gpu(self, model: str, need_optimizer: bool = True, need_model: bool = True) -> None: | ||
| """Ensure model is on GPU, offloading others in same colocation group if needed.""" | ||
| if not self._should_manage_offload(model): | ||
| return | ||
|
|
||
| if model not in self._actor_groups: | ||
| return | ||
|
|
||
| group = self._get_colocation_group(model) | ||
|
|
||
| # Offload others in the same colocation group | ||
| for other in group: | ||
| if other != model and other in self._actor_groups: | ||
| state = self._gpu_state[other] | ||
| if state.model_on_gpu or state.optimizer_on_gpu: | ||
| self._actor_groups[other].offload_to_cpu() | ||
| self._gpu_state[other] = GPUState() | ||
|
|
||
| # Backload requested model | ||
| state = self._gpu_state[model] | ||
| needs_backload = (need_model and not state.model_on_gpu) or (need_optimizer and not state.optimizer_on_gpu) | ||
|
|
||
| if needs_backload: | ||
| self._actor_groups[model].backload_to_gpu( | ||
| backload_optimizer=need_optimizer, | ||
| backload_model=need_model, | ||
| ) | ||
| if need_model: | ||
| self._gpu_state[model].model_on_gpu = True | ||
| if need_optimizer: | ||
| self._gpu_state[model].optimizer_on_gpu = True | ||
|
|
||
| def _offload(self, model: str, offload_optimizer: bool = True, offload_model: bool = True) -> None: | ||
| """Offload model to CPU.""" | ||
| if not self._should_manage_offload(model): | ||
| return | ||
|
|
||
| if model not in self._actor_groups: | ||
| return | ||
|
|
||
| self._actor_groups[model].offload_to_cpu( | ||
| offload_optimizer=offload_optimizer, | ||
| offload_model=offload_model, | ||
| ) | ||
|
|
||
| if offload_model: | ||
| self._gpu_state[model].model_on_gpu = False | ||
| if offload_optimizer: | ||
| self._gpu_state[model].optimizer_on_gpu = False | ||
|
|
||
| def mark_all_offloaded(self) -> None: | ||
| """Mark all models as offloaded (call after build_models when colocate_all).""" | ||
| for model in self._actor_groups: | ||
| self._gpu_state[model] = GPUState() | ||
|
|
||
| # === Forward Pass (Inference) === | ||
|
|
||
| def forward(self, model: str, data: TrainingInputBatch) -> TrainingOutputBatch: | ||
| """Run inference forward pass. Only loads model (not optimizer).""" | ||
| self._ensure_on_gpu(model, need_optimizer=False, need_model=True) | ||
|
|
||
| refs = self._actor_groups[model].async_run_ray_method("mesh", "forward", data=data) | ||
| results = ray.get(refs) | ||
|
|
||
| output = concatenate_outputs_after_mesh_dispatch( | ||
| self._actor_groups[model].actor_infos, results | ||
| ) | ||
| return output | ||
|
|
||
| # === Training === | ||
|
|
||
| def forward_backward(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: | ||
| """Run forward/backward pass. Needs model + optimizer.""" | ||
| self._ensure_on_gpu(model, need_optimizer=True, need_model=True) | ||
|
|
||
| refs = self._actor_groups[model].async_run_ray_method("mesh", "forward_backward", data) | ||
| statuses = ray.get(refs) | ||
|
|
||
| self._save_memory_snapshot(model, "forward_backward") | ||
| return statuses[0] | ||
|
|
||
| def optim_step(self, model: str) -> Optional[float]: | ||
| """Run optimizer step. Model should already be on GPU from forward_backward.""" | ||
| refs = self._actor_groups[model].async_run_ray_method("pass_through", "optim_step") | ||
| grad_norms = ray.get(refs) | ||
|
|
||
| self._save_memory_snapshot(model, "optim_step") | ||
| return grad_norms[0] | ||
|
|
||
| # TODO(tgriggs): Remove this when Megatron supports forward_backward and optim_step. | ||
| def ppo_train(self, model: str, data: TrainingInputBatch) -> Dict[str, float]: | ||
| """Run full PPO training loop (for Megatron).""" | ||
| self._ensure_on_gpu(model, need_optimizer=True, need_model=True) | ||
|
|
||
| refs = self._actor_groups[model].async_run_ray_method("mesh", "ppo_train", data) | ||
| statuses = ray.get(refs) | ||
|
|
||
| return statuses[0].metadata["train_status"] | ||
|
|
||
| def _save_memory_snapshot(self, model: str, tag: str) -> None: | ||
| """Save memory snapshot on workers.""" | ||
| ray.get( | ||
| self._actor_groups[model].async_run_ray_method( | ||
| "pass_through", "save_memory_snapshot", tag=f"{model}_{tag}" | ||
| ) | ||
| ) | ||
|
|
||
| # === Checkpointing === | ||
|
|
||
| def save_checkpoint(self, model: str, ckpt_dir: str, tokenizer=None) -> None: | ||
| """Save checkpoint for model.""" | ||
| self._ensure_on_gpu(model, need_optimizer=True, need_model=True) | ||
|
|
||
| ray.get( | ||
| self._actor_groups[model].async_run_ray_method( | ||
| "pass_through", "save_checkpoint", ckpt_dir=ckpt_dir, tokenizer=tokenizer | ||
| ) | ||
| ) | ||
|
|
||
| def load_checkpoint( | ||
| self, | ||
| model: str, | ||
| ckpt_dir: str, | ||
| load_optimizer_states: bool = True, | ||
| load_lr_scheduler_states: bool = True, | ||
| ) -> None: | ||
| """Load checkpoint for model.""" | ||
| self._ensure_on_gpu(model, need_optimizer=load_optimizer_states, need_model=True) | ||
|
|
||
| ray.get( | ||
| self._actor_groups[model].async_run_ray_method( | ||
| "pass_through", | ||
| "load_checkpoint", | ||
| ckpt_dir=ckpt_dir, | ||
| load_optimizer_states=load_optimizer_states, | ||
| load_lr_scheduler_states=load_lr_scheduler_states, | ||
| ) | ||
| ) | ||
|
|
||
| def save_hf_model(self, model: str, export_dir: str, tokenizer) -> None: | ||
| """Save model in HuggingFace format.""" | ||
| self._ensure_on_gpu(model, need_optimizer=False, need_model=True) | ||
|
|
||
| ray.get( | ||
| self._actor_groups[model].async_run_ray_method( | ||
| "pass_through", "save_hf_model", export_dir, tokenizer | ||
| ) | ||
| ) | ||
|
|
||
| # === Model Initialization === | ||
|
|
||
| def init_model(self, model: str, model_path: str, num_training_steps: Optional[int] = None) -> None: | ||
| """Initialize model from path. Offloads others in colocation group first.""" | ||
| # Offload others in colocation group before init | ||
| if self._should_manage_offload(model): | ||
| group = self._get_colocation_group(model) | ||
| for other in group: | ||
| if other != model and other in self._actor_groups: | ||
| state = self._gpu_state[other] | ||
| if state.model_on_gpu or state.optimizer_on_gpu: | ||
| self._actor_groups[other].offload_to_cpu() | ||
| self._gpu_state[other] = GPUState() | ||
|
|
||
| kwargs = {"model_path": model_path} | ||
| if num_training_steps is not None: | ||
| kwargs["num_training_steps"] = num_training_steps | ||
|
|
||
| ray.get(self._actor_groups[model].async_init_model(**kwargs)) | ||
|
|
||
| # After init, model is on GPU | ||
| self._gpu_state[model].model_on_gpu = True | ||
| self._gpu_state[model].optimizer_on_gpu = (model != "ref") # ref has no optimizer | ||
|
|
||
| # === Weight Sync === | ||
|
|
||
| def init_weight_sync_state(self, inference_engine_client) -> None: | ||
| """Initialize weight sync state for policy model.""" | ||
| ray.get( | ||
| self._actor_groups["policy"].async_run_ray_method( | ||
| "pass_through", "init_weight_sync_state", inference_engine_client | ||
| ) | ||
| ) | ||
|
|
||
| def broadcast_to_inference_engines(self, inference_engine_client) -> None: | ||
| """Broadcast policy weights to inference engines.""" | ||
| ray.get( | ||
| self._actor_groups["policy"].async_run_ray_method( | ||
| "pass_through", "broadcast_to_inference_engines", inference_engine_client | ||
| ) | ||
| ) | ||
|
|
||
| def prepare_for_weight_sync(self) -> None: | ||
| """Prepare for weight sync: ensure policy model is on GPU, offload optimizer.""" | ||
| if not self.colocate_all: | ||
| return | ||
| # Ensure policy model is on GPU (will offload others in colocation group) | ||
| self._ensure_on_gpu("policy", need_optimizer=False, need_model=True) | ||
| # Offload optimizer if it's on GPU | ||
| if self._gpu_state["policy"].optimizer_on_gpu: | ||
| self._offload("policy", offload_optimizer=True, offload_model=False) | ||
|
|
||
| def finish_weight_sync(self) -> None: | ||
| """Finish weight sync: offload model.""" | ||
| if not self.colocate_all: | ||
| return | ||
| self._offload("policy", offload_optimizer=False, offload_model=True) | ||
|
|
||
| # === Utility === | ||
|
|
||
| def empty_cache(self, model: Optional[str] = None) -> None: | ||
| """Empty GPU cache for model(s).""" | ||
| if model is not None: | ||
| ray.get(self._actor_groups[model].async_run_ray_method("pass_through", "empty_cache")) | ||
| else: | ||
| refs = [] | ||
| for group in self._actor_groups.values(): | ||
| refs.extend(group.async_run_ray_method("pass_through", "empty_cache")) | ||
| ray.get(refs) | ||
|
|
||
| def get_node_ids(self) -> List[str]: | ||
| """Get unique node IDs from all actor groups.""" | ||
| all_node_ids = [] | ||
| for group in self._actor_groups.values(): | ||
| node_ids = ray.get(group.async_run_ray_method("pass_through", "get_ray_node_id")) | ||
| all_node_ids.extend(node_ids) | ||
| return list(set(all_node_ids)) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.