diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 7f31c26c9..e453e1e1b 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -23,6 +23,7 @@ from forge.controller.actor import ForgeActor from forge.controller.provisioner import shutdown from forge.data.rewards import MathReward, ThinkingReward +from forge.util import compute_logprobs from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint from omegaconf import DictConfig @@ -128,16 +129,6 @@ def collate(batches: list[list[Episode]]): return inputs, targets -def compute_logprobs( - logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 -) -> torch.Tensor: - context_length = logits.shape[1] - input_ids.shape[1] - logits = logits[:, context_length - 1 : -1] - logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device) - logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) - return logprobs - - def simple_grpo_loss( logits: torch.Tensor, response: torch.Tensor, @@ -317,11 +308,10 @@ async def continuous_rollouts(): ) # Calculate reference logprobs - ref_logits = await ref_model.forward.choose(input_ids) - ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:]) + ref_logprobs = await ref_model.forward.choose(input_ids, max_req_tokens) for i, episode in enumerate(group.episodes): episode.ref_logprobs = ref_logprobs[i] - del ref_logits, ref_logprobs, input_ids + del ref_logprobs, input_ids # Calculate advantages and add to replay buffer advantages = await compute_advantages.compute.choose(group) diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index 108414c88..b31a8a0c3 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -20,6 +20,7 @@ from torchtitan.experiments.forge.job_config import ForgeJobConfig from forge.controller import ForgeActor +from forge.util import compute_logprobs @dataclass @@ -73,7 +74,9 @@ async def setup(self): self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) @endpoint - async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: + async def forward( + self, input_ids: torch.Tensor, max_req_tokens: int + ) -> torch.Tensor: model_parts = self.engine.model_parts parallel_dims = self.engine.parallel_dims input_ids = input_ids.to("cuda") @@ -99,4 +102,6 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: logits = model_parts[0](input_ids) if isinstance(logits, DTensor): logits = logits.full_tensor() - return logits + logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:]) + + return logprobs diff --git a/src/forge/util/__init__.py b/src/forge/util/__init__.py index 5fb03b0f9..a1466c089 100644 --- a/src/forge/util/__init__.py +++ b/src/forge/util/__init__.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from .distributed import get_world_size_and_rank from .logging import get_logger, log_once, log_rank_zero +from .math import compute_logprobs from .metric_logging import get_metric_logger __all__ = [ @@ -13,4 +14,5 @@ "log_once", "log_rank_zero", "get_metric_logger", + "compute_logprobs", ] diff --git a/src/forge/util/math.py b/src/forge/util/math.py new file mode 100644 index 000000000..302ab0fa3 --- /dev/null +++ b/src/forge/util/math.py @@ -0,0 +1,11 @@ +import torch + + +def compute_logprobs( + logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 +) -> torch.Tensor: + context_length = logits.shape[1] - input_ids.shape[1] + logits = logits[:, context_length - 1 : -1] + logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device) + logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) + return logprobs