Skip to content

Commit 57a3de3

Browse files
authored
Adds TitanRefModel in place of HF based Reference Model (meta-pytorch#94)
* Wrapping RefModel * Debugging Cuda issue * Still debugging * Commit prior to cleanup incase rebase gets rough * Initial clean up * More cleaning before mega rebase
1 parent 223c58a commit 57a3de3

File tree

5 files changed

+376
-87
lines changed

5 files changed

+376
-87
lines changed

apps/grpo/main.py

Lines changed: 13 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -13,49 +13,20 @@
1313
import torch
1414
from datasets import load_dataset
1515
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
16+
from forge.actors.reference_actor import compute_sequence_logprobs, TitanRefModel
1617
from forge.actors.replay_buffer import ReplayBuffer
1718
from forge.controller.actor import ForgeActor
1819
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
1920
from forge.data.rewards import MathReward, ThinkingReward
2021
from forge.util.metric_logging import get_metric_logger
2122
from monarch.actor import endpoint
23+
from torchtitan.config.job_config import Model as TitanJobModelConfig
2224
from transformers import AutoModelForCausalLM, AutoTokenizer
2325

2426
logger = logging.getLogger(__name__)
2527
logger.setLevel(logging.DEBUG)
2628

2729

28-
def compute_sequence_logprobs(
29-
model: torch.nn.Module,
30-
input_ids: torch.Tensor,
31-
attention_mask: torch.Tensor,
32-
requires_grad: bool = True,
33-
) -> torch.Tensor:
34-
context_manager = torch.enable_grad() if requires_grad else torch.no_grad()
35-
36-
with context_manager:
37-
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
38-
logits = outputs.logits
39-
40-
# Apply log softmax to get log probabilities
41-
log_probs = torch.log_softmax(logits, dim=-1)
42-
43-
# Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
44-
shifted_input_ids = input_ids[:, 1:] # Remove first token
45-
shifted_log_probs = log_probs[:, :-1, :] # Remove last logit
46-
47-
# Gather log probabilities for actual tokens
48-
token_log_probs = torch.gather(
49-
shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
50-
).squeeze(-1)
51-
52-
# Sum log probabilities across sequence (masked by attention)
53-
shifted_attention_mask = attention_mask[:, 1:]
54-
sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1)
55-
56-
return sequence_log_probs
57-
58-
5930
@dataclass
6031
class Group:
6132
response: str # The response text for tokenization
@@ -273,48 +244,6 @@ async def __call__(self, groups: list[Group]) -> list[float]:
273244
return advantages
274245

275246

276-
class RefModel(ForgeActor):
277-
def __init__(self, model_name, device: torch.device | None = None):
278-
super().__init__()
279-
self.model_name = model_name
280-
281-
# Set device
282-
if device is None:
283-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
284-
else:
285-
self.device = device
286-
287-
# Initialize model and tokenizer
288-
self.model = AutoModelForCausalLM.from_pretrained(
289-
model_name,
290-
torch_dtype=torch.bfloat16,
291-
trust_remote_code=True,
292-
).to(self.device)
293-
294-
# Set model to eval mode for reference computations
295-
self.model.eval()
296-
297-
self.logger.info(f"Model initialized on {self.device}")
298-
299-
@endpoint
300-
async def forward(self, token_ids: list[int]) -> torch.Tensor:
301-
# Use provided token_ids directly
302-
input_ids = (
303-
torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
304-
)
305-
# Create attention mask of all 1s since we have actual tokens (no padding)
306-
attention_mask = torch.ones_like(input_ids).to(self.device)
307-
308-
# Compute log probabilities using shared utility function
309-
sequence_log_probs = compute_sequence_logprobs(
310-
self.model, input_ids, attention_mask, requires_grad=False
311-
)
312-
313-
return (
314-
sequence_log_probs.squeeze()
315-
) # Remove batch dimension for single response
316-
317-
318247
class DatasetActor(ForgeActor):
319248
"""Actor wrapper for HuggingFace dataset to provide async interface."""
320249

@@ -345,7 +274,8 @@ async def __next__(self) -> dict[str, str] | None:
345274
async def main():
346275
"""Main GRPO training loop with rollout and training processes."""
347276
group_size = 1
348-
model = "Qwen/Qwen3-1.7B"
277+
model = "Qwen/Qwen3-0.6B"
278+
titan_model = TitanJobModelConfig(name="qwen3", flavor="0.6B")
349279

350280
# ---- Setup WandB Logger ---- #
351281
logger = get_metric_logger(
@@ -403,8 +333,8 @@ async def main():
403333
),
404334
spawn_service(
405335
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
406-
RefModel,
407-
model_name=model,
336+
TitanRefModel,
337+
model=titan_model,
408338
),
409339
spawn_service(
410340
ServiceConfig(procs_per_replica=1, num_replicas=1),
@@ -431,9 +361,14 @@ async def continuous_rollouts():
431361
target=target,
432362
policy_version=version,
433363
)
434-
actions = await policy.generate.choose(prompt)
364+
responses = await policy.generate.choose(prompt)
365+
actions = responses.outputs
435366
for action in actions:
436-
ref_logprobs = await ref_model.forward.choose(action.token_ids)
367+
request_tokens = responses.prompt_token_ids
368+
response_tokens = action.token_ids
369+
ref_logprobs = await ref_model.forward.choose(
370+
request=request_tokens, response=response_tokens
371+
)
437372
reward = await reward_actor.evaluate_response.choose(
438373
prompt=prompt, response=action.text, target=target
439374
)

apps/rl/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from forge.controller import spawn_actors
2121
from omegaconf import DictConfig
2222

23-
2423
logger = logging.getLogger(__name__)
2524
logger.setLevel(logging.INFO)
2625

src/forge/actors/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"]
7+
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"]
88

99

1010
def __getattr__(name):
@@ -24,5 +24,9 @@ def __getattr__(name):
2424
from .replay_buffer import ReplayBuffer
2525

2626
return ReplayBuffer
27+
elif name == "TitanRefModel":
28+
from .reference_actor import TitanRefModel
29+
30+
return TitanRefModel
2731
else:
2832
raise AttributeError(f"module {__name__} has no attribute {name}")

src/forge/actors/policy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,12 @@
1313
from typing import Dict, List
1414

1515
import torch
16+
17+
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
18+
19+
from forge.data.sharding import VLLMSharding
20+
from forge.interfaces import Policy as PolicyInterface
21+
from forge.types import ProcessConfig
1622
from monarch.actor import current_rank, endpoint, ProcMesh
1723
from torchstore import MultiProcessStore
1824
from torchstore._state_dict_utils import DELIM
@@ -37,12 +43,6 @@
3743
from vllm.v1.structured_output import StructuredOutputManager
3844
from vllm.worker.worker_base import WorkerWrapperBase
3945

40-
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
41-
42-
from forge.data.sharding import VLLMSharding
43-
from forge.interfaces import Policy as PolicyInterface
44-
from forge.types import ProcessConfig
45-
4646

4747
logger = logging.getLogger(__name__)
4848

@@ -317,7 +317,7 @@ async def run(self):
317317
for request_output in processed_outputs.request_outputs:
318318
if request_output.finished:
319319
_, fut = self.requests.pop(request_output.request_id)
320-
fut.set_result(request_output.outputs)
320+
fut.set_result(request_output)
321321

322322
@endpoint
323323
async def update_weights(self) -> int:

0 commit comments

Comments
 (0)