diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 15234f928..89fb83a12 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -23,10 +23,10 @@ from forge.controller.actor import ForgeActor from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from forge.data.rewards import MathReward, ThinkingReward +from forge.data.utils import exclude_service from forge.util.metric_logging import get_metric_logger from monarch.actor import endpoint from omegaconf import DictConfig -from src.forge.data.utils import exclude_service from torch import nn from torchstore.state_dict_utils import DELIM from transformers import AutoModelForCausalLM diff --git a/apps/rl/llama3_8b.yaml b/apps/rl/llama3_8b.yaml index 45b5eca28..a7d2e3d96 100644 --- a/apps/rl/llama3_8b.yaml +++ b/apps/rl/llama3_8b.yaml @@ -13,8 +13,7 @@ trainer: model: name: llama3 flavor: 8B - tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct - + hf_assets_path: /tmp/Meta-Llama-3.1-8B-Instruct optimizer: name: AdamW @@ -36,7 +35,7 @@ trainer: parallelism: data_parallel_replicate_degree: 1 - data_parallel_shard_degree: -1 + data_parallel_shard_degree: 4 tensor_parallel_degree: 1 pipeline_parallel_degree: 1 context_parallel_degree: 1 @@ -57,76 +56,7 @@ trainer: selective_ac_option: op replay_buffer: - batch_size: 2 + batch_size: 4 max_policy_age: 2 seed: None - -# policy: -# scheduler: -# scheduler: local # local | mast (not supported yet) -# num_hosts: 1 -# num_gpus: 1 -# oncall: torchtune -# identity: pytorch_distributed -# image: forge_workspace:latest -# -# model: "meta-llama/Llama-3.1-8B-Instruct" -# tensor_parallel_size: 2 -# pipeline_parallel_size: 1 -# enforce_eager: false - -# postprocessor: -# scheduler: -# scheduler: local # local | mast (not supported yet) -# num_hosts: 1 -# num_gpus: 1 -# oncall: torchtune -# identity: pytorch_distributed -# image: forge_workspace:latest -# -# comm: -# trace_buf_size: 0 -# -# optimizer: -# name: AdamW -# lr: 1e-5 -# eps: 1e-8 -# -# lr_scheduler: -# warmup_steps: 1 -# -# training: -# local_batch_size: 1 -# seq_len: 2048 -# max_norm: 1.0 -# steps: 5 -# compile: false -# dataset: "c4" -# -# parallelism: -# data_parallel_replicate_degree: 1 -# data_parallel_shard_degree: -1 -# tensor_parallel_degree: 1 -# pipeline_parallel_degree: 1 -# context_parallel_degree: 1 -# expert_parallel_degree: 1 -# disable_loss_parallel: false -# -# checkpoint: -# enable: true -# folder: /tmp/Meta-Llama-3.1-8B-Instruct/ -# interval: 500 -# async_mode: "disabled" -# -# activation_checkpoint: -# mode: selective -# selective_ac_option: op -# - -# profiling: -# enable_profiling: false - -# metrics: -# log_freq: 10 -# enable_tensorboard: true -# save_tb_folder: "tb" + dp_size: 4 diff --git a/apps/rl/main.py b/apps/rl/main.py index 02255b063..7d00eb09e 100644 --- a/apps/rl/main.py +++ b/apps/rl/main.py @@ -13,32 +13,169 @@ import asyncio import logging import sys +from dataclasses import dataclass +from typing import Any +import torch +import torch.nn.functional as F from forge.actors import ReplayBuffer, RLTrainer from forge.cli.config import parse from forge.controller.service import ServiceConfig, shutdown_service, spawn_service from omegaconf import DictConfig +from torch import Tensor logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +@dataclass +class Episode: + # TODO: add adtional layer for multi-turn + episode_id: str + request: str + policy_version: int + pad_id: int + request_len: int + response_len: int + target: Any | None = None + # processed data + response: str | None = None + request_tokens: list[int] | None = None + response_tokens: list[int] | None = None + ref_logprobs: Tensor | None = None + reward: float | None = None + advantage: float | None = None + + @property + def request_tensor(self): + tensor = torch.tensor(self.request_tokens, dtype=torch.long) + if tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + return tensor + + @property + def response_tensor(self): + tensor = torch.tensor(self.response_tokens, dtype=torch.long) + if tensor.shape[0] < self.response_len: # right pad + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + return tensor + + +def collate(batches: list[list[Episode]]): + inputs = [] + targets = [] + for batch in batches: + request = [e.request_tensor for e in batch] + request = torch.stack(request) # [b x s] + + response = [e.response_tensor for e in batch] + response = torch.stack(response) # [b x s] + + ref_logprobs = [e.ref_logprobs for e in batch] + ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s] + + advantages = [e.advantage for e in batch] + advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1] + + pad_id = batch[0].pad_id + mask = response != pad_id + + input = {"tokens": torch.cat([request, response], dim=1)} + target = { + "response": response, + "ref_logprobs": ref_logprobs, + "advantages": advantages, + "padding_mask": mask, + } + inputs.append(input) + targets.append(target) + return inputs, targets + + +def compute_logprobs( + logits: Tensor, input_ids: torch.Tensor, temperature: float = 1.0 +) -> Tensor: + context_length = logits.shape[1] - input_ids.shape[1] + + # Truncate request logits and drop last + logits = logits[:, context_length - 1 : -1] + + # Compute logprobs + logprobs = torch.log_softmax(logits / temperature, dim=-1) + logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) + + return logprobs + + +def simple_grpo_loss( + logits: Tensor, + response: Tensor, + ref_logprobs: Tensor, + advantages: Tensor, + padding_mask: Tensor, + beta: float = 0.1, +): + """Simplified GRPO Loss for simplified single step updates + Copied from https://github.com/pytorch/torchtune/blob/main/torchtune/dev/grpo/loss.py. + """ + logprobs = compute_logprobs(logits, response) + per_token_kl = ( + torch.exp(ref_logprobs.detach() - logprobs) + - (ref_logprobs.detach() - logprobs) + - 1 + ) + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_loss = -(per_token_policy_loss - beta * per_token_kl) + loss = ( + (per_token_loss * padding_mask).sum(dim=1) / (padding_mask.sum(dim=1) + 1e-8) + ).mean() + return loss + + async def run(cfg: DictConfig): trainer, replay_buffer = await asyncio.gather( spawn_service( - ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=4), + ServiceConfig(procs_per_replica=4, with_gpus=True, num_replicas=1), RLTrainer, + loss=simple_grpo_loss, **cfg.trainer, ), spawn_service( ServiceConfig(procs_per_replica=1, num_replicas=1), ReplayBuffer, + collate=collate, **cfg.replay_buffer, ), ) - print("Services initialized....") + print("Services initialized...") + + print("Collecting Data...") + g = torch.manual_seed(0) + global_batch_size = cfg.replay_buffer.batch_size * cfg.replay_buffer.dp_size + for i in range(global_batch_size): + req_len, res_len = torch.randint(64, 256, (2,), generator=g).tolist() + e = Episode( + episode_id=i, + request="", + policy_version=0, + pad_id=0, + request_len=256, + response_len=256, + request_tokens=torch.randint(64_000, (req_len,), generator=g).tolist(), + response_tokens=torch.randint(64_000, (res_len,), generator=g).tolist(), + ref_logprobs=torch.randn((256,), generator=g), + advantage=torch.randn((1,), generator=g), + ) + await replay_buffer.add.choose(e) + + print("Train step...") + inputs, targets = await replay_buffer.sample.choose(curr_policy_version=0) + outputs = await trainer.train_step.choose(inputs, targets) + print("Loss: ", outputs["loss"]) - print("shutting down...") + print("Shutting down...") await shutdown_service(trainer) await shutdown_service(replay_buffer) diff --git a/src/forge/actors/replay_buffer.py b/src/forge/actors/replay_buffer.py index f53bfd4ec..985fc4052 100644 --- a/src/forge/actors/replay_buffer.py +++ b/src/forge/actors/replay_buffer.py @@ -6,7 +6,7 @@ import random from dataclasses import dataclass -from typing import Any +from typing import Any, Callable from monarch.actor import endpoint @@ -21,6 +21,7 @@ class ReplayBuffer(ForgeActor): max_policy_age: int dp_size: int = 1 seed: int | None = None + collate: Callable = lambda batch: batch @endpoint async def setup(self) -> None: @@ -31,7 +32,7 @@ async def setup(self) -> None: self.sampler = random.sample @endpoint - async def add(self, episode) -> None: + async def add(self, episode: "Episode") -> None: self.buffer.append(episode) @endpoint @@ -55,7 +56,7 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None): if total_samples > len(self.buffer): return None - # TODO: Make this more efficient + # TODO: prefetch samples in advance idx_to_sample = self.sampler(range(len(self.buffer)), k=total_samples) # Pop episodes in descending order to avoid shifting issues popped = [self.buffer.pop(i) for i in sorted(idx_to_sample, reverse=True)] @@ -71,7 +72,7 @@ async def sample(self, curr_policy_version: int, batch_size: int | None = None): for dp_idx in range(self.dp_size) ] - return reshaped_episodes + return self.collate(reshaped_episodes) @endpoint async def evict(self, curr_policy_version: int) -> None: diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 062fabe8a..767a08cd2 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -10,9 +10,10 @@ import os from collections.abc import Mapping from dataclasses import dataclass, field, fields +from typing import Callable -import torch from monarch.actor import current_rank, current_size, endpoint +from torch import Tensor from torchtitan.config.job_config import ( ActivationCheckpoint, Checkpoint, @@ -31,6 +32,7 @@ from torchtitan.experiments.forge.job_config import ForgeJobConfig from forge.controller import ForgeActor +from forge.data.utils import batch_to_device logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -50,6 +52,7 @@ class RLTrainer(ForgeActor): compile: Compile = field(default_factory=Compile) float8: Float8 = field(default_factory=Float8) comm: Comm = field(default_factory=Comm) + loss: Callable = lambda logits, **targets: logits def __post_init__(self): """Initializes config types and env variables. @@ -92,20 +95,17 @@ def __post_init__(self): async def setup(self): # TODO: update ForgeEngine to not use ForgeJobConfig engine_config = {f.name: getattr(self, f.name) for f in fields(self)} + engine_config.pop("loss") # Not part of job config self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) self.engine.checkpointer.load(step=self.current_step) self.engine.optimizers.zero_grad() - def forward_backward( - self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor - ) -> torch.Tensor: + def forward_backward(self, inputs: dict[Tensor], targets: dict[Tensor]) -> Tensor: model_parts = self.engine.model_parts parallel_dims = self.engine.parallel_dims # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage - inputs = input_dict["tokens"] - if getattr(self.engine.model_args, "use_flex_attn", False): cp_mesh = ( parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None @@ -155,30 +155,34 @@ def forward_backward( with self.engine.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 with self.engine.maybe_enable_amp: - pred = model_parts[0](inputs) - loss = self.engine.loss_fn(pred, labels) + logits = model_parts[0](**inputs) + loss = self.loss(logits, **targets) # need to free to before bwd to avoid peaking memory - del pred + del logits loss.backward() return loss @endpoint - def train_step(self, batch) -> None: - # Move tensors to the appropriate device - for k, v in batch.items(): - if isinstance(v, torch.Tensor): - batch[k] = v.to("cuda") # TODO: hardcoded for now - + def train_step( + self, inputs: list[dict[Tensor]], targets: list[dict[Tensor]] + ) -> None: + inputs = inputs[self.engine.dp_rank] + targets = targets[self.engine.dp_rank] + batch_to_device(inputs, self.engine.device) + batch_to_device(targets, self.engine.device) + + # compute policy logprobs # TODO implement gradient accumulation # with GradientAccumulation( # self.gradient_accumulation_steps, # self.model, # self.data_parallel_size, # ) as grad_acc: - # TODO: convert to GRPO Loss - labels = batch.pop("labels") - loss = self.forward_backward(batch, labels) + loss = self.forward_backward(inputs, targets) + + # # Gradient clipping (optional but recommended for stability) + # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) self.engine.optimizers.step() self.engine.optimizers.zero_grad() @@ -190,75 +194,7 @@ def train_step(self, batch) -> None: last_step=self.current_step == self.num_training_steps, ) - # TODO: integrate the grpo app step with the above step - # def train_step(self, self, batch: list(Episode)): - # total_loss = 0.0 - # num_groups_processed = 0 - # - # for episode in batch: - # groups = episode.groups - # - # # Collect all response texts and corresponding data - # response_texts = [] - # ref_logprobs_list = [] - # advantages_list = [] - # - # for group in groups: - # response_texts.append(group.response) - # ref_logprobs_list.append(group.ref_logprobs) - # advantages_list.append(group.advantage) - # - # # Tokenize all responses in batch - # tokenized = self.tokenizer( - # response_texts, - # padding=True, - # truncation=True, - # return_tensors="pt", - # max_length=512, # Adjust based on your needs - # ) - # - # input_ids = tokenized["input_ids"].to(self.device) - # attention_mask = tokenized["attention_mask"].to(self.device) - # - # # Compute current policy log probabilities using the model - # current_logprobs = compute_sequence_logprobs( - # self.model, input_ids, attention_mask, requires_grad=True - # ) - # - # # Convert ref_logprobs and advantages to tensors - # ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device) - # advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to( - # self.device - # ) - # - # # Compute GRPO loss components - # # Ratio between current policy and reference policy - # ratio = torch.exp(current_logprobs - ref_logprobs_tensor) - # - # # Policy gradient loss weighted by advantages - # pg_loss = -torch.mean(ratio * advantages_tensor) - # - # # KL penalty to prevent policy from deviating too far from reference - # kl_penalty = self.beta * torch.mean( - # (current_logprobs - ref_logprobs_tensor) ** 2 - # ) - # - # # Total GRPO loss - # loss = pg_loss + kl_penalty - # total_loss += loss.item() - # num_groups_processed += len(groups) - # - # self.optimizer.zero_grad() - # loss.backward() - # - # # Gradient clipping (optional but recommended for stability) - # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) - # - # self.optimizer.step() - # - # avg_loss = total_loss / len(batch) if batch else 0.0 - # - # return {"loss": avg_loss, "groups_processed": num_groups_processed} + return {"loss": loss.item()} @endpoint def push_weights(self) -> None: @@ -270,9 +206,7 @@ async def cleanup(self) -> None: self.engine.checkpointer.close() -def _qwen3_hf_to_vllm( - sd: dict[str, torch.Tensor], num_layers: int -) -> dict[str, torch.Tensor]: +def _qwen3_hf_to_vllm(sd: dict[str, Tensor], num_layers: int) -> dict[str, Tensor]: """Convert transformers state dict to vLLM format. Specifically, this fuses QKV projection and MLP gate_up_proj layers.