-
Notifications
You must be signed in to change notification settings - Fork 16
Flatten GRPO main: Group and Episode #400
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
df0e5a9
85fab12
de51076
843420b
6c7e600
9ad1291
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
from forge.controller.actor import ForgeActor | ||
from forge.controller.provisioner import init_provisioner, shutdown | ||
from forge.data.rewards import MathReward, ThinkingReward | ||
from forge.data_models.completion import Completion | ||
from forge.env import MONARCH_HOSTMESH_V1 | ||
from forge.observability.metric_actors import get_or_create_metric_logger | ||
from forge.observability.metrics import record_metric, Reduce | ||
|
@@ -42,73 +43,43 @@ | |
|
||
@dataclass | ||
class Episode: | ||
# TODO: add adtional layer for multi-turn | ||
episode_id: str | ||
request: str | ||
policy_version: int | ||
pad_id: int | ||
request_len: int | ||
response_len: int | ||
target: Any | None = None | ||
# processed data | ||
response: str | None = None | ||
request_tokens: list[int] | None = None | ||
response_tokens: list[int] | None = None | ||
# Processed data | ||
completion: Completion | None = None | ||
ref_logprobs: torch.Tensor | None = None | ||
reward: float | None = None | ||
advantage: float | None = None | ||
|
||
@property | ||
def request_tensor(self): | ||
tensor = torch.tensor(self.request_tokens, dtype=torch.long) | ||
def policy_version(self) -> int | None: | ||
return getattr(self.completion, "generator_version", None) | ||
Jack-Khuu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
@property | ||
def request_tensor(self) -> torch.Tensor: | ||
request_tokens: torch.Tensor = self.completion.prompt_ids | ||
tensor = torch.tensor(request_tokens, dtype=torch.long) | ||
if tensor.shape[0] < self.request_len: # left pad | ||
diff = self.request_len - tensor.shape[0] | ||
tensor = F.pad(tensor, (diff, 0), value=self.pad_id) | ||
return tensor | ||
|
||
@property | ||
def response_tensor(self): | ||
tensor = torch.tensor(self.response_tokens, dtype=torch.long) | ||
def response_tensor(self) -> torch.Tensor: | ||
response_tokens: torch.Tensor = self.completion.token_ids | ||
tensor = torch.tensor(response_tokens, dtype=torch.long) | ||
if tensor.shape[0] < self.response_len: # right pad | ||
diff = self.response_len - tensor.shape[0] | ||
tensor = F.pad(tensor, (0, diff), value=self.pad_id) | ||
return tensor | ||
|
||
|
||
@dataclass | ||
class Group: | ||
Jack-Khuu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
group_id: str | ||
episodes: list[Episode] | ||
|
||
@classmethod | ||
def new_group( | ||
cls, | ||
group_id: int, | ||
group_size: int, | ||
request: str, | ||
policy_version: int, | ||
pad_id: int, | ||
request_len: int, | ||
response_len: int, | ||
target: Any = None, | ||
): | ||
episodes = [] | ||
for _ in range(group_size): | ||
episodes.append( | ||
Episode( | ||
episode_id=str(uuid.uuid4()), | ||
request=request, | ||
policy_version=policy_version, | ||
pad_id=pad_id, | ||
request_len=request_len, | ||
response_len=response_len, | ||
target=target, | ||
) | ||
) | ||
return cls(str(group_id), episodes) | ||
|
||
|
||
def collate(batches: list[list[Episode]]): | ||
def collate( | ||
Jack-Khuu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
batches: list[list[Episode]], | ||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: | ||
inputs = [] | ||
targets = [] | ||
for batch in batches: | ||
|
@@ -220,9 +191,9 @@ class ComputeAdvantages(ForgeActor): | |
"""Compute advantages for GRPO using reward signals.""" | ||
|
||
@endpoint | ||
async def compute(self, group: Group) -> list[float]: | ||
async def compute(self, episodes: list[Episode]) -> list[float]: | ||
# TODO: add batch processing | ||
rewards = torch.tensor([[e.reward for e in group.episodes]]) | ||
rewards = torch.tensor([[e.reward for e in episodes]]) | ||
mean = rewards.mean(1, keepdim=True) | ||
std = rewards.std(1, keepdim=True) | ||
advantages = (rewards - mean) / (std + 1e-4) | ||
|
@@ -387,68 +358,61 @@ async def continuous_rollouts(): | |
while not shutdown_event.is_set(): | ||
t = Tracer("main_perf/continuous_rollouts") | ||
t.start() | ||
sample = await dataloader.sample.call_one() | ||
sample: dict[str, str] = await dataloader.sample.call_one() | ||
Jack-Khuu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
if sample is None: | ||
print("Dataloader is empty, exiting continuous rollout") | ||
return | ||
|
||
t.step("data_loading") | ||
|
||
prompt, target = sample["request"], sample["target"] | ||
responses = await policy.generate.route(prompt) | ||
# TODO: this shall be part of the responses metadata instead of a separate call | ||
version = await policy.get_version.route() | ||
|
||
responses: list[Completion] = await policy.generate.route(prompt) | ||
Jack-Khuu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
t.step("policy_generation") | ||
|
||
assert ( | ||
Jack-Khuu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
len(responses) > 0 | ||
), "Sanity check: Responses should NEVER return empty" | ||
assert ( | ||
version := responses[0].generator_version | ||
) is not None, "Response must indicate a version" | ||
group = Group.new_group( | ||
group_id=rollout_count, | ||
group_size=group_size, | ||
request=prompt, | ||
policy_version=version, | ||
pad_id=pad_id, | ||
request_len=max_req_tokens, | ||
response_len=max_res_tokens, | ||
target=target, | ||
) | ||
|
||
# Construct episodes and calculate rewards | ||
episodes: list[Episode] = [] | ||
Jack-Khuu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
input_ids = torch.ones( | ||
(group_size, max_req_tokens + max_res_tokens), | ||
dtype=torch.long, | ||
device="cuda", | ||
Jack-Khuu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
) | ||
# Populate episode info and calculate rewards | ||
for i, (episode, response) in enumerate(zip(group.episodes, responses)): | ||
episode.request_tokens = response.prompt_ids | ||
episode.response_tokens = response.token_ids | ||
episode.response = response.text | ||
input_ids[i, :max_req_tokens] = episode.request_tensor | ||
input_ids[i, max_req_tokens:] = episode.response_tensor | ||
for i, response in enumerate(responses): | ||
episode = Episode( | ||
episode_id=str(uuid.uuid4()), | ||
pad_id=pad_id, | ||
request_len=max_req_tokens, | ||
response_len=max_res_tokens, | ||
target=target, | ||
completion=response, | ||
) | ||
episode.reward = await reward_actor.evaluate_response.route( | ||
prompt=prompt, response=response.text, target=target | ||
) | ||
episodes.append(episode) | ||
|
||
# Build input_ids for reference logprobs | ||
input_ids[i, :max_req_tokens] = episode.request_tensor | ||
input_ids[i, max_req_tokens:] = episode.response_tensor | ||
|
||
t.step("reward_evaluation") | ||
|
||
ref_logprobs = await ref_model.forward.route( | ||
ref_logprobs: torch.Tensor = await ref_model.forward.route( | ||
Jack-Khuu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
input_ids, max_req_tokens, return_logprobs=True | ||
) | ||
t.step("reference_model_calculate_logprobs") | ||
|
||
for i, episode in enumerate(group.episodes): | ||
for i, episode in enumerate(episodes): | ||
episode.ref_logprobs = ref_logprobs[i] | ||
del ref_logprobs, input_ids | ||
t.step("compute_logprobs") | ||
t.step("gc_ref_logprobs") | ||
Jack-Khuu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
# Calculate advantages and add to replay buffer | ||
advantages = await compute_advantages.compute.call_one(group) | ||
for episode, advantage in zip(group.episodes, advantages): | ||
advantages = await compute_advantages.compute.call_one(episodes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not related to this diff but now since we're scrutinizing the main flow again, I think making
I wonder, if for now it should be just inlined in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @JenniferWang these are good points. I want to propose an idea (not for you to implement @Jack-Khuu just brainstorming if this makes sense)
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Chained calls would be cool 👀 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks legit; +1 on chained calls |
||
for episode, advantage in zip(episodes, advantages): | ||
episode.advantage = advantage | ||
await replay_buffer.add.call_one(episode) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.