Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 3 additions & 13 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.util import compute_logprobs
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down Expand Up @@ -128,16 +129,6 @@ def collate(batches: list[list[Episode]]):
return inputs, targets


def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
context_length = logits.shape[1] - input_ids.shape[1]
logits = logits[:, context_length - 1 : -1]
logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device)
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
return logprobs


def simple_grpo_loss(
logits: torch.Tensor,
response: torch.Tensor,
Expand Down Expand Up @@ -317,11 +308,10 @@ async def continuous_rollouts():
)

# Calculate reference logprobs
ref_logits = await ref_model.forward.choose(input_ids)
ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:])
ref_logprobs = await ref_model.forward.choose(input_ids, max_req_tokens)
for i, episode in enumerate(group.episodes):
episode.ref_logprobs = ref_logprobs[i]
del ref_logits, ref_logprobs, input_ids
del ref_logprobs, input_ids

# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.choose(group)
Expand Down
9 changes: 7 additions & 2 deletions src/forge/actors/reference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.controller import ForgeActor
from forge.util import compute_logprobs


@dataclass
Expand Down Expand Up @@ -73,7 +74,9 @@ async def setup(self):
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))

@endpoint
async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
async def forward(
self, input_ids: torch.Tensor, max_req_tokens: int
) -> torch.Tensor:
model_parts = self.engine.model_parts
parallel_dims = self.engine.parallel_dims
input_ids = input_ids.to("cuda")
Expand All @@ -99,4 +102,6 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
logits = model_parts[0](input_ids)
if isinstance(logits, DTensor):
logits = logits.full_tensor()
return logits
logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:])

return logprobs
2 changes: 2 additions & 0 deletions src/forge/util/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.
from .distributed import get_world_size_and_rank
from .logging import get_logger, log_once, log_rank_zero
from .math import compute_logprobs
from .metric_logging import get_metric_logger

__all__ = [
Expand All @@ -13,4 +14,5 @@
"log_once",
"log_rank_zero",
"get_metric_logger",
"compute_logprobs",
]
11 changes: 11 additions & 0 deletions src/forge/util/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import torch


def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
context_length = logits.shape[1] - input_ids.shape[1]
logits = logits[:, context_length - 1 : -1]
logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device)
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
return logprobs
Loading