Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 24 additions & 77 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,25 @@
import torch
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.reference_actor import (
compute_sequence_logprobs,
HuggingFaceRefModel,
RefModel,
TitanRefModel,
)
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from torchtitan.config.job_config import Model as TitanJobModelConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def compute_sequence_logprobs(
model: torch.nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
requires_grad: bool = True,
) -> torch.Tensor:
context_manager = torch.enable_grad() if requires_grad else torch.no_grad()

with context_manager:
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits

# Apply log softmax to get log probabilities
log_probs = torch.log_softmax(logits, dim=-1)

# Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
shifted_input_ids = input_ids[:, 1:] # Remove first token
shifted_log_probs = log_probs[:, :-1, :] # Remove last logit

# Gather log probabilities for actual tokens
token_log_probs = torch.gather(
shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)

# Sum log probabilities across sequence (masked by attention)
shifted_attention_mask = attention_mask[:, 1:]
sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1)

return sequence_log_probs


@dataclass
class Group:
response: str # The response text for tokenization
Expand Down Expand Up @@ -273,48 +249,6 @@ async def __call__(self, groups: list[Group]) -> list[float]:
return advantages


class RefModel(ForgeActor):
def __init__(self, model_name, device: torch.device | None = None):
super().__init__()
self.model_name = model_name

# Set device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device

# Initialize model and tokenizer
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)

# Set model to eval mode for reference computations
self.model.eval()

self.logger.info(f"Model initialized on {self.device}")

@endpoint
async def forward(self, token_ids: list[int]) -> torch.Tensor:
# Use provided token_ids directly
input_ids = (
torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
)
# Create attention mask of all 1s since we have actual tokens (no padding)
attention_mask = torch.ones_like(input_ids).to(self.device)

# Compute log probabilities using shared utility function
sequence_log_probs = compute_sequence_logprobs(
self.model, input_ids, attention_mask, requires_grad=False
)

return (
sequence_log_probs.squeeze()
) # Remove batch dimension for single response


class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""

Expand Down Expand Up @@ -346,6 +280,7 @@ async def main():
"""Main GRPO training loop with rollout and training processes."""
group_size = 1
model = "Qwen/Qwen3-1.7B"
# model = "meta-llama/Meta-Llama-3.1-8B"

# ---- Setup WandB Logger ---- #
logger = get_metric_logger(
Expand Down Expand Up @@ -401,10 +336,16 @@ async def main():
gamma=0.99,
lambda_=0.95,
),
# spawn_service(
# ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
# RefModel,
# model_name=model,
# ),
# GOAL: Swap this in and everything should just "work"
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
RefModel,
model_name=model,
TitanRefModel,
# model=TitanJobModelConfig(name=model),
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
Expand All @@ -431,9 +372,15 @@ async def continuous_rollouts():
target=target,
policy_version=version,
)
actions = await policy.generate.choose(prompt)
responses = await policy.generate.choose(prompt)
actions = responses.outputs
for action in actions:
ref_logprobs = await ref_model.forward.choose(action.token_ids)
# ref_logprobs = await ref_model.forward.choose(action.token_ids)
request_tokens = responses.prompt_token_ids
response_tokens = action.token_ids
ref_logprobs = await ref_model.forward.choose(
request=request_tokens, response=response_tokens
)
reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=action.text, target=target
)
Expand Down
126 changes: 126 additions & 0 deletions apps/grpo/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import asyncio

from datasets import load_dataset

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.reference_actor import HuggingFaceRefModel, RefModel, TitanRefModel

from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from monarch.actor import endpoint
from torchtitan.config.job_config import Model


class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""

def __init__(
self, path: str, config_name: str, split: str, streaming: bool, **kwargs
):
super().__init__()

def gsm8k_to_messages(sample):
question = sample["question"]
full_answer: str = sample["answer"]
answer = full_answer.split("#### ")[1]
return {"question": question, "answer": answer}

ds = load_dataset(path, config_name, split=split, streaming=streaming)
ds = ds.map(gsm8k_to_messages)
ds = ds.shuffle()
self._iterator = iter(ds)

@endpoint
async def __next__(self) -> dict[str, str] | None:
return next(self._iterator)


# Sandbox; will be removed
async def main():
group_size = 1

# For Torchtitan
model = "Qwen/Qwen3-1.7B"
# model = "meta-llama/Meta-Llama-3.1-8B"

# Spawn Reference "Agents"

# # Joe
# hf_model = await spawn_service(
# ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
# HuggingFaceRefModel,
# model_name=model,
# )

# # Philip
# hf_model = await spawn_service(
# ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
# RefModel,
# model_name=model,
# )

titan_model = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
TitanRefModel,
)

# Spawn Policy for getting responses
policy = await spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Policy,
config=PolicyConfig(
worker_params=WorkerConfig(model=model),
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
),
)

# Load Dataset
dataloader = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
DatasetActor,
path="openai/gsm8k",
config_name="main",
split="train",
streaming=True,
)
sample = await dataloader.__next__.choose()
prompt, target = sample["question"], sample["answer"]
print("Sample: ", sample)

# Generate output from policy, then pass to reference agents
responses = await policy.generate.choose(prompt)
actions = responses.outputs
for action in actions:
request_tokens = responses.prompt_token_ids
response_tokens = action.token_ids

print("request_tokens: ", request_tokens)
print("response_tokens: ", response_tokens)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# print("HuggingFace Results")
# hf_logprobs = await hf_model.forward.choose(
# request=request_tokens, response=response_tokens
# )
# print("HF logprob: ", hf_logprobs)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

await asyncio.gather(
shutdown_service(policy),
shutdown_service(dataloader),
# shutdown_service(hf_model),
)

# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
print("Titan Results")
titan_logprobs: float = await titan_model.forward.choose(
request=request_tokens, response=response_tokens
)
print("Titan logprob: ", titan_logprobs)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# await shutdown_service(titan_model)


if __name__ == "__main__":
asyncio.run(main())
2 changes: 1 addition & 1 deletion apps/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from forge.cli.config import parse
from forge.controller import spawn_actors
from omegaconf import DictConfig

from torchtitan.config.job_config import Model

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down
6 changes: 5 additions & 1 deletion src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"]
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"]


def __getattr__(name):
Expand All @@ -24,5 +24,9 @@ def __getattr__(name):
from .replay_buffer import ReplayBuffer

return ReplayBuffer
elif name == "TitanRefModel":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙃

from .reference_actor import TitanRefModel

return TitanRefModel
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
15 changes: 8 additions & 7 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
from typing import Dict, List

import torch

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import DELIM
Expand All @@ -37,12 +43,6 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -310,7 +310,8 @@ async def run(self):
for request_output in processed_outputs.request_outputs:
if request_output.finished:
_, fut = self.requests.pop(request_output.request_id)
fut.set_result(request_output.outputs)
# fut.set_result(request_output.outputs)
fut.set_result(request_output)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adopted from #97

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this instead of raw outputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pragmatically: Less merge conflict with Philip's PR

I don't have strong preference, but it does make the output self contained which is nice when we need to pass the results around


@endpoint
async def update_weights(self):
Expand Down
Loading
Loading