Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 1 addition & 73 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.reference_actor import compute_sequence_logprobs, RefModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, spawn_service
Expand All @@ -21,37 +22,6 @@
from transformers import AutoModelForCausalLM, AutoTokenizer


def compute_sequence_logprobs(
model: torch.nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
requires_grad: bool = True,
) -> torch.Tensor:
context_manager = torch.enable_grad() if requires_grad else torch.no_grad()

with context_manager:
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits

# Apply log softmax to get log probabilities
log_probs = torch.log_softmax(logits, dim=-1)

# Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
shifted_input_ids = input_ids[:, 1:] # Remove first token
shifted_log_probs = log_probs[:, :-1, :] # Remove last logit

# Gather log probabilities for actual tokens
token_log_probs = torch.gather(
shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)

# Sum log probabilities across sequence (masked by attention)
shifted_attention_mask = attention_mask[:, 1:]
sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1)

return sequence_log_probs


@dataclass
class Group:
response: str # The response text for tokenization
Expand Down Expand Up @@ -269,48 +239,6 @@ async def __call__(self, groups: list[Group]) -> list[float]:
return advantages


class RefModel(ForgeActor):
def __init__(self, model_name, device: torch.device | None = None):
super().__init__()
self.model_name = model_name

# Set device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device

# Initialize model and tokenizer
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)

# Set model to eval mode for reference computations
self.model.eval()

self.logger.info(f"Model initialized on {self.device}")

@endpoint
async def forward(self, token_ids: list[int]) -> torch.Tensor:
# Use provided token_ids directly
input_ids = (
torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
)
# Create attention mask of all 1s since we have actual tokens (no padding)
attention_mask = torch.ones_like(input_ids).to(self.device)

# Compute log probabilities using shared utility function
sequence_log_probs = compute_sequence_logprobs(
self.model, input_ids, attention_mask, requires_grad=False
)

return (
sequence_log_probs.squeeze()
) # Remove batch dimension for single response


class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""

Expand Down
10 changes: 8 additions & 2 deletions apps/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from forge.cli.config import parse
from forge.controller import spawn_actors
from omegaconf import DictConfig

from torchtitan.config.job_config import Model

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


async def run(cfg: DictConfig):
trainer, buffer = await asyncio.gather(
trainer, buffer, reference = await asyncio.gather(
spawn_actors(
name="trainer",
actor_cls=RLTrainer,
Expand All @@ -40,18 +40,24 @@ async def run(cfg: DictConfig):
cfg=cfg.replay_buffer,
processes=cfg.replay_buffer.pop("processes"),
),
spawn_actors(
name="reference_actor",
actor_cls=ReferenceActor,
),
)
print("Actors spawned")

# Initialize everything
await asyncio.gather(
buffer.setup.call(),
trainer.setup.call(),
reference.setup.call(),
)
print("Setup done")

print("shutting down...")
await asyncio.gather(*[a.mesh.stop() for a in [trainer]])
await reference.cleanup.call()


@parse
Expand Down
204 changes: 204 additions & 0 deletions src/forge/actors/reference_actor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import asyncio
import logging
import math
import os

from collections import deque
from collections.abc import Mapping
from dataclasses import dataclass, field, fields

from typing import Any

import torch

from forge.controller import ForgeActor
from monarch.actor import current_rank, current_size, endpoint
from omegaconf import DictConfig, OmegaConf
from torch import nn

from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.config.job_config import Comm, Model, Parallelism
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig
from transformers import AutoModelForCausalLM


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@dataclass
class ReferenceActor(ForgeActor):
model: Model = field(default_factory=Model)
# parallelism: Parallelism = field(default_factory=Parallelism)
# comm: Comm = field(default_factory=Comm)

# For RefModel
ref_model: ForgeActor | None = None
device: torch.device | None = None

# For processing
running: bool = False
queue: deque | None = None

def __post_init__(self):
"""Initializes config types and env variables.

torchrun normally hands env variables, but we need to do it ourselves
in monarch for now.

"""
# Instantiate dict fields
for f in fields(self):
attr = getattr(self, f.name)
if isinstance(attr, Mapping):
setattr(self, f.name, f.type(**attr))
elif not isinstance(attr, f.type):
raise TypeError(
f"{f.name} should be a {f.type} type or a dict like object"
)

# This might need to be changed to a distributed friendly container
# We also don't have a traditional scheduler?
self.queue = deque()

self.rank = current_rank().rank
self.size = math.prod(current_size().values())

env = {
"RANK": str(self.rank),
"LOCAL_RANK": str(self.rank),
"LOCAL_WORLD_SIZE": str(self.size),
"GROUP_RANK": str(self.size),
"GROUP_WORLD_SIZE": str(self.size),
"ROLE_RANK": str(self.rank),
"ROLE_WORLD_SIZE": str(self.size),
"ROLE_NAME": "rank",
"WORLD_SIZE": str(self.size),
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}
os.environ.update(env)

@endpoint
async def setup(self):
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))

# Spawn the RefModel
self.ref_model = await spawn_service(
default_service_cfg,
RefModel,
model_name=self.model.name,
device=self.device,
)

# Kick off background processing
asyncio.create_task(self.run_processing.call())

@endpoint
async def forward(self, token_ids: list[int]) -> torch.Tensor:
"""
Enque the tokens and await response
"""
fut = asyncio.Future()
self.queue.append((token_ids, fut))
return await fut

@endpoint
async def run_processing(self):
"""
Simple loop to pass things along to the ref model
"""

# TODO: Consider creating a unified base class for this pattern
self.running = True

while self.running:
request, fut = self.queue.popleft()
model_output = await self.ref_model.forward(request)
fut.set_result(model_output)

@endpoint
async def cleanup(self) -> None:
self.running = False


class RefModel(ForgeActor):
def __init__(self, model_name, device: torch.device | None = None):
super().__init__()
self.model_name = model_name

# Set device
if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device

# Initialize model and tokenizer
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)

# Set model to eval mode for reference computations
self.model.eval()

self.logger.info(f"Model initialized on {self.device}")

@endpoint
async def forward(self, token_ids: list[int]) -> torch.Tensor:
# Use provided token_ids directly
input_ids = (
torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
)
# Create attention mask of all 1s since we have actual tokens (no padding)
attention_mask = torch.ones_like(input_ids).to(self.device)

# Compute log probabilities using shared utility function
sequence_log_probs = compute_sequence_logprobs(
self.model, input_ids, attention_mask, requires_grad=False
)

return (
sequence_log_probs.squeeze()
) # Remove batch dimension for single response


def compute_sequence_logprobs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should update this to match the one from main. This should probably go in the trainer.py file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was copy paste from main (at the time Joe's)

compute_logprobs is the file is the implementation from #97

model: torch.nn.Module,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
requires_grad: bool = True,
) -> torch.Tensor:
context_manager = torch.enable_grad() if requires_grad else torch.no_grad()

with context_manager:
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
logits = outputs.logits

# Apply log softmax to get log probabilities
log_probs = torch.log_softmax(logits, dim=-1)

# Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
shifted_input_ids = input_ids[:, 1:] # Remove first token
shifted_log_probs = log_probs[:, :-1, :] # Remove last logit

# Gather log probabilities for actual tokens
token_log_probs = torch.gather(
shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
).squeeze(-1)

# Sum log probabilities across sequence (masked by attention)
shifted_attention_mask = attention_mask[:, 1:]
sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1)

return sequence_log_probs
Loading