Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 13 additions & 78 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,20 @@
import torch
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.reference_actor import compute_sequence_logprobs, TitanRefModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from torchtitan.config.job_config import Model as TitanJobModelConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def compute_sequence_logprobs(
model: torch.nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
requires_grad: bool = True,
) -> torch.Tensor:
context_manager = torch.enable_grad() if requires_grad else torch.no_grad()

with context_manager:
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits

# Apply log softmax to get log probabilities
log_probs = torch.log_softmax(logits, dim=-1)

# Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
shifted_input_ids = input_ids[:, 1:] # Remove first token
shifted_log_probs = log_probs[:, :-1, :] # Remove last logit

# Gather log probabilities for actual tokens
token_log_probs = torch.gather(
shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)

# Sum log probabilities across sequence (masked by attention)
shifted_attention_mask = attention_mask[:, 1:]
sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1)

return sequence_log_probs


@dataclass
class Group:
response: str # The response text for tokenization
Expand Down Expand Up @@ -273,48 +244,6 @@ async def __call__(self, groups: list[Group]) -> list[float]:
return advantages


class RefModel(ForgeActor):
def __init__(self, model_name, device: torch.device | None = None):
super().__init__()
self.model_name = model_name

# Set device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device

# Initialize model and tokenizer
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)

# Set model to eval mode for reference computations
self.model.eval()

self.logger.info(f"Model initialized on {self.device}")

@endpoint
async def forward(self, token_ids: list[int]) -> torch.Tensor:
# Use provided token_ids directly
input_ids = (
torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
)
# Create attention mask of all 1s since we have actual tokens (no padding)
attention_mask = torch.ones_like(input_ids).to(self.device)

# Compute log probabilities using shared utility function
sequence_log_probs = compute_sequence_logprobs(
self.model, input_ids, attention_mask, requires_grad=False
)

return (
sequence_log_probs.squeeze()
) # Remove batch dimension for single response


class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""

Expand Down Expand Up @@ -345,7 +274,8 @@ async def __next__(self) -> dict[str, str] | None:
async def main():
"""Main GRPO training loop with rollout and training processes."""
group_size = 1
model = "Qwen/Qwen3-1.7B"
model = "Qwen/Qwen3-0.6B"
titan_model = TitanJobModelConfig(name="qwen3", flavor="0.6B")

# ---- Setup WandB Logger ---- #
logger = get_metric_logger(
Expand Down Expand Up @@ -403,8 +333,8 @@ async def main():
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
RefModel,
model_name=model,
TitanRefModel,
model=titan_model,
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
Expand All @@ -431,9 +361,14 @@ async def continuous_rollouts():
target=target,
policy_version=version,
)
actions = await policy.generate.choose(prompt)
responses = await policy.generate.choose(prompt)
actions = responses.outputs
for action in actions:
ref_logprobs = await ref_model.forward.choose(action.token_ids)
request_tokens = responses.prompt_token_ids
response_tokens = action.token_ids
ref_logprobs = await ref_model.forward.choose(
request=request_tokens, response=response_tokens
)
reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=action.text, target=target
)
Expand Down
1 change: 0 additions & 1 deletion apps/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from forge.controller import spawn_actors
from omegaconf import DictConfig


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand Down
6 changes: 5 additions & 1 deletion src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"]
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"]


def __getattr__(name):
Expand All @@ -24,5 +24,9 @@ def __getattr__(name):
from .replay_buffer import ReplayBuffer

return ReplayBuffer
elif name == "TitanRefModel":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙃

from .reference_actor import TitanRefModel

return TitanRefModel
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
14 changes: 7 additions & 7 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
from typing import Dict, List

import torch

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import DELIM
Expand All @@ -37,12 +43,6 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -310,7 +310,7 @@ async def run(self):
for request_output in processed_outputs.request_outputs:
if request_output.finished:
_, fut = self.requests.pop(request_output.request_id)
fut.set_result(request_output.outputs)
fut.set_result(request_output)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adopted from #97

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this instead of raw outputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pragmatically: Less merge conflict with Philip's PR

I don't have strong preference, but it does make the output self contained which is nice when we need to pass the results around


@endpoint
async def update_weights(self):
Expand Down
Loading
Loading