diff --git a/apps/grpo/main.py b/apps/grpo/main.py index ff46fea20..ed33e7e2f 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -28,6 +28,7 @@ from forge.controller.actor import ForgeActor from forge.controller.provisioner import init_provisioner, shutdown from forge.data.rewards import MathReward, ThinkingReward +from forge.data_models.completion import Completion from forge.observability.metric_actors import get_or_create_metric_logger from forge.observability.metrics import record_metric, Reduce from forge.observability.perf_tracker import Tracer @@ -41,73 +42,51 @@ @dataclass class Episode: - # TODO: add adtional layer for multi-turn episode_id: str - request: str - policy_version: int pad_id: int request_len: int response_len: int target: Any | None = None - # processed data - response: str | None = None - request_tokens: list[int] | None = None - response_tokens: list[int] | None = None + # Processed data + completion: Completion | None = None ref_logprobs: torch.Tensor | None = None reward: float | None = None advantage: float | None = None @property - def request_tensor(self): - tensor = torch.tensor(self.request_tokens, dtype=torch.long) + def policy_version(self) -> int | None: + return self.completion.generator_version + + @property + def request_tensor(self) -> torch.Tensor: + request_tokens: torch.Tensor = self.completion.prompt_ids + tensor = torch.tensor(request_tokens, dtype=torch.long) if tensor.shape[0] < self.request_len: # left pad diff = self.request_len - tensor.shape[0] tensor = F.pad(tensor, (diff, 0), value=self.pad_id) return tensor @property - def response_tensor(self): - tensor = torch.tensor(self.response_tokens, dtype=torch.long) + def response_tensor(self) -> torch.Tensor: + response_tokens: torch.Tensor = self.completion.token_ids + tensor = torch.tensor(response_tokens, dtype=torch.long) if tensor.shape[0] < self.response_len: # right pad diff = self.response_len - tensor.shape[0] tensor = F.pad(tensor, (0, diff), value=self.pad_id) return tensor -@dataclass -class Group: - group_id: str - episodes: list[Episode] - - @classmethod - def new_group( - cls, - group_id: int, - group_size: int, - request: str, - policy_version: int, - pad_id: int, - request_len: int, - response_len: int, - target: Any = None, - ): - episodes = [] - for _ in range(group_size): - episodes.append( - Episode( - episode_id=str(uuid.uuid4()), - request=request, - policy_version=policy_version, - pad_id=pad_id, - request_len=request_len, - response_len=response_len, - target=target, - ) - ) - return cls(str(group_id), episodes) +# Represents the group (G) of episodes in GRPO +Group = list[Episode] -def collate(batches: list[list[Episode]]): +def collate( + batches: list[Group], +) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: + """ + Collates a list of batches into a single batch of inputs and targets. + Each batch is a list of episodes, and each episode is a dict of tensors. + """ inputs = [] targets = [] for batch in batches: @@ -221,7 +200,7 @@ class ComputeAdvantages(ForgeActor): @endpoint async def compute(self, group: Group) -> list[float]: # TODO: add batch processing - rewards = torch.tensor([[e.reward for e in group.episodes]]) + rewards = torch.tensor([[e.reward for e in group]]) mean = rewards.mean(1, keepdim=True) std = rewards.std(1, keepdim=True) advantages = (rewards - mean) / (std + 1e-4) @@ -386,44 +365,32 @@ async def continuous_rollouts(): t.step("data_loading") prompt, target = sample["request"], sample["target"] - responses = await policy.generate.route(prompt) - # TODO: this shall be part of the responses metadata instead of a separate call - version = await policy.get_version.route() - + responses: list[Completion] = await policy.generate.route(prompt) t.step("policy_generation") - assert ( - len(responses) > 0 - ), "Sanity check: Responses should NEVER return empty" - assert ( - version := responses[0].generator_version - ) is not None, "Response must indicate a version" - group = Group.new_group( - group_id=rollout_count, - group_size=group_size, - request=prompt, - policy_version=version, - pad_id=pad_id, - request_len=max_req_tokens, - response_len=max_res_tokens, - target=target, - ) - + # Construct episodes and calculate rewards + episodes = [] input_ids = torch.ones( (group_size, max_req_tokens + max_res_tokens), dtype=torch.long, - device="cuda", ) - # Populate episode info and calculate rewards - for i, (episode, response) in enumerate(zip(group.episodes, responses)): - episode.request_tokens = response.prompt_ids - episode.response_tokens = response.token_ids - episode.response = response.text - input_ids[i, :max_req_tokens] = episode.request_tensor - input_ids[i, max_req_tokens:] = episode.response_tensor + for i, response in enumerate(responses): + episode = Episode( + episode_id=str(uuid.uuid4()), + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + target=target, + completion=response, + ) episode.reward = await reward_actor.evaluate_response.route( prompt=prompt, response=response.text, target=target ) + episodes.append(episode) + + # Build input_ids for reference logprobs + input_ids[i, :max_req_tokens] = episode.request_tensor + input_ids[i, max_req_tokens:] = episode.response_tensor t.step("reward_evaluation") @@ -432,14 +399,13 @@ async def continuous_rollouts(): ) t.step("reference_model_calculate_logprobs") - for i, episode in enumerate(group.episodes): + for i, episode in enumerate(episodes): episode.ref_logprobs = ref_logprobs[i] del ref_logprobs, input_ids - t.step("compute_logprobs") # Calculate advantages and add to replay buffer - advantages = await compute_advantages.compute.call_one(group) - for episode, advantage in zip(group.episodes, advantages): + advantages = await compute_advantages.compute.call_one(episodes) + for episode, advantage in zip(episodes, advantages): episode.advantage = advantage await replay_buffer.add.call_one(episode) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 8e8c8de17..7a1cdfd15 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -343,7 +343,8 @@ def _preprocess_add_request( self, request: EngineCoreRequest ) -> tuple[Request, int]: """(forge/issues/332) Will require attention when we bump vllm versions - https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419""" + https://github.com/vllm-project/vllm/blob/0e3bb543f064eb416bca4f6f3013efa3830b12f7/vllm/v1/engine/core.py#L419 + """ if request.mm_hashes is not None: raise NotImplementedError("Support for mm_hash is not implemented yet.") req = Request.from_engine_core_request(request) @@ -446,11 +447,6 @@ async def update_weights(self, policy_version: int) -> None: async def _reset_prefix_cache(self): self.scheduler.reset_prefix_cache() - @endpoint - async def get_version(self) -> int: - """Get the current policy version.""" - return self.policy_version - @endpoint async def stop(self): self.running = False diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index a92fd5501..db1caf333 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -105,6 +105,7 @@ async def sample( for dp_idx in range(self.dp_size) ] + # Call the underlying collate function to collate the episodes into a batch return self.collate(reshaped_episodes) @endpoint