Skip to content

Conversation

@tyler-griggs
Copy link
Member

PR Overview

This PR has a pretty significant architectural refactor. The primary changes are:

  1. Move algorithm logic into the trainer and infra logic into the workers.

    • Remove ppo_train (from FSDP workers for now, @erictang000 to remove from Megatron)
    • Move micro-batching logic from trainer to worker (worker doesn't know mini batches, trainer doesn't know micro batches
    • Worker onloading/offloading moved out of trainer
  2. Introduce WorkerDispatch: an intermediate layer between the training loop and workers. Later, this is the entity that will sit beneath the Tinker API server. It is the entity that handles a "pool" of workers. Currently it dispatches requests to the pool of workers (by mesh or by pass_through) and handles onloading / offloading workers.
    i. I'm very open to renaming suggestions.


Architecture

Here's a diagram of what the breakdown of responsibilities now looks like. Note that some of this will still need to change. E.g., the trainer will not call dispatch in the future, it will call into the Tinker API server.

┌─────────────────────────────────────────────────────────────┐
│                    TRAINER (Algorithm)                       │
│  - PPO algorithm implementation                              │
│  - Knows only mini batches                                   │
│  - Calls dispatch.forward_backward() + dispatch.optim_step() │
└─────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────┐
│                 WORKER DISPATCH (Coordination)               │
│  - Manages all actor groups (policy, critic, ref)           │
│  - Handles GPU state (offload/backload) automatically       │
│  - Routes calls to appropriate workers                      │
│  - Handles DP sharding via MeshDispatch                     │
└─────────────────────────────────────────────────────────────┘
                              │
                              ▼
┌─────────────────────────────────────────────────────────────┐
│                    WORKERS (Execution)                       │
│  - Execute forward/backward passes                          │
│  - Handle micro-batching internally                         │
│  - Scale gradients at optim_step                            │
│  - Model-specific implementations (FSDP, Megatron)          │
└─────────────────────────────────────────────────────────────┘

Tests

Deleted Tests

  • test_ppo_train.py - tested removed ppo_train method

Updated Tests

  • test_training_step.py - uses WorkerDispatch for policy tests
  • test_worker_offload.py - updated to work with new interfaces
  • test_save_load_checkpoint.py - updated imports
  • test_trainer.py - rewrote test_normalize_mini_batch_size

Gradient scaling

Since the worker no longer knows mini batches, we scale gradients instead of loss:

Old (scale loss during backward)

for i in 1..N:
    grad += (1/N) * ∂loss_i/∂param
optimizer.step(grad)

New (scale gradients at optim_step)

for i in 1..N:
    grad += ∂loss_i/∂param
grad *= 1/N
optimizer.step(grad)

Both produce: grad = (1/N) * Σ ∂loss_i/∂param


What's next for training refactor?

  • Remove weight sync logic from trainer. This should not be explicitly triggered by the trainer
  • Remove other “infra” calls from trainer, such as empty_cache
  • Create separate entry points for a) launching workers and b) launching training. Currently trainer.py sets up the workers (build_models) and launches training
  • Update Megatron (@erictang000) to bring to same state instead of branching on which backend is used

lcm_dp_size = math.lcm(lcm_dp_size, self.critic_model.actor_infos[0].rank.dp_size)
if self.ref_model is not None:
lcm_dp_size = math.lcm(lcm_dp_size, self.ref_model.actor_infos[0].rank.dp_size)
lcm_dp_size = self.dispatch.get_lcm_dp_size()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also needs to be moved out of trainer. The trainer shouldn't know/care about dp size.

Copy link
Collaborator

@erictang000 erictang000 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks pretty good to me! will work on the megatron worker update off of this branch

Also calling out some work that's actively ongoing for adding dynamic batch sizing that's touching relevant codepaths:
#817
#847

maybe the best thing to do would be to rebase #847 on top of the tinkerify branch? (and move down the micro batching logic updates from there into into forward_backward)

scale = 1.0 / self._micro_batches_accumulated
for param in self.model.parameters():
if param.grad is not None:
param.grad.mul_(scale)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that this is cleaner to not worry about loss scaling during forward (which requires pre-knowing the details of how we break down the microbatch or how often we call optim_step).

But are you sure that this is exactly numerically equivalent + compatible with autocasting/mixed precision? Scaling the loss prior to backward seems cleaner in that sense, can we verify that this is numerically identical (and that there aren't overflow edge cases here) to scaling the loss in pytorch?

Copy link
Collaborator

@erictang000 erictang000 Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coming back to this now that i'm implementing this for megatron - I think it overall seems cleaner and more flexible to handle scaling loss for gradient accumulation in forward_backward, where we could just set the default to weight each micro batch evenly and scale the loss for each microbatch by (1/num_microbatches). We could even theoretically provide an optional parameter for forward_backward that specifies a loss scaling term (default 1), so that users could flexibly scale the loss for different mini-batches differently (i.e. if you have 2 mini batches of different sizes, you might want to make sure that the 2nd mini batch loss is scaled to size).

def _forward_backward_micro(experience, microbatch_weight):
     ...
     loss = loss * microbatch_weight
     loss.backward()
     ...
def forward_backward(data, loss_weight=1.0):
     micro_batch_iterator = BatchIterator(data, micro_batch_size, drop_last=False)
     for micro_batch in micro_batch_iterator:
            metrics = self._forward_backward_micro(micro_batch, microbatch_weight=loss_weight / len(micro_batch_iterator))
      ...

This would still be compatible with the upstream tinker API which doesn't let you specify a loss weight.

wdyt? @tyler-griggs

else:
base_log_probs = None
# Critic forward (dispatch handles offload/backload automatically)
if self.has_critic:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one reason that this logic was so messy before was that you could technically overlap the forward pass for critic/ref/policy, which is something you lose here. This probably wasn't obvious or used very much in most of our configs since everything was colocated, but could matter more at scale.

Once we make the api async (and handle some requests queue on the server side) we could add that functionality back - do you think that's worth adding a TODO?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants