Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 41 additions & 77 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.data_models.completion import Completion
from forge.env import MONARCH_HOSTMESH_V1
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
Expand All @@ -42,73 +43,43 @@

@dataclass
class Episode:
# TODO: add adtional layer for multi-turn
episode_id: str
request: str
policy_version: int
pad_id: int
request_len: int
response_len: int
target: Any | None = None
# processed data
response: str | None = None
request_tokens: list[int] | None = None
response_tokens: list[int] | None = None
# Processed data
completion: Completion | None = None
ref_logprobs: torch.Tensor | None = None
reward: float | None = None
advantage: float | None = None

@property
def request_tensor(self):
tensor = torch.tensor(self.request_tokens, dtype=torch.long)
def policy_version(self) -> int | None:
return getattr(self.completion, "generator_version", None)

@property
def request_tensor(self) -> torch.Tensor:
request_tokens: torch.Tensor = self.completion.prompt_ids
tensor = torch.tensor(request_tokens, dtype=torch.long)
if tensor.shape[0] < self.request_len: # left pad
diff = self.request_len - tensor.shape[0]
tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
return tensor

@property
def response_tensor(self):
tensor = torch.tensor(self.response_tokens, dtype=torch.long)
def response_tensor(self) -> torch.Tensor:
response_tokens: torch.Tensor = self.completion.token_ids
tensor = torch.tensor(response_tokens, dtype=torch.long)
if tensor.shape[0] < self.response_len: # right pad
diff = self.response_len - tensor.shape[0]
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
return tensor


@dataclass
class Group:
group_id: str
episodes: list[Episode]

@classmethod
def new_group(
cls,
group_id: int,
group_size: int,
request: str,
policy_version: int,
pad_id: int,
request_len: int,
response_len: int,
target: Any = None,
):
episodes = []
for _ in range(group_size):
episodes.append(
Episode(
episode_id=str(uuid.uuid4()),
request=request,
policy_version=policy_version,
pad_id=pad_id,
request_len=request_len,
response_len=response_len,
target=target,
)
)
return cls(str(group_id), episodes)


def collate(batches: list[list[Episode]]):
def collate(
batches: list[list[Episode]],
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
inputs = []
targets = []
for batch in batches:
Expand Down Expand Up @@ -220,9 +191,9 @@ class ComputeAdvantages(ForgeActor):
"""Compute advantages for GRPO using reward signals."""

@endpoint
async def compute(self, group: Group) -> list[float]:
async def compute(self, episodes: list[Episode]) -> list[float]:
# TODO: add batch processing
rewards = torch.tensor([[e.reward for e in group.episodes]])
rewards = torch.tensor([[e.reward for e in episodes]])
mean = rewards.mean(1, keepdim=True)
std = rewards.std(1, keepdim=True)
advantages = (rewards - mean) / (std + 1e-4)
Expand Down Expand Up @@ -387,68 +358,61 @@ async def continuous_rollouts():
while not shutdown_event.is_set():
t = Tracer("main_perf/continuous_rollouts")
t.start()
sample = await dataloader.sample.call_one()
sample: dict[str, str] = await dataloader.sample.call_one()
if sample is None:
print("Dataloader is empty, exiting continuous rollout")
return

t.step("data_loading")

prompt, target = sample["request"], sample["target"]
responses = await policy.generate.route(prompt)
# TODO: this shall be part of the responses metadata instead of a separate call
version = await policy.get_version.route()

responses: list[Completion] = await policy.generate.route(prompt)
t.step("policy_generation")

assert (
len(responses) > 0
), "Sanity check: Responses should NEVER return empty"
assert (
version := responses[0].generator_version
) is not None, "Response must indicate a version"
group = Group.new_group(
group_id=rollout_count,
group_size=group_size,
request=prompt,
policy_version=version,
pad_id=pad_id,
request_len=max_req_tokens,
response_len=max_res_tokens,
target=target,
)

# Construct episodes and calculate rewards
episodes: list[Episode] = []
input_ids = torch.ones(
(group_size, max_req_tokens + max_res_tokens),
dtype=torch.long,
device="cuda",
)
# Populate episode info and calculate rewards
for i, (episode, response) in enumerate(zip(group.episodes, responses)):
episode.request_tokens = response.prompt_ids
episode.response_tokens = response.token_ids
episode.response = response.text
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor
for i, response in enumerate(responses):
episode = Episode(
episode_id=str(uuid.uuid4()),
pad_id=pad_id,
request_len=max_req_tokens,
response_len=max_res_tokens,
target=target,
completion=response,
)
episode.reward = await reward_actor.evaluate_response.route(
prompt=prompt, response=response.text, target=target
)
episodes.append(episode)

# Build input_ids for reference logprobs
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor

t.step("reward_evaluation")

ref_logprobs = await ref_model.forward.route(
ref_logprobs: torch.Tensor = await ref_model.forward.route(
input_ids, max_req_tokens, return_logprobs=True
)
t.step("reference_model_calculate_logprobs")

for i, episode in enumerate(group.episodes):
for i, episode in enumerate(episodes):
episode.ref_logprobs = ref_logprobs[i]
del ref_logprobs, input_ids
t.step("compute_logprobs")
t.step("gc_ref_logprobs")

# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.call_one(group)
for episode, advantage in zip(group.episodes, advantages):
advantages = await compute_advantages.compute.call_one(episodes)
Copy link
Contributor

@JenniferWang JenniferWang Oct 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to this diff but now since we're scrutinizing the main flow again, I think making compute_advantages its own Actor is very weird and probably the opposite to an "optimization"

  1. We do not expose capability to specify the hostmesh for a specific actor -- ideally, this should be collocated with the generator replica that produces this batch.
  2. ComputeAdvantage only needs the rewards; so very likely the entire episodes are serialized.

I wonder, if for now it should be just inlined in the sample call; or allocating a proc on the Policy mesh along side the PolicyWorker to handle the computation but chain the calls and return the result together in sample

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JenniferWang these are good points. I want to propose an idea (not for you to implement @Jack-Khuu just brainstorming if this makes sense)

with policy.session() as s:
    host: HostMesh =  await s.get_host_mesh() # returns the host mesh associated with this replica
    advantages = host.run_task(compute_advantages) # where compute_advantages is a function

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chained calls would be cool 👀

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks legit; +1 on chained calls

for episode, advantage in zip(episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.call_one(episode)

Expand Down
10 changes: 3 additions & 7 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,9 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
def _preprocess_add_request(
self, request: EngineCoreRequest
) -> tuple[Request, int]:
""" (forge/issues/332) Will require attention when we bump vllm versions
https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419"""
"""(forge/issues/332) Will require attention when we bump vllm versions
https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419
"""
if request.mm_hashes is not None:
raise NotImplementedError("Support for mm_hash is not implemented yet.")
req = Request.from_engine_core_request(request)
Expand Down Expand Up @@ -445,11 +446,6 @@ async def update_weights(self, policy_version: int) -> None:
async def _reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()

@endpoint
async def get_version(self) -> int:
"""Get the current policy version."""
return self.policy_version

@endpoint
async def stop(self):
self.running = False
Expand Down
Loading