-
Notifications
You must be signed in to change notification settings - Fork 17
Off-by-1 GRPO #140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Off-by-1 GRPO #140
Changes from 23 commits
e6b7692
a95a001
3ba0df6
3e32264
e4723bb
5a17c8b
52028a5
e2a3a68
2cf9d00
b85320c
f7626ce
bf31587
53c8c89
f494949
a13a1ac
833a6b6
0acbe4a
7d05aad
3c880dd
8796fa1
75447d9
2838937
3120100
8f4bda1
7825255
b511fe3
9b46a77
1a6d6df
e31f815
55c32be
b74a47c
14d6354
8fa4451
bdd03a8
7eedc91
4044087
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -5,7 +5,7 @@ | |||||||
# LICENSE file in the root directory of this source tree. | ||||||||
|
||||||||
import asyncio | ||||||||
import logging | ||||||||
import time | ||||||||
import uuid | ||||||||
from dataclasses import dataclass | ||||||||
from typing import Any, Callable, Optional | ||||||||
|
@@ -21,12 +21,11 @@ | |||||||
from forge.util.metric_logging import get_metric_logger | ||||||||
from monarch.actor import endpoint | ||||||||
from torch import nn | ||||||||
from torchstore import MultiProcessStore | ||||||||
from torchstore._state_dict_utils import DELIM, push_state_dict | ||||||||
from transformers import AutoModelForCausalLM | ||||||||
from vllm.transformers_utils.tokenizer import get_tokenizer | ||||||||
|
||||||||
logger = logging.getLogger(__name__) | ||||||||
logger.setLevel(logging.DEBUG) | ||||||||
|
||||||||
|
||||||||
def compute_logprobs( | ||||||||
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 | ||||||||
|
@@ -121,7 +120,7 @@ def new_group( | |||||||
target: Any = None, | ||||||||
): | ||||||||
episodes = [] | ||||||||
for i in range(group_size): | ||||||||
for _ in range(group_size): | ||||||||
episodes.append( | ||||||||
Episode( | ||||||||
episode_id=str(uuid.uuid4()), | ||||||||
|
@@ -145,6 +144,8 @@ class Trainer(ForgeActor): | |||||||
beta: float = 0.1 | ||||||||
epsilon: float = 0.1 | ||||||||
device: torch.device | None = None | ||||||||
store: MultiProcessStore | None = None | ||||||||
state_dict_key: str = "model_state_dict" | ||||||||
|
||||||||
@endpoint | ||||||||
def setup(self): | ||||||||
|
@@ -208,11 +209,19 @@ async def train_step(self, batch: list[Episode]): | |||||||
|
||||||||
self.optimizer.step() | ||||||||
|
||||||||
return {"loss": loss.item()} | ||||||||
return loss.item() | ||||||||
|
||||||||
@endpoint | ||||||||
async def push_weights(self): | ||||||||
pass | ||||||||
async def push_weights(self, version: int): | ||||||||
"""Update policy model weights with trainer's current weights.""" | ||||||||
start_time = time.time() | ||||||||
|
start_time = time.time() | |
# TODO - issues/148 followup | |
start_time = time.time() |
Just for my future reference, tagging some pieces for observability
joecummings marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
joecummings marked this conversation as resolved.
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasLLC Is this still the recommended way of doing things?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like we are using ts.initialize()
now and there is a global singleton torchstore. But I will let @LucasLLC weigh in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We'll throw away a lot of data this way for fully on policy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For short responses yeah definitely, if you look at the WandB logs (buffer_size/rollout), you can see that we build up a buffer of about 100 episodes and then evict the majority of them back and forth during weight updates.
When we start allowing much longer generations and our models are much bigger, this won't be as big of an issue.
joecummings marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be a call
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be call technically even though choose works since replicas=1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be call technically even though choose works since replicas=1
As a side note, even if we do call()
, what we are doing here is all the trainers training on the same batch right?
The replicas are just for fault tolerance? In this regard, if we want different trainers to train on different batches, the trainers themselves have to pull the batches right?
An alternative way to do this is we split the batch
into microbatch
es and then call choose()
on each microbatch
. After the whole batch
is done. We then do an all_reduce
(or other forms of reduce) to average the weights.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This isn't the current recommended way to use store now. You should just call it as a singleton inside of the trainer