diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 15642272c..519798e8b 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -24,6 +24,7 @@ from forge.controller.provisioner import shutdown from forge.data.rewards import MathReward, ThinkingReward from forge.util.metric_logging import get_metric_logger +from forge.util.ops import selective_log_softmax from monarch.actor import endpoint from omegaconf import DictConfig from vllm.transformers_utils.tokenizer import get_tokenizer @@ -43,7 +44,7 @@ class Episode: response: str | None = None request_tokens: list[int] | None = None response_tokens: list[int] | None = None - ref_logprobs: torch.Tensor | None = None + ref_logits: torch.Tensor | None = None reward: float | None = None advantage: float | None = None @@ -107,8 +108,8 @@ def collate(batches: list[list[Episode]]): response = [e.response_tensor for e in batch] response = torch.stack(response) # [b x s] - ref_logprobs = [e.ref_logprobs for e in batch] - ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s] + ref_logits = [e.ref_logits for e in batch] + ref_logits = torch.stack(ref_logits).squeeze() # [b x s] advantages = [e.advantage for e in batch] advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1] @@ -119,7 +120,7 @@ def collate(batches: list[list[Episode]]): input = {"tokens": torch.cat([request, response], dim=1)} target = { "response": response, - "ref_logprobs": ref_logprobs, + "ref_logits": ref_logits, "advantages": advantages, "padding_mask": mask, } @@ -129,30 +130,35 @@ def collate(batches: list[list[Episode]]): def compute_logprobs( - logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 + logits: torch.Tensor, target_ids: torch.Tensor, temperature: float = 1.0 ) -> torch.Tensor: - context_length = logits.shape[1] - input_ids.shape[1] - logits = logits[:, context_length - 1 : -1] - logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device) - logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) - return logprobs + logits = logits[:, -target_ids.size(1) : -1, :].float() + scaled_logits = logits / temperature + logprobs = selective_log_softmax(scaled_logits, target_ids) + return logprobs.to(target_ids.device) def simple_grpo_loss( logits: torch.Tensor, response: torch.Tensor, - ref_logprobs: torch.Tensor, + ref_logits: torch.Tensor, advantages: torch.Tensor, padding_mask: torch.Tensor, beta: float = 0.1, ) -> torch.Tensor: + print(f"num of padding: {padding_mask.sum(dim=1)}") + # assert ref_logits.dtype == torch.long + # assert logits.dtype == torch.long logprobs = compute_logprobs(logits, response) + ref_logprobs = compute_logprobs(ref_logits, response) kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 + print(f"kl (no padding): {(kl * padding_mask).mean(dim=1)}") + # Pad out via padding mask per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages per_token_loss = -(per_token_policy_loss - beta * kl) loss = ( - ((per_token_loss * padding_mask).sum(dim=1)) - / (padding_mask.sum(dim=1).clamp(min=1.0)) + (per_token_loss * padding_mask).sum(dim=1) + / padding_mask.sum(dim=1).clamp(min=1.0) ).mean() return loss @@ -299,28 +305,31 @@ async def continuous_rollouts(): target=target, ) - input_ids = torch.ones( - (group_size, max_req_tokens + max_req_tokens), - dtype=torch.long, - device="cuda", - ) - # Populate episode info and calculate rewards - for i, (episode, response) in enumerate(zip(group.episodes, responses)): + # Populate episode info, compute ref logprobs, and calculate rewards + for episode, response in zip(group.episodes, responses): episode.request_tokens = response.prompt_ids episode.response_tokens = response.token_ids episode.response = response.text - input_ids[i, :max_req_tokens] = episode.request_tensor - input_ids[i, max_req_tokens:] = episode.response_tensor + episode.ref_logits = await ref_model.forward.choose( + torch.cat( + [episode.request_tensor, episode.response_tensor] + ).unsqueeze(0) + ) episode.reward = await reward_actor.evaluate_response.choose( prompt=prompt, response=response.text, target=target ) - # Calculate reference logprobs - ref_logits = await ref_model.forward.choose(input_ids) - ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:]) - for i, episode in enumerate(group.episodes): - episode.ref_logprobs = ref_logprobs[i] - del ref_logits, ref_logprobs, input_ids + # # Calculate reference logprobs + # print(f" input ids dtype: {input_ids.dtype}") + # ref_logits = await ref_model.forward.choose(input_ids) + # # ref_logits = ref_logits[:, :-1, :] # Exclude the last token + # # ref_logits = ref_logits[:, -max_res_tokens:, :] + # print(f" ref logits dtype: {ref_logits.dtype}") + # print("Computed ref logits") + # # ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:]) + # for i, episode in enumerate(group.episodes): + # episode.ref_logits = ref_logits[i] + # del ref_logits, input_ids # Calculate advantages and add to replay buffer advantages = await compute_advantages.compute.choose(group) @@ -342,15 +351,22 @@ async def continuous_rollouts(): async def continuous_training(): training_step = 0 + _tokenizer = get_tokenizer("Qwen/Qwen3-1.7B") while True: batch = await replay_buffer.sample.choose(curr_policy_version=training_step) if batch is None: await asyncio.sleep(0.1) else: inputs, targets = batch - loss = await trainer.train_step.choose(inputs, targets) + tokens = inputs[0]["tokens"] + print(f"Training input: {_tokenizer.batch_decode(tokens)}") + print(f"Num of padding tokens: {targets[0]['padding_mask'].sum(dim=1)}") + metrics = await trainer.train_step.choose(inputs, targets) training_step += 1 - mlogger.log("loss/training_step", loss, training_step) + mlogger.log("loss/training_step", metrics["loss"], training_step) + mlogger.log( + "grad_norm/training_step", metrics["grad_norm"], training_step + ) await trainer.push_weights.call(training_step) await policy.update_weights.call(training_step) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 5c0481528..5c7a5aff2 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -3,11 +3,11 @@ # Global configuration group_size: 8 -batch_size: 16 +batch_size: 8 max_req_tokens: 512 max_res_tokens: 512 model: "Qwen/Qwen3-1.7B" -off_by_n: 1 # Off by one by default +off_by_n: 0 # Off by one by default # Dataset configuration dataset: @@ -24,11 +24,14 @@ policy: tensor_parallel_size: 1 pipeline_parallel_size: 1 enforce_eager: false + dtype: "float32" + gpu_memory_utilization: 0.9 sampling_config: n: ${group_size} max_tokens: ${max_res_tokens} temperature: 1.0 top_p: 1.0 + seed: 42 # Trainer configuration trainer: @@ -47,7 +50,7 @@ trainer: seq_len: 2048 max_norm: 1.0 steps: 1000000 - dtype: bfloat16 + dtype: float32 gc_freq: 1 compile: enable: false @@ -83,7 +86,7 @@ ref_model: flavor: 1.7B hf_assets_path: hf://${model} training: - dtype: bfloat16 + dtype: float32 gc_freq: 1 compile: enable: false diff --git a/apps/rl/main.py b/apps/rl/main.py index 9f7314f53..97dff5053 100644 --- a/apps/rl/main.py +++ b/apps/rl/main.py @@ -167,6 +167,7 @@ async def run(cfg: DictConfig): inputs, targets = await replay_buffer.sample.choose(curr_policy_version=0) outputs = await trainer.train_step.choose(inputs, targets) print("Loss: ", outputs["loss"]) + print("Gradient Norm: ", outputs["grad_norm"]) print("Shutting down...") await trainer.shutdown() diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index de21b3e9e..aa3352e0c 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -19,6 +19,15 @@ import torch import torch.distributed.checkpoint as dcp import torchstore as ts + +from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh + +from forge.data.sharding import VLLMSharding +from forge.data_models.completion import Completion +from forge.data_models.prompt import to_prompt + +from forge.interfaces import Policy as PolicyInterface +from forge.types import ProcessConfig from monarch.actor import current_rank, endpoint, ProcMesh from torchstore.state_dict_utils import DELIM from vllm.config import VllmConfig @@ -43,15 +52,6 @@ from vllm.v1.structured_output import StructuredOutputManager from vllm.worker.worker_base import WorkerWrapperBase -from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh - -from forge.data.sharding import VLLMSharding -from forge.data_models.completion import Completion -from forge.data_models.prompt import to_prompt - -from forge.interfaces import Policy as PolicyInterface -from forge.types import ProcessConfig - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -77,6 +77,7 @@ class SamplingConfig: temperature: float = 1.0 top_p: float = 1.0 logprobs: int = 1 + seed: int | None = None def __post_init__(self): super().__init__() diff --git a/src/forge/actors/reference_model.py b/src/forge/actors/reference_model.py index 3cee6cae5..89ed91686 100644 --- a/src/forge/actors/reference_model.py +++ b/src/forge/actors/reference_model.py @@ -13,6 +13,8 @@ from dataclasses import dataclass, field, fields import torch + +from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint from torch.distributed.tensor import DTensor @@ -26,8 +28,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.controller import ForgeActor - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -86,13 +86,15 @@ def __post_init__(self): async def setup(self): engine_config = {f.name: getattr(self, f.name) for f in fields(self)} self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) + self.model = self.engine.model_parts[0] # Currently not using PP + self.model.eval() @endpoint async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: self.engine.gc_handler.run(self.step) - model_parts = self.engine.model_parts parallel_dims = self.engine.parallel_dims input_ids = input_ids.to("cuda") + # print(f"Ref model input_ids: {input_ids}") # optional_context_parallel_ctx = ( # dist_utils.create_context_parallel_ctx( # cp_mesh=parallel_dims.world_mesh["cp"], @@ -112,7 +114,7 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor: with self.engine.train_context(optional_context_parallel_ctx): with self.engine.maybe_enable_amp: with torch.inference_mode(): - logits = model_parts[0](input_ids) + logits = self.model(input_ids) self.step += 1 if isinstance(logits, DTensor): logits = logits.full_tensor() diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index f9a5cd962..a37b48e8c 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -15,7 +15,8 @@ import torch import torch.distributed.checkpoint as dcp import torchstore as ts - +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict @@ -32,11 +33,10 @@ Parallelism, Training, ) +from torchtitan.distributed import utils as dist_utils from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig - -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device +from transformers import AutoModelForCausalLM logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -166,6 +166,10 @@ def forward_backward( assert len(model_parts) == 1 with self.engine.maybe_enable_amp: logits = model_parts[0](**inputs) + # hf_logits = self.hf_model(input_ids=inputs["tokens"]).logits.to( + # "cpu" + # ) + # assert torch.allclose(logits, hf_logits) loss = self.loss(logits, **targets) # need to free to before bwd to avoid peaking memory del logits @@ -176,7 +180,7 @@ def forward_backward( @endpoint def train_step( self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]] - ) -> float: + ) -> dict[str, float]: self.engine.gc_handler.run(self.step) local_inputs = inputs[self.engine.dp_rank] local_targets = targets[self.engine.dp_rank] @@ -193,6 +197,17 @@ def train_step( loss = self.forward_backward(local_inputs, local_targets) torch.distributed.all_reduce(loss) + grad_norm = dist_utils.clip_grad_norm_( + [p for m in self.engine.model_parts for p in m.parameters()], + self.training.max_norm, + foreach=True, + pp_mesh=None, + # ( + # self.engine.parallel_dims.world_mesh["pp"] if self.engine.parallel_dims.pp_enabled else None + # ), + ep_enabled=False, # parallel_dims.ep_enabled, + ) + self.engine.optimizers.step() self.engine.optimizers.zero_grad() self.engine.lr_schedulers.step() @@ -203,7 +218,7 @@ def train_step( last_step=self.step == self.num_training_steps, ) - return loss.item() + return {"loss": loss.item(), "grad_norm": grad_norm} @endpoint async def push_weights(self, policy_version: int) -> None: