-
Notifications
You must be signed in to change notification settings - Fork 16
Flatten GRPO main: Group and Episode #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
df0e5a9
85fab12
de51076
843420b
6c7e600
9ad1291
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
from forge.controller.actor import ForgeActor | ||
from forge.controller.provisioner import init_provisioner, shutdown | ||
from forge.data.rewards import MathReward, ThinkingReward | ||
from forge.data_models.completion import Completion | ||
from forge.observability.metric_actors import get_or_create_metric_logger | ||
from forge.observability.metrics import record_metric, Reduce | ||
from forge.observability.perf_tracker import Tracer | ||
|
@@ -41,73 +42,51 @@ | |
|
||
@dataclass | ||
class Episode: | ||
# TODO: add adtional layer for multi-turn | ||
episode_id: str | ||
request: str | ||
policy_version: int | ||
pad_id: int | ||
request_len: int | ||
response_len: int | ||
target: Any | None = None | ||
# processed data | ||
response: str | None = None | ||
request_tokens: list[int] | None = None | ||
response_tokens: list[int] | None = None | ||
# Processed data | ||
completion: Completion | None = None | ||
ref_logprobs: torch.Tensor | None = None | ||
reward: float | None = None | ||
advantage: float | None = None | ||
|
||
@property | ||
def request_tensor(self): | ||
tensor = torch.tensor(self.request_tokens, dtype=torch.long) | ||
def policy_version(self) -> int | None: | ||
return self.completion.generator_version | ||
|
||
@property | ||
def request_tensor(self) -> torch.Tensor: | ||
request_tokens: torch.Tensor = self.completion.prompt_ids | ||
tensor = torch.tensor(request_tokens, dtype=torch.long) | ||
if tensor.shape[0] < self.request_len: # left pad | ||
diff = self.request_len - tensor.shape[0] | ||
tensor = F.pad(tensor, (diff, 0), value=self.pad_id) | ||
return tensor | ||
|
||
@property | ||
def response_tensor(self): | ||
tensor = torch.tensor(self.response_tokens, dtype=torch.long) | ||
def response_tensor(self) -> torch.Tensor: | ||
response_tokens: torch.Tensor = self.completion.token_ids | ||
tensor = torch.tensor(response_tokens, dtype=torch.long) | ||
if tensor.shape[0] < self.response_len: # right pad | ||
diff = self.response_len - tensor.shape[0] | ||
tensor = F.pad(tensor, (0, diff), value=self.pad_id) | ||
return tensor | ||
|
||
|
||
@dataclass | ||
class Group: | ||
Jack-Khuu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
group_id: str | ||
episodes: list[Episode] | ||
|
||
@classmethod | ||
def new_group( | ||
cls, | ||
group_id: int, | ||
group_size: int, | ||
request: str, | ||
policy_version: int, | ||
pad_id: int, | ||
request_len: int, | ||
response_len: int, | ||
target: Any = None, | ||
): | ||
episodes = [] | ||
for _ in range(group_size): | ||
episodes.append( | ||
Episode( | ||
episode_id=str(uuid.uuid4()), | ||
request=request, | ||
policy_version=policy_version, | ||
pad_id=pad_id, | ||
request_len=request_len, | ||
response_len=response_len, | ||
target=target, | ||
) | ||
) | ||
return cls(str(group_id), episodes) | ||
# Represents the group (G) of episodes in GRPO | ||
Group = list[Episode] | ||
|
||
|
||
def collate(batches: list[list[Episode]]): | ||
def collate( | ||
Jack-Khuu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
batches: list[Group], | ||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | ||
""" | ||
Collates a list of batches into a single batch of inputs and targets. | ||
Each batch is a list of episodes, and each episode is a dict of tensors. | ||
""" | ||
inputs = [] | ||
targets = [] | ||
for batch in batches: | ||
|
@@ -221,7 +200,7 @@ class ComputeAdvantages(ForgeActor): | |
@endpoint | ||
async def compute(self, group: Group) -> list[float]: | ||
# TODO: add batch processing | ||
rewards = torch.tensor([[e.reward for e in group.episodes]]) | ||
rewards = torch.tensor([[e.reward for e in group]]) | ||
mean = rewards.mean(1, keepdim=True) | ||
std = rewards.std(1, keepdim=True) | ||
advantages = (rewards - mean) / (std + 1e-4) | ||
|
@@ -386,44 +365,32 @@ async def continuous_rollouts(): | |
t.step("data_loading") | ||
|
||
prompt, target = sample["request"], sample["target"] | ||
responses = await policy.generate.route(prompt) | ||
# TODO: this shall be part of the responses metadata instead of a separate call | ||
version = await policy.get_version.route() | ||
|
||
responses: list[Completion] = await policy.generate.route(prompt) | ||
Jack-Khuu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
t.step("policy_generation") | ||
|
||
assert ( | ||
len(responses) > 0 | ||
), "Sanity check: Responses should NEVER return empty" | ||
assert ( | ||
version := responses[0].generator_version | ||
) is not None, "Response must indicate a version" | ||
group = Group.new_group( | ||
group_id=rollout_count, | ||
group_size=group_size, | ||
request=prompt, | ||
policy_version=version, | ||
pad_id=pad_id, | ||
request_len=max_req_tokens, | ||
response_len=max_res_tokens, | ||
target=target, | ||
) | ||
|
||
# Construct episodes and calculate rewards | ||
episodes = [] | ||
input_ids = torch.ones( | ||
(group_size, max_req_tokens + max_res_tokens), | ||
dtype=torch.long, | ||
device="cuda", | ||
) | ||
# Populate episode info and calculate rewards | ||
for i, (episode, response) in enumerate(zip(group.episodes, responses)): | ||
episode.request_tokens = response.prompt_ids | ||
episode.response_tokens = response.token_ids | ||
episode.response = response.text | ||
input_ids[i, :max_req_tokens] = episode.request_tensor | ||
input_ids[i, max_req_tokens:] = episode.response_tensor | ||
for i, response in enumerate(responses): | ||
episode = Episode( | ||
episode_id=str(uuid.uuid4()), | ||
pad_id=pad_id, | ||
request_len=max_req_tokens, | ||
response_len=max_res_tokens, | ||
target=target, | ||
completion=response, | ||
) | ||
episode.reward = await reward_actor.evaluate_response.route( | ||
prompt=prompt, response=response.text, target=target | ||
) | ||
episodes.append(episode) | ||
|
||
# Build input_ids for reference logprobs | ||
input_ids[i, :max_req_tokens] = episode.request_tensor | ||
input_ids[i, max_req_tokens:] = episode.response_tensor | ||
|
||
t.step("reward_evaluation") | ||
|
||
|
@@ -432,14 +399,13 @@ async def continuous_rollouts(): | |
) | ||
t.step("reference_model_calculate_logprobs") | ||
|
||
for i, episode in enumerate(group.episodes): | ||
for i, episode in enumerate(episodes): | ||
episode.ref_logprobs = ref_logprobs[i] | ||
del ref_logprobs, input_ids | ||
t.step("compute_logprobs") | ||
|
||
# Calculate advantages and add to replay buffer | ||
advantages = await compute_advantages.compute.call_one(group) | ||
for episode, advantage in zip(group.episodes, advantages): | ||
advantages = await compute_advantages.compute.call_one(episodes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not related to this diff but now since we're scrutinizing the main flow again, I think making
I wonder, if for now it should be just inlined in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JenniferWang these are good points. I want to propose an idea (not for you to implement @Jack-Khuu just brainstorming if this makes sense)
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chained calls would be cool 👀 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks legit; +1 on chained calls |
||
for episode, advantage in zip(episodes, advantages): | ||
episode.advantage = advantage | ||
await replay_buffer.add.call_one(episode) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.