Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 14 additions & 75 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,49 +13,24 @@
import torch
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.reference_actor import (
compute_sequence_logprobs,
HuggingFaceRefModel,
TitanRefModel,
)
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from torchtitan.config.job_config import Model as TitanJobModelConfig
from transformers import AutoModelForCausalLM, AutoTokenizer

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def compute_sequence_logprobs(
model: torch.nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
requires_grad: bool = True,
) -> torch.Tensor:
context_manager = torch.enable_grad() if requires_grad else torch.no_grad()

with context_manager:
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits

# Apply log softmax to get log probabilities
log_probs = torch.log_softmax(logits, dim=-1)

# Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
shifted_input_ids = input_ids[:, 1:] # Remove first token
shifted_log_probs = log_probs[:, :-1, :] # Remove last logit

# Gather log probabilities for actual tokens
token_log_probs = torch.gather(
shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)

# Sum log probabilities across sequence (masked by attention)
shifted_attention_mask = attention_mask[:, 1:]
sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1)

return sequence_log_probs


@dataclass
class Group:
response: str # The response text for tokenization
Expand Down Expand Up @@ -273,48 +248,6 @@ async def __call__(self, groups: list[Group]) -> list[float]:
return advantages


class RefModel(ForgeActor):
def __init__(self, model_name, device: torch.device | None = None):
super().__init__()
self.model_name = model_name

# Set device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device

# Initialize model and tokenizer
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)

# Set model to eval mode for reference computations
self.model.eval()

self.logger.info(f"Model initialized on {self.device}")

@endpoint
async def forward(self, token_ids: list[int]) -> torch.Tensor:
# Use provided token_ids directly
input_ids = (
torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
)
# Create attention mask of all 1s since we have actual tokens (no padding)
attention_mask = torch.ones_like(input_ids).to(self.device)

# Compute log probabilities using shared utility function
sequence_log_probs = compute_sequence_logprobs(
self.model, input_ids, attention_mask, requires_grad=False
)

return (
sequence_log_probs.squeeze()
) # Remove batch dimension for single response


class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""

Expand Down Expand Up @@ -401,10 +334,16 @@ async def main():
gamma=0.99,
lambda_=0.95,
),
# spawn_service(
# ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
# RefModel,
# model_name=model,
# ),
# GOAL: Swap this in and everything should just "work"
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
RefModel,
model_name=model,
TitanRefModel,
model=TitanJobModelConfig(name=model),
),
spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
Expand Down
101 changes: 101 additions & 0 deletions apps/grpo/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import asyncio

from datasets import load_dataset

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.reference_actor import HuggingFaceRefModel, TitanRefModel

from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from monarch.actor import endpoint


class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""

def __init__(
self, path: str, config_name: str, split: str, streaming: bool, **kwargs
):
super().__init__()

def gsm8k_to_messages(sample):
question = sample["question"]
full_answer: str = sample["answer"]
answer = full_answer.split("#### ")[1]
return {"question": question, "answer": answer}

ds = load_dataset(path, config_name, split=split, streaming=streaming)
ds = ds.map(gsm8k_to_messages)
ds = ds.shuffle()
self._iterator = iter(ds)

@endpoint
async def __next__(self) -> dict[str, str] | None:
return next(self._iterator)


# Sandbox; will be removed
async def main():
group_size = 1

# For Torchtitan
model = "Qwen/Qwen3-1.7B"

# Spawn Reference "Agents"
hf_model = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
HuggingFaceRefModel,
model_name=model,
)
titan_model = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
TitanRefModel,
)

# Spawn Policy for getting responses
policy = await spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Policy,
config=PolicyConfig(
worker_params=WorkerConfig(model=model),
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
),
)

# Load Dataset
dataloader = await spawn_service(
ServiceConfig(procs_per_replica=1, num_replicas=1),
DatasetActor,
path="openai/gsm8k",
config_name="main",
split="train",
streaming=True,
)
sample = await dataloader.__next__.choose()
prompt, target = sample["question"], sample["answer"]
print("Sample: ", sample)

# Generate output from policy, then pass to reference agents
actions = await policy.generate.choose(prompt)
for action in actions:
print("Generated Action tok_ids: ", action.token_ids)

print("HuggingFace Results")
hf_logprobs: float = await hf_model.forward.choose(action.token_ids)
print("HF logprob: ", hf_logprobs)

print("Titan Results")
titan_logprobs: float = await titan_model.forward.choose(action.token_ids)
print("Titan logprob: ", titan_logprobs)
# TODO: Update forward to convert probs (logits) to logprobs

await asyncio.gather(
shutdown_service(policy),
shutdown_service(dataloader),
shutdown_service(hf_model),
shutdown_service(titan_model),
)


if __name__ == "__main__":
asyncio.run(main())
2 changes: 1 addition & 1 deletion apps/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from forge.cli.config import parse
from forge.controller import spawn_actors
from omegaconf import DictConfig

from torchtitan.config.job_config import Model

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down
6 changes: 5 additions & 1 deletion src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"]
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"]


def __getattr__(name):
Expand All @@ -24,5 +24,9 @@ def __getattr__(name):
from .replay_buffer import ReplayBuffer

return ReplayBuffer
elif name == "TitanRefModel":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🙃

from .reference_actor import TitanRefModel

return TitanRefModel
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
Loading
Loading