-
Notifications
You must be signed in to change notification settings - Fork 16
Adds TitanRefModel in place of HF based Reference Model #94
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
c3f2c8c
4ca0685
3b7ee6d
41bdd93
7621fe4
29e74aa
135deaf
ff7c120
a33030d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import asyncio | ||
import logging | ||
import math | ||
import os | ||
|
||
from collections import deque | ||
from collections.abc import Mapping | ||
from dataclasses import dataclass, field, fields | ||
|
||
from typing import Any | ||
|
||
import torch | ||
|
||
from forge.controller import ForgeActor | ||
from monarch.actor import current_rank, current_size, endpoint | ||
from omegaconf import DictConfig, OmegaConf | ||
from torch import nn | ||
|
||
from torchtitan.components.lr_scheduler import LRSchedulersContainer | ||
from torchtitan.config.job_config import Comm, Model, Parallelism | ||
from torchtitan.distributed import ParallelDims, utils as dist_utils | ||
from torchtitan.experiments.forge.engine import ForgeEngine | ||
from torchtitan.experiments.forge.job_config import ForgeJobConfig | ||
from transformers import AutoModelForCausalLM | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
@dataclass | ||
class ReferenceActor(ForgeActor): | ||
model: Model = field(default_factory=Model) | ||
# parallelism: Parallelism = field(default_factory=Parallelism) | ||
# comm: Comm = field(default_factory=Comm) | ||
|
||
# For RefModel | ||
ref_model: ForgeActor | None = None | ||
device: torch.device | None = None | ||
|
||
# For processing | ||
running: bool = False | ||
queue: deque | None = None | ||
|
||
def __post_init__(self): | ||
"""Initializes config types and env variables. | ||
|
||
torchrun normally hands env variables, but we need to do it ourselves | ||
in monarch for now. | ||
|
||
""" | ||
# Instantiate dict fields | ||
for f in fields(self): | ||
attr = getattr(self, f.name) | ||
if isinstance(attr, Mapping): | ||
setattr(self, f.name, f.type(**attr)) | ||
elif not isinstance(attr, f.type): | ||
raise TypeError( | ||
f"{f.name} should be a {f.type} type or a dict like object" | ||
) | ||
|
||
# This might need to be changed to a distributed friendly container | ||
# We also don't have a traditional scheduler? | ||
self.queue = deque() | ||
|
||
self.rank = current_rank().rank | ||
self.size = math.prod(current_size().values()) | ||
|
||
env = { | ||
"RANK": str(self.rank), | ||
"LOCAL_RANK": str(self.rank), | ||
"LOCAL_WORLD_SIZE": str(self.size), | ||
"GROUP_RANK": str(self.size), | ||
"GROUP_WORLD_SIZE": str(self.size), | ||
"ROLE_RANK": str(self.rank), | ||
"ROLE_WORLD_SIZE": str(self.size), | ||
"ROLE_NAME": "rank", | ||
"WORLD_SIZE": str(self.size), | ||
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", | ||
} | ||
os.environ.update(env) | ||
|
||
@endpoint | ||
async def setup(self): | ||
engine_config = {f.name: getattr(self, f.name) for f in fields(self)} | ||
self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) | ||
|
||
# Spawn the RefModel | ||
self.ref_model = await spawn_service( | ||
default_service_cfg, | ||
RefModel, | ||
model_name=self.model.name, | ||
device=self.device, | ||
) | ||
|
||
# Kick off background processing | ||
asyncio.create_task(self.run_processing.call()) | ||
|
||
@endpoint | ||
async def forward(self, token_ids: list[int]) -> torch.Tensor: | ||
""" | ||
Enque the tokens and await response | ||
""" | ||
fut = asyncio.Future() | ||
self.queue.append((token_ids, fut)) | ||
return await fut | ||
|
||
@endpoint | ||
async def run_processing(self): | ||
""" | ||
Simple loop to pass things along to the ref model | ||
""" | ||
|
||
# TODO: Consider creating a unified base class for this pattern | ||
self.running = True | ||
|
||
while self.running: | ||
request, fut = self.queue.popleft() | ||
model_output = await self.ref_model.forward(request) | ||
fut.set_result(model_output) | ||
|
||
@endpoint | ||
async def cleanup(self) -> None: | ||
self.running = False | ||
|
||
|
||
class RefModel(ForgeActor): | ||
def __init__(self, model_name, device: torch.device | None = None): | ||
super().__init__() | ||
self.model_name = model_name | ||
|
||
# Set device | ||
if device is None: | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
else: | ||
self.device = device | ||
|
||
# Initialize model and tokenizer | ||
self.model = AutoModelForCausalLM.from_pretrained( | ||
model_name, | ||
torch_dtype=torch.bfloat16, | ||
trust_remote_code=True, | ||
).to(self.device) | ||
|
||
# Set model to eval mode for reference computations | ||
self.model.eval() | ||
|
||
self.logger.info(f"Model initialized on {self.device}") | ||
|
||
@endpoint | ||
async def forward(self, token_ids: list[int]) -> torch.Tensor: | ||
# Use provided token_ids directly | ||
input_ids = ( | ||
torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device) | ||
) | ||
# Create attention mask of all 1s since we have actual tokens (no padding) | ||
attention_mask = torch.ones_like(input_ids).to(self.device) | ||
|
||
# Compute log probabilities using shared utility function | ||
sequence_log_probs = compute_sequence_logprobs( | ||
self.model, input_ids, attention_mask, requires_grad=False | ||
) | ||
|
||
return ( | ||
sequence_log_probs.squeeze() | ||
) # Remove batch dimension for single response | ||
|
||
|
||
def compute_sequence_logprobs( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should update this to match the one from main. This should probably go in the trainer.py file There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was copy paste from main (at the time Joe's)
|
||
model: torch.nn.Module, | ||
input_ids: torch.Tensor, | ||
attention_mask: torch.Tensor, | ||
requires_grad: bool = True, | ||
) -> torch.Tensor: | ||
context_manager = torch.enable_grad() if requires_grad else torch.no_grad() | ||
|
||
with context_manager: | ||
outputs = model(input_ids=input_ids, attention_mask=attention_mask) | ||
logits = outputs.logits | ||
|
||
# Apply log softmax to get log probabilities | ||
log_probs = torch.log_softmax(logits, dim=-1) | ||
|
||
# Extract log probabilities for the actual tokens (excluding the first token for next-token prediction) | ||
shifted_input_ids = input_ids[:, 1:] # Remove first token | ||
shifted_log_probs = log_probs[:, :-1, :] # Remove last logit | ||
|
||
# Gather log probabilities for actual tokens | ||
token_log_probs = torch.gather( | ||
shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1) | ||
).squeeze(-1) | ||
|
||
# Sum log probabilities across sequence (masked by attention) | ||
shifted_attention_mask = attention_mask[:, 1:] | ||
sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1) | ||
|
||
return sequence_log_probs |
Uh oh!
There was an error while loading. Please reload this page.