Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 128 additions & 64 deletions apps/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import random
import time
import uuid
from collections.abc import Iterator
from dataclasses import dataclass
from typing import Any

Expand All @@ -23,18 +22,27 @@
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown

from forge.losses.reinforce_loss import ReinforceLoss
from forge.losses.grpo_loss import SimpleGRPOLoss
from forge.util.metric_logging import get_metric_logger

from forge.util.ops import selective_log_softmax
from monarch.actor import endpoint
from omegaconf import DictConfig

from torch.utils.data import IterableDataset
from torchstore.state_dict_utils import DELIM
from transformers import AutoModelForCausalLM
from vllm.transformers_utils.tokenizer import get_tokenizer


def pad_sequence(
tensor: torch.Tensor, target_len: int, pad_value: float = 0.0
) -> torch.Tensor:
diff = target_len - tensor.size(0)
if diff > 0:
return F.pad(tensor, (0, diff), value=pad_value)
return tensor


# TODO: Episode and Group and duplicated and needs clean up.
@dataclass
class Episode:
Expand Down Expand Up @@ -199,6 +207,45 @@ def new_group(
return cls(str(group_id), episodes)


class RefModel(ForgeActor):
def __init__(self, model_name, device: torch.device | None = None):
super().__init__()
self.model_name = model_name

if device is None:
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
self.device = device

self.model = AutoModelForCausalLM.from_pretrained(
model_name,
dtype=torch.bfloat16,
trust_remote_code=True,
).to(self.device)
self.model.eval()

self.logger.info(f"Model initialized on {self.device}")

@endpoint
async def forward(self, episode: Episode) -> torch.Tensor:
input_ids = (
pad_sequence(episode.input_ids, episode.max_seq_len - 1, episode.pad_id)
.to(self.device)
.unsqueeze(0)
)
target_ids = (
pad_sequence(episode.target_ids, episode.max_seq_len - 1, episode.pad_id)
.to(self.device)
.unsqueeze(0)
)
mask = input_ids != episode.pad_id

with torch.inference_mode():
logits = self.model(input_ids=input_ids, attention_mask=mask).logits

return selective_log_softmax(logits, target_ids).squeeze(0)


@dataclass
class Trainer(ForgeActor):
"""Reinforce Loss Trainer implementation for policy optimization."""
Expand All @@ -224,7 +271,9 @@ async def setup(self):
self.model.parameters(), lr=self.learning_rate
)
self.optimizer.zero_grad()
self.loss = ReinforceLoss()

# beta = 0.01 for quicker convergence
self.loss = SimpleGRPOLoss(0.01)
self.logger.info(f"Trainer model initialized on {self.device}")

@endpoint
Expand All @@ -238,14 +287,16 @@ def train_step(self, episodes: list[Episode]) -> float:
batch_loss_masks = []
batch_weights = []
batch_sampling_log_probs = []
batch_ref_logprobs = []
for episode in episodes:
input_ids = self.pad_sequence(episode.input_ids, max_seq_len, pad_id)
target_ids = self.pad_sequence(episode.target_ids, max_seq_len, pad_id)
loss_mask = self.pad_sequence(episode.loss_mask, max_seq_len, 0.0)
sampling_log_probs = self.pad_sequence(
input_ids = pad_sequence(episode.input_ids, max_seq_len, pad_id)
target_ids = pad_sequence(episode.target_ids, max_seq_len, pad_id)
loss_mask = pad_sequence(episode.loss_mask, max_seq_len, 0.0)
sampling_log_probs = pad_sequence(
episode.sampling_log_probs, max_seq_len, 0.0
)
weights = self.pad_sequence(episode.weighted_advantages, max_seq_len, 0.0)
weights = pad_sequence(episode.weighted_advantages, max_seq_len, 0.0)
ref_logprobs = episode.ref_logprobs

# Exclude padded response tokens from loss
valid_mask = target_ids != pad_id
Expand All @@ -258,22 +309,26 @@ def train_step(self, episodes: list[Episode]) -> float:
batch_loss_masks.append(loss_mask)
batch_weights.append(weights)
batch_sampling_log_probs.append(sampling_log_probs)
batch_ref_logprobs.append(ref_logprobs)

# Stack into batched tensors
input_ids = torch.stack(batch_input_ids).to(self.device)
target_ids = torch.stack(batch_target_ids).to(self.device)
loss_masks = torch.stack(batch_loss_masks).to(self.device)
weights = torch.stack(batch_weights).to(self.device)
sampling_log_probs = torch.stack(batch_sampling_log_probs).to(self.device)
ref_logprobs = torch.stack(batch_ref_logprobs).to(self.device)

# Create attention mask
attention_mask = input_ids != pad_id

# Forward pass
logits = self.model(input_ids=input_ids, attention_mask=attention_mask).logits

trainer_log_probs = selective_log_softmax(logits, target_ids)
# Compute loss only on response tokens
loss = self.loss(logits, target_ids, loss_masks, weights, sampling_log_probs)
# loss = self.loss(logits, target_ids, loss_masks, weights, sampling_log_probs)
loss = self.loss(trainer_log_probs, ref_logprobs, weights, loss_masks)
loss.backward()

torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
Expand All @@ -296,68 +351,66 @@ async def push_weights(self, version: int):
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
)

def pad_sequence(
self, tensor: torch.Tensor, target_len: int, pad_value: float = 0.0
) -> torch.Tensor:
diff = target_len - tensor.size(0)
if diff > 0:
return F.pad(tensor, (0, diff), value=pad_value)
return tensor


@dataclass
class RewardActor(ForgeActor):
"""Reward actor that uses a list of scoring functions."""

@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
if response == target:
return 1.0
return 0.0
reward = 1.0 if response.strip() == "10" else 0.0
return reward


@dataclass
class SumDigitsDataset(IterableDataset):
class SumDigitsDataset:
def __init__(self, tokenizer, max_samples=1000):
self.min_digit_length = 2
self.max_digit_length = 3
self.max_numbers = max_samples
self.data = self.generate_random_number()
self._tokenizer = tokenizer

def __iter__(self) -> Iterator[Any]:
for data in self.data:
answer = str(sum(int(x) for x in data))
system_prompt = """
A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant only gives very concise answers.
"""
request: str = f"What is the sum of the digits of {data}"
as_chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": request},
]
formatted_request = self._tokenizer.apply_chat_template(
as_chat,
tokenize=False,
add_generation_prompt=True,
)
yield {
"question": formatted_request,
"request": formatted_request,
"answer": answer,
"target": answer,
}

def generate_random_number(self) -> Iterator[str]:
while True:
yield self.generate_one()
def generate_sample(self, step: int) -> dict[str, str]:
"""Generate a single sample based on training step for progressive difficulty."""
data = self.generate_one(step)
answer = str(sum(int(x) for x in data))

def generate_one(self) -> str:
return "".join(
str(random.randint(0, 9))
for _ in range(random.randint(self.min_digit_length, self.max_digit_length))
system_prompt = """
A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistant only gives very concise answers (just the number, no explanation).
"""
request: str = f"What is the sum of the digits of {data}"
as_chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": request},
]
formatted_request = self._tokenizer.apply_chat_template(
as_chat,
tokenize=False,
add_generation_prompt=True,
)
return {
"question": formatted_request,
"request": formatted_request,
"answer": answer,
"target": answer,
}

def generate_one(self, step: int) -> str:
"""Generate number based on training step for curriculum learning."""
if step < 200:
# Early training: 2-digit numbers (10-99)
min_val, max_val = 10, 99
elif step < 1000:
# Later training: 1-4 digit numbers (0-1000)
min_val, max_val = 0, 1000
elif step < 3000:
# Later training: 1-6 digit numbers (0-100000)
min_val, max_val = 0, 100000
else:
# Later training: 1-8 digit numbers (0-10000000)
min_val, max_val = 0, 10000000

number = random.randint(min_val, max_val)
return str(number)


@dataclass
Expand All @@ -369,13 +422,15 @@ class DatasetActor(ForgeActor):
@endpoint
def setup(self):
self._tokenizer = get_tokenizer(self.model)
self._iterator = iter(SumDigitsDataset(self._tokenizer))
self._dataset = SumDigitsDataset(self._tokenizer)

@endpoint
async def sample(self) -> dict[str, str] | None:
async def sample(self, step: int = 0) -> dict[str, str] | None:
"""Sample with progressive difficulty based on training step."""
try:
return next(self._iterator)
except StopIteration:
return self._dataset.generate_sample(step)
except Exception as e:
self.logger.error(f"Error generating sample: {e}")
return None

@endpoint
Expand All @@ -397,14 +452,22 @@ async def main(cfg: DictConfig):

# ---- Setup services ---- #
await ts.initialize()
(dataloader, policy, trainer, replay_buffer, reward_actor,) = await asyncio.gather(
(
dataloader,
policy,
trainer,
replay_buffer,
reward_actor,
ref_model,
) = await asyncio.gather(
DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
Trainer.options(**cfg.services.trainer).as_service(**cfg.trainer),
ReplayBuffer.options(**cfg.services.replay_buffer).as_service(
**cfg.replay_buffer
),
RewardActor.options(**cfg.services.reward_actor).as_service(),
RefModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
)

print("All services initialized successfully!")
Expand All @@ -414,7 +477,8 @@ async def continuous_rollouts():
rollout_count = 0
pad_id = await dataloader.pad_token.choose()
while True:
sample = await dataloader.sample.choose()
# Pass rollout_count for curriculum learning
sample = await dataloader.sample.choose(rollout_count)
if sample is None:
print("Dataloader is empty, exiting continuous rollout")
return
Expand Down Expand Up @@ -446,13 +510,13 @@ async def continuous_rollouts():
)
]
)
episode.ref_logprobs = await ref_model.forward.choose(episode)
episode.reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=response.text, target=target
)
episode.advantage = episode.reward # simple case for now
for episode in group.episodes:
await replay_buffer.add.choose(episode)

avg_response_len = (
sum(len(e.response_tokens) for e in group.episodes) / group_size
)
Expand Down
12 changes: 8 additions & 4 deletions apps/toy_rl/sumdigits.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Toy app Training Configuration

# Global configuration
group_size: 8
batch_size: 16
max_req_tokens: 512
max_res_tokens: 512
group_size: 6
batch_size: 12
max_req_tokens: 64
max_res_tokens: 64
model: "Qwen/Qwen2.5-0.5B-Instruct"

# Dataset configuration
Expand Down Expand Up @@ -60,3 +60,7 @@ services:
procs_per_replica: 1
num_replicas: 1
with_gpus: false
ref_model:
procs_per_replica: 1
num_replicas: 1
with_gpus: true
Loading
Loading