diff --git a/.gitignore b/.gitignore index 14e5f66e1..9d1759aef 100644 --- a/.gitignore +++ b/.gitignore @@ -41,6 +41,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +.rsyncignore # Django stuff *.log diff --git a/apps/grpo/main_no_reward.py b/apps/grpo/main_no_reward.py new file mode 100644 index 000000000..93b69e03a --- /dev/null +++ b/apps/grpo/main_no_reward.py @@ -0,0 +1,389 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml + +import asyncio +import uuid +from dataclasses import dataclass +from typing import Any, Callable + +import torch +import torch.nn.functional as F +import torchstore as ts +from datasets import load_dataset +from forge.actors.policy import Policy +from forge.actors.replay_buffer import ReplayBuffer +from forge.actors.trainer import RLTrainer +from forge.cli.config import parse +from forge.controller.actor import ForgeActor +from forge.controller.provisioner import shutdown +from forge.data.rewards import MathReward, ThinkingReward +from forge.util.metric_logging import get_metric_logger +from monarch.actor import endpoint +from omegaconf import DictConfig +from vllm.transformers_utils.tokenizer import get_tokenizer + + +@dataclass +class Episode: + # TODO: add adtional layer for multi-turn + episode_id: str + request: str + policy_version: int + pad_id: int + request_len: int + response_len: int + target: Any | None = None + # processed data + response: str | None = None + request_tokens: list[int] | None = None + response_tokens: list[int] | None = None + ref_logprobs: torch.Tensor | None = None + reward: float | None = None + advantage: float | None = None + + @property + def request_tensor(self): + tensor = torch.tensor(self.request_tokens, dtype=torch.long) + if tensor.shape[0] < self.request_len: # left pad + diff = self.request_len - tensor.shape[0] + tensor = F.pad(tensor, (diff, 0), value=self.pad_id) + return tensor + + @property + def response_tensor(self): + tensor = torch.tensor(self.response_tokens, dtype=torch.long) + if tensor.shape[0] < self.response_len: # right pad + diff = self.response_len - tensor.shape[0] + tensor = F.pad(tensor, (0, diff), value=self.pad_id) + return tensor + + +@dataclass +class Group: + group_id: str + episodes: list[Episode] + + @classmethod + def new_group( + cls, + group_id: int, + group_size: int, + request: str, + policy_version: int, + pad_id: int, + request_len: int, + response_len: int, + target: Any = None, + ): + episodes = [] + for _ in range(group_size): + episodes.append( + Episode( + episode_id=str(uuid.uuid4()), + request=request, + policy_version=policy_version, + pad_id=pad_id, + request_len=request_len, + response_len=response_len, + target=target, + ) + ) + return cls(str(group_id), episodes) + + +def collate(batches: list[list[Episode]]): + inputs = [] + targets = [] + for batch in batches: + request = [e.request_tensor for e in batch] + request = torch.stack(request) # [b x s] + + response = [e.response_tensor for e in batch] + response = torch.stack(response) # [b x s] + + # mock out the ref logprobs for now + ref_logprobs = torch.zeros(len(batch), batch[0].response_len) # [b x s] + + advantages = [e.advantage for e in batch] + advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1] + + pad_id = batch[0].pad_id + mask = response != pad_id + + input = {"tokens": torch.cat([request, response], dim=1)} + target = { + "response": response, + "ref_logprobs": ref_logprobs, + "advantages": advantages, + "padding_mask": mask, + } + inputs.append(input) + targets.append(target) + return inputs, targets + + +def compute_logprobs( + logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0 +) -> torch.Tensor: + context_length = logits.shape[1] - input_ids.shape[1] + logits = logits[:, context_length - 1 : -1] + logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device) + logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1) + return logprobs + + +def simple_grpo_loss( + logits: torch.Tensor, + response: torch.Tensor, + ref_logprobs: torch.Tensor, + advantages: torch.Tensor, + padding_mask: torch.Tensor, + beta: float = 0.1, +) -> torch.Tensor: + logprobs = compute_logprobs(logits, response) + # kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1 + per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages + per_token_loss = -(per_token_policy_loss) + loss = ( + ((per_token_loss * padding_mask).sum(dim=1)) + / (padding_mask.sum(dim=1).clamp(min=1.0)) + ).mean() + return loss + + +@dataclass +class RewardActor(ForgeActor): + """Reward actor that uses a list of scoring functions.""" + + reward_functions: list[Callable] + + @endpoint + async def evaluate_response(self, prompt: str, response: str, target: str) -> float: + total_rewards = 0.0 + for reward_fn in self.reward_functions: + reward = reward_fn(prompt, response, target) + total_rewards += reward + return total_rewards / len(self.reward_functions) + + +class ComputeAdvantages(ForgeActor): + """Compute advantages for GRPO using reward signals.""" + + @endpoint + async def compute(self, group: Group) -> list[float]: + # TODO: add batch processing + rewards = torch.tensor([[e.reward for e in group.episodes]]) + mean = rewards.mean(1, keepdim=True) + std = rewards.std(1, keepdim=True) + advantages = (rewards - mean) / (std + 1e-4) + return advantages.squeeze(0).tolist() + + +@dataclass +class DatasetActor(ForgeActor): + """Actor wrapper for HuggingFace dataset to provide async interface.""" + + path: str = "openai/gsm8k" + revision: str = "main" + data_split: str = "train" + streaming: bool = True + model: str = "Qwen/Qwen3-1.7B" + + @endpoint + def setup(self): + self._tokenizer = get_tokenizer(self.model) + + def gsm8k_transform(sample): + system_prompt = """ + Put all your scratchpad work between and tags. + Your final answer should be between and tags otherwise it will not be scored. + """ + request: str = sample["question"] + as_chat = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": request}, + ] + formatted_request = self._tokenizer.apply_chat_template( + as_chat, + tokenize=False, + add_generation_prompt=True, + ) + target: str = sample["answer"] + formatted_target = target.split("#### ")[1] + return {"request": formatted_request, "target": formatted_target} + + ds = load_dataset( + self.path, self.revision, split=self.data_split, streaming=self.streaming + ) + ds = ds.map(gsm8k_transform) + ds = ds.shuffle() + self._iterator = iter(ds) + + @endpoint + async def sample(self) -> dict[str, str] | None: + try: + return next(self._iterator) + except StopIteration: + return None + + @endpoint + async def pad_token(self): + return self._tokenizer.pad_token_id + + +async def main(cfg: DictConfig): + """Main GRPO training loop with rollout and training processes.""" + group_size = cfg.group_size + max_req_tokens = cfg.max_req_tokens + max_res_tokens = cfg.max_res_tokens + mlogger = get_metric_logger( + "wandb", + freq=1, + project="grpo-training", + ) + + # ---- Setup services ---- # + await ts.initialize(strategy=ts.ControllerStorageVolumes()) + ( + dataloader, + policy, + trainer, + replay_buffer, + compute_advantages, + # ref_model, + reward_actor, + ) = await asyncio.gather( + DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset), + Policy.options(**cfg.services.policy).as_service(**cfg.policy), + RLTrainer.options(**cfg.services.trainer).as_service( + **cfg.trainer, loss=simple_grpo_loss + ), + ReplayBuffer.options(**cfg.services.replay_buffer).as_service( + **cfg.replay_buffer, collate=collate + ), + ComputeAdvantages.options(**cfg.services.compute_advantages).as_service(), + # ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model), + RewardActor.options(**cfg.services.reward_actor).as_service( + reward_functions=[MathReward(), ThinkingReward()] + ), + ) + print("All services initialized successfully!") + + # ---- Core RL loops ---- # + async def continuous_rollouts(): + rollout_count = 0 + pad_id = await dataloader.pad_token.choose() + while True: + sample = await dataloader.sample.choose() + if sample is None: + print("Dataloader is empty, exiting continuous rollout") + return + prompt, target = sample["request"], sample["target"] + responses = await policy.generate.choose(prompt) + # TODO: this shall be part of the responses metadata instead of a separate call + version = await policy.get_version.choose() + group = Group.new_group( + group_id=rollout_count, + group_size=group_size, + request=prompt, + policy_version=version, + pad_id=pad_id, + request_len=max_req_tokens, + response_len=max_res_tokens, + target=target, + ) + + input_ids = torch.ones( + (group_size, max_req_tokens + max_req_tokens), + dtype=torch.long, + device="cuda", + ) + # Populate episode info and calculate rewards + for i, (episode, response) in enumerate(zip(group.episodes, responses)): + episode.request_tokens = response.prompt_ids + episode.response_tokens = response.token_ids + episode.response = response.text + input_ids[i, :max_req_tokens] = episode.request_tensor + input_ids[i, max_req_tokens:] = episode.response_tensor + episode.reward = await reward_actor.evaluate_response.choose( + prompt=prompt, response=response.text, target=target + ) + + # Calculate reference logprobs + # ref_logits = await ref_model.forward.choose(input_ids) + # ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:]) + # for i, episode in enumerate(group.episodes): + # episode.ref_logprobs = ref_logprobs[i] + # del ref_logits, ref_logprobs, input_ids + + # Calculate advantages and add to replay buffer + advantages = await compute_advantages.compute.choose(group) + for episode, advantage in zip(group.episodes, advantages): + episode.advantage = advantage + await replay_buffer.add.choose(episode) + + # Log metrics + avg_response_len = ( + sum(len(e.response_tokens) for e in group.episodes) / group_size + ) + mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count) + buffer_size = await replay_buffer._numel.choose() + mlogger.log("buffer_size/rollout", buffer_size, rollout_count) + avg_reward = sum(e.reward for e in group.episodes) / group_size + mlogger.log("avg_reward/rollout", avg_reward, rollout_count) + + rollout_count += 1 + + async def continuous_training(): + training_step = 0 + while True: + batch = await replay_buffer.sample.choose(curr_policy_version=training_step) + if batch is None: + await asyncio.sleep(0.1) + else: + inputs, targets = batch + loss = await trainer.train_step.choose(inputs, targets) + training_step += 1 + mlogger.log("loss/training_step", loss, training_step) + await trainer.push_weights.call(training_step) + await policy.update_weights.call(training_step) + + print("Starting GRPO training loops...") + # TODO: Start multiple rollouts once all serivces support it + rollout_task = asyncio.create_task(continuous_rollouts()) + training_task = asyncio.create_task(continuous_training()) + + try: + await asyncio.gather(rollout_task, training_task) + except KeyboardInterrupt: + print("Training interrupted by user") + rollout_task.cancel() + training_task.cancel() + finally: + print("Shutting down...") + await asyncio.gather( + dataloader.shutdown(), + policy.shutdown(), + trainer.shutdown(), + replay_buffer.shutdown(), + compute_advantages.shutdown(), + # ref_model.shutdown(), + reward_actor.shutdown(), + ) + # TODO - add a global shutdown that implicitly shuts down all services + # and remote allocations + await shutdown() + + +if __name__ == "__main__": + + @parse + def _main(cfg): + asyncio.run(main(cfg)) + + _main() # @parse grabs the cfg from CLI diff --git a/apps/grpo/qwen3_30b_moe.yaml b/apps/grpo/qwen3_30b_moe.yaml new file mode 100644 index 000000000..413a8b1dc --- /dev/null +++ b/apps/grpo/qwen3_30b_moe.yaml @@ -0,0 +1,107 @@ +# Grouped Relative Policy Optimization (GRPO) +# >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml + +# Global configuration +group_size: 2 +batch_size: 4 +max_req_tokens: 512 +max_res_tokens: 512 +model: "Qwen/Qwen3-30B-A3B" +off_by_n: 1 # Off by one by default + +# Dataset configuration +dataset: + path: "openai/gsm8k" + revision: "main" + data_split: "train" + streaming: true + model: ${model} + +# Policy configuration +policy: + engine_config: + model: ${model} + tensor_parallel_size: 8 + pipeline_parallel_size: 1 + enable_expert_parallel: true + enforce_eager: false + sampling_config: + n: ${group_size} + max_tokens: ${max_res_tokens} + temperature: 1.0 + top_p: 1.0 + +# Trainer configuration +trainer: + model: + name: qwen3 + # TODO: check titan trainer + flavor: 30B-A3B + hf_assets_path: hf://${model} + optimizer: + name: AdamW + lr: 1e-5 + eps: 1e-8 + lr_scheduler: + warmup_steps: 1 + training: + local_batch_size: ${batch_size} + seq_len: 2048 + max_norm: 1.0 + steps: 1000000 + dtype: bfloat16 + compile: + enable: false + parallelism: + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: -1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + context_parallel_degree: 1 + expert_parallel_degree: 8 + disable_loss_parallel: true + checkpoint: + enable: true + initial_load_path: hf://${model} + initial_load_in_hf: true + last_save_in_hf: true + interval: 500 + async_mode: "disabled" + activation_checkpoint: + mode: selective + selective_ac_option: op + +# Replay buffer configuration +replay_buffer: + batch_size: ${batch_size} + max_policy_age: ${off_by_n} + # TODO: check if we need to change this + dp_size: 8 + +# All resource allocations +services: + dataset: + procs: 1 + num_replicas: 1 + with_gpus: false + policy: + procs: ${policy.engine_config.tensor_parallel_size} + num_replicas: 1 + with_gpus: true + trainer: + procs: 8 + hosts: 1 + num_replicas: 1 + with_gpus: true + replay_buffer: + procs: 1 + num_replicas: 1 + with_gpus: false + compute_advantages: + procs: 1 + num_replicas: 1 + with_gpus: false + reward_actor: + procs: 1 + num_replicas: 1 + with_gpus: false diff --git a/apps/vllm/main.py b/apps/vllm/main.py index 6ba1bbbaf..e4d482f60 100644 --- a/apps/vllm/main.py +++ b/apps/vllm/main.py @@ -36,7 +36,7 @@ async def run(cfg: DictConfig): import time print("Requesting generation...") - n = 100 + n = 5 start = time.time() response_outputs: list[Completion] = await asyncio.gather( *[policy.generate.choose(prompt=prompt) for _ in range(n)] diff --git a/apps/vllm/qwen2_5_32b.yaml b/apps/vllm/qwen2_5_32b.yaml index a7f799bce..8cc9efd79 100644 --- a/apps/vllm/qwen2_5_32b.yaml +++ b/apps/vllm/qwen2_5_32b.yaml @@ -12,7 +12,7 @@ policy: services: policy: procs: 4 - hosts: 1 + # hosts: 1 num_replicas: 1 with_gpus: true diff --git a/apps/vllm/qwen3_30b_moe.yaml b/apps/vllm/qwen3_30b_moe.yaml new file mode 100644 index 000000000..3c03aa3ac --- /dev/null +++ b/apps/vllm/qwen3_30b_moe.yaml @@ -0,0 +1,21 @@ +policy: + engine_config: + model: "Qwen/Qwen3-30B-A3B" + tensor_parallel_size: 8 + pipeline_parallel_size: 1 + enable_expert_parallel: true + enforce_eager: true + sampling_config: + n: 2 + guided_decoding: false + max_tokens: 512 + +services: + policy: + procs: ${policy.engine_config.tensor_parallel_size} + num_replicas: 1 + with_gpus: true + + +# Optional, otherwise argparse fallback kicks in +prompt: "Tell me a joke" diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index 3cc1e6a48..fb8466c28 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -206,6 +206,7 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T: """ logger.info("Spawning single actor %s", cls.__name__) actor = await cls.launch(*args, **actor_kwargs) + logger.info("Successfully spawned single actor %s", cls.__name__) return actor @classmethod diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 26d51ea5c..7691fd444 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -192,6 +192,10 @@ def bootstrap(gpu_ids: list[str]): os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600" os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824" + os.environ["VLLM_LOG_LEVEL"] = "DEBUG" + os.environ["NCCL_DEBUG"] = "INFO" + os.environ["NCCL_DEBUG_SUBSYS"] = "INIT" + gpu_ids = gpu_manager.get_gpus(num_procs) procs = host_mesh.spawn_procs( per_host={"gpus": num_procs},