From df0e5a95a019d012ac6b9fde89c82631c7320e8e Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Mon, 13 Oct 2025 23:51:46 -0700 Subject: [PATCH 1/4] Push initial removal; debugging hang --- apps/grpo/main.py | 125 +++++++++------------ src/forge/actors/policy.py | 10 +- src/forge/data_models/prompt.py | 9 ++ src/forge/data_models/scored_completion.py | 11 +- 4 files changed, 72 insertions(+), 83 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 6439ead85..826b922f0 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -28,6 +28,7 @@ from forge.controller.actor import ForgeActor from forge.controller.provisioner import init_provisioner, shutdown from forge.data.rewards import MathReward, ThinkingReward +from forge.data_models.scored_completion import ScoredCompletion from forge.env import MONARCH_HOSTMESH_V1 from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce @@ -42,24 +43,42 @@ @dataclass class Episode: - # TODO: add adtional layer for multi-turn episode_id: str - request: str - policy_version: int pad_id: int request_len: int response_len: int target: Any | None = None - # processed data - response: str | None = None - request_tokens: list[int] | None = None - response_tokens: list[int] | None = None + completion: ScoredCompletion | None = None ref_logprobs: torch.Tensor | None = None - reward: float | None = None advantage: float | None = None @property - def request_tensor(self): + def request(self) -> str | None: + prompt = getattr(self.completion, "prompt", None) + return prompt.get_first_turn_prompt() if prompt is not None else None + + @property + def policy_version(self) -> int | None: + return getattr(self.completion, "generator_version", None) + + @property + def response(self) -> str | None: + return getattr(self.completion, "text", None) + + @property + def request_tokens(self) -> torch.Tensor: + return getattr(self.completion, "prompt_ids", None) + + @property + def response_tokens(self) -> torch.Tensor: + return getattr(self.completion, "token_ids", None) + + @property + def reward(self) -> float: + return getattr(self.completion, "score", None) + + @property + def request_tensor(self) -> torch.Tensor: tensor = torch.tensor(self.request_tokens, dtype=torch.long) if tensor.shape[0] < self.request_len: # left pad diff = self.request_len - tensor.shape[0] @@ -67,7 +86,7 @@ def request_tensor(self): return tensor @property - def response_tensor(self): + def response_tensor(self) -> torch.Tensor: tensor = torch.tensor(self.response_tokens, dtype=torch.long) if tensor.shape[0] < self.response_len: # right pad diff = self.response_len - tensor.shape[0] @@ -75,39 +94,6 @@ def response_tensor(self): return tensor -@dataclass -class Group: - group_id: str - episodes: list[Episode] - - @classmethod - def new_group( - cls, - group_id: int, - group_size: int, - request: str, - policy_version: int, - pad_id: int, - request_len: int, - response_len: int, - target: Any = None, - ): - episodes = [] - for _ in range(group_size): - episodes.append( - Episode( - episode_id=str(uuid.uuid4()), - request=request, - policy_version=policy_version, - pad_id=pad_id, - request_len=request_len, - response_len=response_len, - target=target, - ) - ) - return cls(str(group_id), episodes) - - def collate(batches: list[list[Episode]]): inputs = [] targets = [] @@ -220,9 +206,9 @@ class ComputeAdvantages(ForgeActor): """Compute advantages for GRPO using reward signals.""" @endpoint - async def compute(self, group: Group) -> list[float]: + async def compute(self, episodes: list[Episode]) -> list[float]: # TODO: add batch processing - rewards = torch.tensor([[e.reward for e in group.episodes]]) + rewards = torch.tensor([[e.reward for e in episodes]]) mean = rewards.mean(1, keepdim=True) std = rewards.std(1, keepdim=True) advantages = (rewards - mean) / (std + 1e-4) @@ -396,43 +382,36 @@ async def continuous_rollouts(): prompt, target = sample["request"], sample["target"] responses = await policy.generate.route(prompt) - # TODO: this shall be part of the responses metadata instead of a separate call - version = await policy.get_version.route() - t.step("policy_generation") assert ( len(responses) > 0 ), "Sanity check: Responses should NEVER return empty" - assert ( - version := responses[0].generator_version - ) is not None, "Response must indicate a version" - group = Group.new_group( - group_id=rollout_count, - group_size=group_size, - request=prompt, - policy_version=version, - pad_id=pad_id, - request_len=max_req_tokens, - response_len=max_res_tokens, - target=target, - ) + # Construct episodes and calculate rewards + episodes: list[Episode] = [] input_ids = torch.ones( (group_size, max_req_tokens + max_res_tokens), dtype=torch.long, device="cuda", ) - # Populate episode info and calculate rewards - for i, (episode, response) in enumerate(zip(group.episodes, responses)): - episode.request_tokens = response.prompt_ids - episode.response_tokens = response.token_ids - episode.response = response.text - input_ids[i, :max_req_tokens] = episode.request_tensor - input_ids[i, max_req_tokens:] = episode.response_tensor - episode.reward = await reward_actor.evaluate_response.route( + for i, response in enumerate(responses): + reward: float = await reward_actor.evaluate_response.route( prompt=prompt, response=response.text, target=target ) + episode = Episode( + episode_id=str(uuid.uuid4()), + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + target=target, + completion=ScoredCompletion.from_completion(response, score=reward), + ) + episodes.append(episode) + + # Build input_ids for reference logprobs + input_ids[i, :max_req_tokens] = episode.request_tensor + input_ids[i, max_req_tokens:] = episode.response_tensor t.step("reward_evaluation") @@ -441,14 +420,14 @@ async def continuous_rollouts(): ) t.step("reference_model_calculate_logprobs") - for i, episode in enumerate(group.episodes): + for i, episode in enumerate(episodes): episode.ref_logprobs = ref_logprobs[i] del ref_logprobs, input_ids - t.step("compute_logprobs") + t.step("gc_ref_logprobs") # Calculate advantages and add to replay buffer - advantages = await compute_advantages.compute.call_one(group) - for episode, advantage in zip(group.episodes, advantages): + advantages = await compute_advantages.compute.call_one(episodes) + for episode, advantage in zip(episodes, advantages): episode.advantage = advantage await replay_buffer.add.call_one(episode) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 3a1b3e86e..a89581ef1 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -341,8 +341,9 @@ async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]: def _preprocess_add_request( self, request: EngineCoreRequest ) -> tuple[Request, int]: - """ (forge/issues/332) Will require attention when we bump vllm versions - https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419""" + """(forge/issues/332) Will require attention when we bump vllm versions + https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 + """ if request.mm_hashes is not None: raise NotImplementedError("Support for mm_hash is not implemented yet.") req = Request.from_engine_core_request(request) @@ -445,11 +446,6 @@ async def update_weights(self, policy_version: int) -> None: async def _reset_prefix_cache(self): self.scheduler.reset_prefix_cache() - @endpoint - async def get_version(self) -> int: - """Get the current policy version.""" - return self.policy_version - @endpoint async def stop(self): self.running = False diff --git a/src/forge/data_models/prompt.py b/src/forge/data_models/prompt.py index c7f79c04e..e16a4d4e9 100644 --- a/src/forge/data_models/prompt.py +++ b/src/forge/data_models/prompt.py @@ -40,6 +40,15 @@ def from_prompt( messages=messages, ) + def get_first_turn_prompt(self) -> str: + """Returns the string prompt of the first turn.""" + if len(self.messages) == 0: + raise ValueError("No messages in prompt.") + elif len(self.messages[0].chunks) == 0: + raise ValueError("No chunks in first message.") + + return self.messages[0].chunks[0] + def prompt_to_messages( prompt: str, system_instruction: str | None = None diff --git a/src/forge/data_models/scored_completion.py b/src/forge/data_models/scored_completion.py index f41ff7b59..68598e3d4 100644 --- a/src/forge/data_models/scored_completion.py +++ b/src/forge/data_models/scored_completion.py @@ -10,10 +10,15 @@ @dataclass -class ScoredCompletion: +class ScoredCompletion(Completion): """A completion with an associated score (from a reward model or human).""" - completion: Completion - score: float # akin to reward + score: float | None = None # akin to reward # TODO: add more fields as needed. + + @classmethod + def from_completion( + cls, completion: Completion, score: float + ) -> "ScoredCompletion": + return cls(**asdict(completion), score=score) From 85fab12e87849bf7ee52ba1d9bb6b24237d9c844 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 14 Oct 2025 01:36:26 -0700 Subject: [PATCH 2/4] Fix hang, remove test properties --- apps/grpo/main.py | 51 ++++++++-------------- src/forge/data_models/prompt.py | 9 ---- src/forge/data_models/scored_completion.py | 11 ++--- 3 files changed, 21 insertions(+), 50 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 826b922f0..88797a513 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -28,7 +28,7 @@ from forge.controller.actor import ForgeActor from forge.controller.provisioner import init_provisioner, shutdown from forge.data.rewards import MathReward, ThinkingReward -from forge.data_models.scored_completion import ScoredCompletion +from forge.data_models.completion import Completion from forge.env import MONARCH_HOSTMESH_V1 from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce @@ -48,38 +48,20 @@ class Episode: request_len: int response_len: int target: Any | None = None - completion: ScoredCompletion | None = None + # Processed data + completion: Completion | None = None ref_logprobs: torch.Tensor | None = None + reward: float | None = None advantage: float | None = None - @property - def request(self) -> str | None: - prompt = getattr(self.completion, "prompt", None) - return prompt.get_first_turn_prompt() if prompt is not None else None - @property def policy_version(self) -> int | None: return getattr(self.completion, "generator_version", None) - @property - def response(self) -> str | None: - return getattr(self.completion, "text", None) - - @property - def request_tokens(self) -> torch.Tensor: - return getattr(self.completion, "prompt_ids", None) - - @property - def response_tokens(self) -> torch.Tensor: - return getattr(self.completion, "token_ids", None) - - @property - def reward(self) -> float: - return getattr(self.completion, "score", None) - @property def request_tensor(self) -> torch.Tensor: - tensor = torch.tensor(self.request_tokens, dtype=torch.long) + request_tokens: torch.Tensor = self.completion.prompt_ids + tensor = torch.tensor(request_tokens, dtype=torch.long) if tensor.shape[0] < self.request_len: # left pad diff = self.request_len - tensor.shape[0] tensor = F.pad(tensor, (diff, 0), value=self.pad_id) @@ -87,14 +69,17 @@ def request_tensor(self) -> torch.Tensor: @property def response_tensor(self) -> torch.Tensor: - tensor = torch.tensor(self.response_tokens, dtype=torch.long) + response_tokens: torch.Tensor = self.completion.token_ids + tensor = torch.tensor(response_tokens, dtype=torch.long) if tensor.shape[0] < self.response_len: # right pad diff = self.response_len - tensor.shape[0] tensor = F.pad(tensor, (0, diff), value=self.pad_id) return tensor -def collate(batches: list[list[Episode]]): +def collate( + batches: list[list[Episode]], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: inputs = [] targets = [] for batch in batches: @@ -373,7 +358,7 @@ async def continuous_rollouts(): while not shutdown_event.is_set(): t = Tracer("main_perf/continuous_rollouts") t.start() - sample = await dataloader.sample.call_one() + sample: dict[str, str] = await dataloader.sample.call_one() if sample is None: print("Dataloader is empty, exiting continuous rollout") return @@ -381,7 +366,7 @@ async def continuous_rollouts(): t.step("data_loading") prompt, target = sample["request"], sample["target"] - responses = await policy.generate.route(prompt) + responses: list[Completion] = await policy.generate.route(prompt) t.step("policy_generation") assert ( @@ -396,16 +381,16 @@ async def continuous_rollouts(): device="cuda", ) for i, response in enumerate(responses): - reward: float = await reward_actor.evaluate_response.route( - prompt=prompt, response=response.text, target=target - ) episode = Episode( episode_id=str(uuid.uuid4()), pad_id=pad_id, request_len=max_req_tokens, response_len=max_res_tokens, target=target, - completion=ScoredCompletion.from_completion(response, score=reward), + completion=response, + ) + episode.reward = await reward_actor.evaluate_response.route( + prompt=prompt, response=response.text, target=target ) episodes.append(episode) @@ -415,7 +400,7 @@ async def continuous_rollouts(): t.step("reward_evaluation") - ref_logprobs = await ref_model.forward.route( + ref_logprobs: torch.Tensor = await ref_model.forward.route( input_ids, max_req_tokens, return_logprobs=True ) t.step("reference_model_calculate_logprobs") diff --git a/src/forge/data_models/prompt.py b/src/forge/data_models/prompt.py index e16a4d4e9..c7f79c04e 100644 --- a/src/forge/data_models/prompt.py +++ b/src/forge/data_models/prompt.py @@ -40,15 +40,6 @@ def from_prompt( messages=messages, ) - def get_first_turn_prompt(self) -> str: - """Returns the string prompt of the first turn.""" - if len(self.messages) == 0: - raise ValueError("No messages in prompt.") - elif len(self.messages[0].chunks) == 0: - raise ValueError("No chunks in first message.") - - return self.messages[0].chunks[0] - def prompt_to_messages( prompt: str, system_instruction: str | None = None diff --git a/src/forge/data_models/scored_completion.py b/src/forge/data_models/scored_completion.py index 68598e3d4..f41ff7b59 100644 --- a/src/forge/data_models/scored_completion.py +++ b/src/forge/data_models/scored_completion.py @@ -10,15 +10,10 @@ @dataclass -class ScoredCompletion(Completion): +class ScoredCompletion: """A completion with an associated score (from a reward model or human).""" - score: float | None = None # akin to reward + completion: Completion + score: float # akin to reward # TODO: add more fields as needed. - - @classmethod - def from_completion( - cls, completion: Completion, score: float - ) -> "ScoredCompletion": - return cls(**asdict(completion), score=score) From 843420bdd126ce06a1bd8bc6e98259d117ae8b81 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 14 Oct 2025 10:14:49 -0700 Subject: [PATCH 3/4] Address comments --- apps/grpo/main.py | 27 +++++++++++++++------------ src/forge/actors/replay_buffer.py | 1 + 2 files changed, 16 insertions(+), 12 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 88797a513..f1e8e7c97 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -56,7 +56,7 @@ class Episode: @property def policy_version(self) -> int | None: - return getattr(self.completion, "generator_version", None) + return self.completion.generator_version @property def request_tensor(self) -> torch.Tensor: @@ -77,9 +77,17 @@ def response_tensor(self) -> torch.Tensor: return tensor +# Represents the group (G) of episodes in GRPO +Group = list[Episode] + + def collate( - batches: list[list[Episode]], + batches: list[Group], ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """ + Collates a list of batches into a single batch of inputs and targets. + Each batch is a list of episodes, and each episode is a dict of tensors. + """ inputs = [] targets = [] for batch in batches: @@ -191,9 +199,9 @@ class ComputeAdvantages(ForgeActor): """Compute advantages for GRPO using reward signals.""" @endpoint - async def compute(self, episodes: list[Episode]) -> list[float]: + async def compute(self, group: Group) -> list[float]: # TODO: add batch processing - rewards = torch.tensor([[e.reward for e in episodes]]) + rewards = torch.tensor([[e.reward for e in group]]) mean = rewards.mean(1, keepdim=True) std = rewards.std(1, keepdim=True) advantages = (rewards - mean) / (std + 1e-4) @@ -358,7 +366,7 @@ async def continuous_rollouts(): while not shutdown_event.is_set(): t = Tracer("main_perf/continuous_rollouts") t.start() - sample: dict[str, str] = await dataloader.sample.call_one() + sample = await dataloader.sample.call_one() if sample is None: print("Dataloader is empty, exiting continuous rollout") return @@ -369,12 +377,8 @@ async def continuous_rollouts(): responses: list[Completion] = await policy.generate.route(prompt) t.step("policy_generation") - assert ( - len(responses) > 0 - ), "Sanity check: Responses should NEVER return empty" - # Construct episodes and calculate rewards - episodes: list[Episode] = [] + episodes = [] input_ids = torch.ones( (group_size, max_req_tokens + max_res_tokens), dtype=torch.long, @@ -400,7 +404,7 @@ async def continuous_rollouts(): t.step("reward_evaluation") - ref_logprobs: torch.Tensor = await ref_model.forward.route( + ref_logprobs = await ref_model.forward.route( input_ids, max_req_tokens, return_logprobs=True ) t.step("reference_model_calculate_logprobs") @@ -408,7 +412,6 @@ async def continuous_rollouts(): for i, episode in enumerate(episodes): episode.ref_logprobs = ref_logprobs[i] del ref_logprobs, input_ids - t.step("gc_ref_logprobs") # Calculate advantages and add to replay buffer advantages = await compute_advantages.compute.call_one(episodes) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index a92fd5501..db1caf333 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -105,6 +105,7 @@ async def sample( for dp_idx in range(self.dp_size) ] + # Call the underlying collate function to collate the episodes into a batch return self.collate(reshaped_episodes) @endpoint From 6c7e600bb4cdfdc74f9817d906e39d7c56ac8807 Mon Sep 17 00:00:00 2001 From: Jack-Khuu Date: Tue, 14 Oct 2025 10:37:54 -0700 Subject: [PATCH 4/4] Remove device --- apps/grpo/main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index f1e8e7c97..e3fb851e6 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -382,7 +382,6 @@ async def continuous_rollouts(): input_ids = torch.ones( (group_size, max_req_tokens + max_res_tokens), dtype=torch.long, - device="cuda", ) for i, response in enumerate(responses): episode = Episode(