-
Notifications
You must be signed in to change notification settings - Fork 222
[Tinker] Refactor trainer and worker (to move algo to trainer and infra to worker) #859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: tinkerify
Are you sure you want to change the base?
[Tinker] Refactor trainer and worker (to move algo to trainer and infra to worker) #859
Conversation
| lcm_dp_size = math.lcm(lcm_dp_size, self.critic_model.actor_infos[0].rank.dp_size) | ||
| if self.ref_model is not None: | ||
| lcm_dp_size = math.lcm(lcm_dp_size, self.ref_model.actor_infos[0].rank.dp_size) | ||
| lcm_dp_size = self.dispatch.get_lcm_dp_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also needs to be moved out of trainer. The trainer shouldn't know/care about dp size.
erictang000
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks pretty good to me! will work on the megatron worker update off of this branch
Also calling out some work that's actively ongoing for adding dynamic batch sizing that's touching relevant codepaths:
#817
#847
maybe the best thing to do would be to rebase #847 on top of the tinkerify branch? (and move down the micro batching logic updates from there into into forward_backward)
| scale = 1.0 / self._micro_batches_accumulated | ||
| for param in self.model.parameters(): | ||
| if param.grad is not None: | ||
| param.grad.mul_(scale) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get that this is cleaner to not worry about loss scaling during forward (which requires pre-knowing the details of how we break down the microbatch or how often we call optim_step).
But are you sure that this is exactly numerically equivalent + compatible with autocasting/mixed precision? Scaling the loss prior to backward seems cleaner in that sense, can we verify that this is numerically identical (and that there aren't overflow edge cases here) to scaling the loss in pytorch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
coming back to this now that i'm implementing this for megatron - I think it overall seems cleaner and more flexible to handle scaling loss for gradient accumulation in forward_backward, where we could just set the default to weight each micro batch evenly and scale the loss for each microbatch by (1/num_microbatches). We could even theoretically provide an optional parameter for forward_backward that specifies a loss scaling term (default 1), so that users could flexibly scale the loss for different mini-batches differently (i.e. if you have 2 mini batches of different sizes, you might want to make sure that the 2nd mini batch loss is scaled to size).
def _forward_backward_micro(experience, microbatch_weight):
...
loss = loss * microbatch_weight
loss.backward()
...
def forward_backward(data, loss_weight=1.0):
micro_batch_iterator = BatchIterator(data, micro_batch_size, drop_last=False)
for micro_batch in micro_batch_iterator:
metrics = self._forward_backward_micro(micro_batch, microbatch_weight=loss_weight / len(micro_batch_iterator))
...This would still be compatible with the upstream tinker API which doesn't let you specify a loss weight.
wdyt? @tyler-griggs
| else: | ||
| base_log_probs = None | ||
| # Critic forward (dispatch handles offload/backload automatically) | ||
| if self.has_critic: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one reason that this logic was so messy before was that you could technically overlap the forward pass for critic/ref/policy, which is something you lose here. This probably wasn't obvious or used very much in most of our configs since everything was colocated, but could matter more at scale.
Once we make the api async (and handle some requests queue on the server side) we could add that functionality back - do you think that's worth adding a TODO?
PR Overview
This PR has a pretty significant architectural refactor. The primary changes are:
Move algorithm logic into the trainer and infra logic into the workers.
ppo_train(from FSDP workers for now, @erictang000 to remove from Megatron)Introduce WorkerDispatch: an intermediate layer between the training loop and workers. Later, this is the entity that will sit beneath the Tinker API server. It is the entity that handles a "pool" of workers. Currently it dispatches requests to the pool of workers (by mesh or by pass_through) and handles onloading / offloading workers.
i. I'm very open to renaming suggestions.
Architecture
Here's a diagram of what the breakdown of responsibilities now looks like. Note that some of this will still need to change. E.g., the trainer will not call
dispatchin the future, it will call into the Tinker API server.Tests
Deleted Tests
test_ppo_train.py- tested removedppo_trainmethodUpdated Tests
test_training_step.py- usesWorkerDispatchfor policy teststest_worker_offload.py- updated to work with new interfacestest_save_load_checkpoint.py- updated importstest_trainer.py- rewrotetest_normalize_mini_batch_sizeGradient scaling
Since the worker no longer knows mini batches, we scale gradients instead of loss:
Old (scale loss during backward)
New (scale gradients at optim_step)
Both produce:
grad = (1/N) * Σ ∂loss_i/∂paramWhat's next for training refactor?