diff --git a/.gitignore b/.gitignore
index 14e5f66e1..9d1759aef 100644
--- a/.gitignore
+++ b/.gitignore
@@ -41,6 +41,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
+.rsyncignore
# Django stuff
*.log
diff --git a/apps/grpo/main_no_reward.py b/apps/grpo/main_no_reward.py
new file mode 100644
index 000000000..93b69e03a
--- /dev/null
+++ b/apps/grpo/main_no_reward.py
@@ -0,0 +1,389 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
+
+import asyncio
+import uuid
+from dataclasses import dataclass
+from typing import Any, Callable
+
+import torch
+import torch.nn.functional as F
+import torchstore as ts
+from datasets import load_dataset
+from forge.actors.policy import Policy
+from forge.actors.replay_buffer import ReplayBuffer
+from forge.actors.trainer import RLTrainer
+from forge.cli.config import parse
+from forge.controller.actor import ForgeActor
+from forge.controller.provisioner import shutdown
+from forge.data.rewards import MathReward, ThinkingReward
+from forge.util.metric_logging import get_metric_logger
+from monarch.actor import endpoint
+from omegaconf import DictConfig
+from vllm.transformers_utils.tokenizer import get_tokenizer
+
+
+@dataclass
+class Episode:
+ # TODO: add adtional layer for multi-turn
+ episode_id: str
+ request: str
+ policy_version: int
+ pad_id: int
+ request_len: int
+ response_len: int
+ target: Any | None = None
+ # processed data
+ response: str | None = None
+ request_tokens: list[int] | None = None
+ response_tokens: list[int] | None = None
+ ref_logprobs: torch.Tensor | None = None
+ reward: float | None = None
+ advantage: float | None = None
+
+ @property
+ def request_tensor(self):
+ tensor = torch.tensor(self.request_tokens, dtype=torch.long)
+ if tensor.shape[0] < self.request_len: # left pad
+ diff = self.request_len - tensor.shape[0]
+ tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
+ return tensor
+
+ @property
+ def response_tensor(self):
+ tensor = torch.tensor(self.response_tokens, dtype=torch.long)
+ if tensor.shape[0] < self.response_len: # right pad
+ diff = self.response_len - tensor.shape[0]
+ tensor = F.pad(tensor, (0, diff), value=self.pad_id)
+ return tensor
+
+
+@dataclass
+class Group:
+ group_id: str
+ episodes: list[Episode]
+
+ @classmethod
+ def new_group(
+ cls,
+ group_id: int,
+ group_size: int,
+ request: str,
+ policy_version: int,
+ pad_id: int,
+ request_len: int,
+ response_len: int,
+ target: Any = None,
+ ):
+ episodes = []
+ for _ in range(group_size):
+ episodes.append(
+ Episode(
+ episode_id=str(uuid.uuid4()),
+ request=request,
+ policy_version=policy_version,
+ pad_id=pad_id,
+ request_len=request_len,
+ response_len=response_len,
+ target=target,
+ )
+ )
+ return cls(str(group_id), episodes)
+
+
+def collate(batches: list[list[Episode]]):
+ inputs = []
+ targets = []
+ for batch in batches:
+ request = [e.request_tensor for e in batch]
+ request = torch.stack(request) # [b x s]
+
+ response = [e.response_tensor for e in batch]
+ response = torch.stack(response) # [b x s]
+
+ # mock out the ref logprobs for now
+ ref_logprobs = torch.zeros(len(batch), batch[0].response_len) # [b x s]
+
+ advantages = [e.advantage for e in batch]
+ advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]
+
+ pad_id = batch[0].pad_id
+ mask = response != pad_id
+
+ input = {"tokens": torch.cat([request, response], dim=1)}
+ target = {
+ "response": response,
+ "ref_logprobs": ref_logprobs,
+ "advantages": advantages,
+ "padding_mask": mask,
+ }
+ inputs.append(input)
+ targets.append(target)
+ return inputs, targets
+
+
+def compute_logprobs(
+ logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
+) -> torch.Tensor:
+ context_length = logits.shape[1] - input_ids.shape[1]
+ logits = logits[:, context_length - 1 : -1]
+ logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device)
+ logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
+ return logprobs
+
+
+def simple_grpo_loss(
+ logits: torch.Tensor,
+ response: torch.Tensor,
+ ref_logprobs: torch.Tensor,
+ advantages: torch.Tensor,
+ padding_mask: torch.Tensor,
+ beta: float = 0.1,
+) -> torch.Tensor:
+ logprobs = compute_logprobs(logits, response)
+ # kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
+ per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
+ per_token_loss = -(per_token_policy_loss)
+ loss = (
+ ((per_token_loss * padding_mask).sum(dim=1))
+ / (padding_mask.sum(dim=1).clamp(min=1.0))
+ ).mean()
+ return loss
+
+
+@dataclass
+class RewardActor(ForgeActor):
+ """Reward actor that uses a list of scoring functions."""
+
+ reward_functions: list[Callable]
+
+ @endpoint
+ async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
+ total_rewards = 0.0
+ for reward_fn in self.reward_functions:
+ reward = reward_fn(prompt, response, target)
+ total_rewards += reward
+ return total_rewards / len(self.reward_functions)
+
+
+class ComputeAdvantages(ForgeActor):
+ """Compute advantages for GRPO using reward signals."""
+
+ @endpoint
+ async def compute(self, group: Group) -> list[float]:
+ # TODO: add batch processing
+ rewards = torch.tensor([[e.reward for e in group.episodes]])
+ mean = rewards.mean(1, keepdim=True)
+ std = rewards.std(1, keepdim=True)
+ advantages = (rewards - mean) / (std + 1e-4)
+ return advantages.squeeze(0).tolist()
+
+
+@dataclass
+class DatasetActor(ForgeActor):
+ """Actor wrapper for HuggingFace dataset to provide async interface."""
+
+ path: str = "openai/gsm8k"
+ revision: str = "main"
+ data_split: str = "train"
+ streaming: bool = True
+ model: str = "Qwen/Qwen3-1.7B"
+
+ @endpoint
+ def setup(self):
+ self._tokenizer = get_tokenizer(self.model)
+
+ def gsm8k_transform(sample):
+ system_prompt = """
+ Put all your scratchpad work between and tags.
+ Your final answer should be between and tags otherwise it will not be scored.
+ """
+ request: str = sample["question"]
+ as_chat = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": request},
+ ]
+ formatted_request = self._tokenizer.apply_chat_template(
+ as_chat,
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ target: str = sample["answer"]
+ formatted_target = target.split("#### ")[1]
+ return {"request": formatted_request, "target": formatted_target}
+
+ ds = load_dataset(
+ self.path, self.revision, split=self.data_split, streaming=self.streaming
+ )
+ ds = ds.map(gsm8k_transform)
+ ds = ds.shuffle()
+ self._iterator = iter(ds)
+
+ @endpoint
+ async def sample(self) -> dict[str, str] | None:
+ try:
+ return next(self._iterator)
+ except StopIteration:
+ return None
+
+ @endpoint
+ async def pad_token(self):
+ return self._tokenizer.pad_token_id
+
+
+async def main(cfg: DictConfig):
+ """Main GRPO training loop with rollout and training processes."""
+ group_size = cfg.group_size
+ max_req_tokens = cfg.max_req_tokens
+ max_res_tokens = cfg.max_res_tokens
+ mlogger = get_metric_logger(
+ "wandb",
+ freq=1,
+ project="grpo-training",
+ )
+
+ # ---- Setup services ---- #
+ await ts.initialize(strategy=ts.ControllerStorageVolumes())
+ (
+ dataloader,
+ policy,
+ trainer,
+ replay_buffer,
+ compute_advantages,
+ # ref_model,
+ reward_actor,
+ ) = await asyncio.gather(
+ DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset),
+ Policy.options(**cfg.services.policy).as_service(**cfg.policy),
+ RLTrainer.options(**cfg.services.trainer).as_service(
+ **cfg.trainer, loss=simple_grpo_loss
+ ),
+ ReplayBuffer.options(**cfg.services.replay_buffer).as_service(
+ **cfg.replay_buffer, collate=collate
+ ),
+ ComputeAdvantages.options(**cfg.services.compute_advantages).as_service(),
+ # ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
+ RewardActor.options(**cfg.services.reward_actor).as_service(
+ reward_functions=[MathReward(), ThinkingReward()]
+ ),
+ )
+ print("All services initialized successfully!")
+
+ # ---- Core RL loops ---- #
+ async def continuous_rollouts():
+ rollout_count = 0
+ pad_id = await dataloader.pad_token.choose()
+ while True:
+ sample = await dataloader.sample.choose()
+ if sample is None:
+ print("Dataloader is empty, exiting continuous rollout")
+ return
+ prompt, target = sample["request"], sample["target"]
+ responses = await policy.generate.choose(prompt)
+ # TODO: this shall be part of the responses metadata instead of a separate call
+ version = await policy.get_version.choose()
+ group = Group.new_group(
+ group_id=rollout_count,
+ group_size=group_size,
+ request=prompt,
+ policy_version=version,
+ pad_id=pad_id,
+ request_len=max_req_tokens,
+ response_len=max_res_tokens,
+ target=target,
+ )
+
+ input_ids = torch.ones(
+ (group_size, max_req_tokens + max_req_tokens),
+ dtype=torch.long,
+ device="cuda",
+ )
+ # Populate episode info and calculate rewards
+ for i, (episode, response) in enumerate(zip(group.episodes, responses)):
+ episode.request_tokens = response.prompt_ids
+ episode.response_tokens = response.token_ids
+ episode.response = response.text
+ input_ids[i, :max_req_tokens] = episode.request_tensor
+ input_ids[i, max_req_tokens:] = episode.response_tensor
+ episode.reward = await reward_actor.evaluate_response.choose(
+ prompt=prompt, response=response.text, target=target
+ )
+
+ # Calculate reference logprobs
+ # ref_logits = await ref_model.forward.choose(input_ids)
+ # ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:])
+ # for i, episode in enumerate(group.episodes):
+ # episode.ref_logprobs = ref_logprobs[i]
+ # del ref_logits, ref_logprobs, input_ids
+
+ # Calculate advantages and add to replay buffer
+ advantages = await compute_advantages.compute.choose(group)
+ for episode, advantage in zip(group.episodes, advantages):
+ episode.advantage = advantage
+ await replay_buffer.add.choose(episode)
+
+ # Log metrics
+ avg_response_len = (
+ sum(len(e.response_tokens) for e in group.episodes) / group_size
+ )
+ mlogger.log("avg_response_len/rollout", avg_response_len, rollout_count)
+ buffer_size = await replay_buffer._numel.choose()
+ mlogger.log("buffer_size/rollout", buffer_size, rollout_count)
+ avg_reward = sum(e.reward for e in group.episodes) / group_size
+ mlogger.log("avg_reward/rollout", avg_reward, rollout_count)
+
+ rollout_count += 1
+
+ async def continuous_training():
+ training_step = 0
+ while True:
+ batch = await replay_buffer.sample.choose(curr_policy_version=training_step)
+ if batch is None:
+ await asyncio.sleep(0.1)
+ else:
+ inputs, targets = batch
+ loss = await trainer.train_step.choose(inputs, targets)
+ training_step += 1
+ mlogger.log("loss/training_step", loss, training_step)
+ await trainer.push_weights.call(training_step)
+ await policy.update_weights.call(training_step)
+
+ print("Starting GRPO training loops...")
+ # TODO: Start multiple rollouts once all serivces support it
+ rollout_task = asyncio.create_task(continuous_rollouts())
+ training_task = asyncio.create_task(continuous_training())
+
+ try:
+ await asyncio.gather(rollout_task, training_task)
+ except KeyboardInterrupt:
+ print("Training interrupted by user")
+ rollout_task.cancel()
+ training_task.cancel()
+ finally:
+ print("Shutting down...")
+ await asyncio.gather(
+ dataloader.shutdown(),
+ policy.shutdown(),
+ trainer.shutdown(),
+ replay_buffer.shutdown(),
+ compute_advantages.shutdown(),
+ # ref_model.shutdown(),
+ reward_actor.shutdown(),
+ )
+ # TODO - add a global shutdown that implicitly shuts down all services
+ # and remote allocations
+ await shutdown()
+
+
+if __name__ == "__main__":
+
+ @parse
+ def _main(cfg):
+ asyncio.run(main(cfg))
+
+ _main() # @parse grabs the cfg from CLI
diff --git a/apps/grpo/qwen3_30b_moe.yaml b/apps/grpo/qwen3_30b_moe.yaml
new file mode 100644
index 000000000..413a8b1dc
--- /dev/null
+++ b/apps/grpo/qwen3_30b_moe.yaml
@@ -0,0 +1,107 @@
+# Grouped Relative Policy Optimization (GRPO)
+# >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml
+
+# Global configuration
+group_size: 2
+batch_size: 4
+max_req_tokens: 512
+max_res_tokens: 512
+model: "Qwen/Qwen3-30B-A3B"
+off_by_n: 1 # Off by one by default
+
+# Dataset configuration
+dataset:
+ path: "openai/gsm8k"
+ revision: "main"
+ data_split: "train"
+ streaming: true
+ model: ${model}
+
+# Policy configuration
+policy:
+ engine_config:
+ model: ${model}
+ tensor_parallel_size: 8
+ pipeline_parallel_size: 1
+ enable_expert_parallel: true
+ enforce_eager: false
+ sampling_config:
+ n: ${group_size}
+ max_tokens: ${max_res_tokens}
+ temperature: 1.0
+ top_p: 1.0
+
+# Trainer configuration
+trainer:
+ model:
+ name: qwen3
+ # TODO: check titan trainer
+ flavor: 30B-A3B
+ hf_assets_path: hf://${model}
+ optimizer:
+ name: AdamW
+ lr: 1e-5
+ eps: 1e-8
+ lr_scheduler:
+ warmup_steps: 1
+ training:
+ local_batch_size: ${batch_size}
+ seq_len: 2048
+ max_norm: 1.0
+ steps: 1000000
+ dtype: bfloat16
+ compile:
+ enable: false
+ parallelism:
+ data_parallel_replicate_degree: 1
+ data_parallel_shard_degree: -1
+ tensor_parallel_degree: 1
+ pipeline_parallel_degree: 1
+ context_parallel_degree: 1
+ expert_parallel_degree: 8
+ disable_loss_parallel: true
+ checkpoint:
+ enable: true
+ initial_load_path: hf://${model}
+ initial_load_in_hf: true
+ last_save_in_hf: true
+ interval: 500
+ async_mode: "disabled"
+ activation_checkpoint:
+ mode: selective
+ selective_ac_option: op
+
+# Replay buffer configuration
+replay_buffer:
+ batch_size: ${batch_size}
+ max_policy_age: ${off_by_n}
+ # TODO: check if we need to change this
+ dp_size: 8
+
+# All resource allocations
+services:
+ dataset:
+ procs: 1
+ num_replicas: 1
+ with_gpus: false
+ policy:
+ procs: ${policy.engine_config.tensor_parallel_size}
+ num_replicas: 1
+ with_gpus: true
+ trainer:
+ procs: 8
+ hosts: 1
+ num_replicas: 1
+ with_gpus: true
+ replay_buffer:
+ procs: 1
+ num_replicas: 1
+ with_gpus: false
+ compute_advantages:
+ procs: 1
+ num_replicas: 1
+ with_gpus: false
+ reward_actor:
+ procs: 1
+ num_replicas: 1
+ with_gpus: false
diff --git a/apps/vllm/main.py b/apps/vllm/main.py
index 6ba1bbbaf..e4d482f60 100644
--- a/apps/vllm/main.py
+++ b/apps/vllm/main.py
@@ -36,7 +36,7 @@ async def run(cfg: DictConfig):
import time
print("Requesting generation...")
- n = 100
+ n = 5
start = time.time()
response_outputs: list[Completion] = await asyncio.gather(
*[policy.generate.choose(prompt=prompt) for _ in range(n)]
diff --git a/apps/vllm/qwen2_5_32b.yaml b/apps/vllm/qwen2_5_32b.yaml
index a7f799bce..8cc9efd79 100644
--- a/apps/vllm/qwen2_5_32b.yaml
+++ b/apps/vllm/qwen2_5_32b.yaml
@@ -12,7 +12,7 @@ policy:
services:
policy:
procs: 4
- hosts: 1
+ # hosts: 1
num_replicas: 1
with_gpus: true
diff --git a/apps/vllm/qwen3_30b_moe.yaml b/apps/vllm/qwen3_30b_moe.yaml
new file mode 100644
index 000000000..3c03aa3ac
--- /dev/null
+++ b/apps/vllm/qwen3_30b_moe.yaml
@@ -0,0 +1,21 @@
+policy:
+ engine_config:
+ model: "Qwen/Qwen3-30B-A3B"
+ tensor_parallel_size: 8
+ pipeline_parallel_size: 1
+ enable_expert_parallel: true
+ enforce_eager: true
+ sampling_config:
+ n: 2
+ guided_decoding: false
+ max_tokens: 512
+
+services:
+ policy:
+ procs: ${policy.engine_config.tensor_parallel_size}
+ num_replicas: 1
+ with_gpus: true
+
+
+# Optional, otherwise argparse fallback kicks in
+prompt: "Tell me a joke"
diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py
index 3cc1e6a48..fb8466c28 100644
--- a/src/forge/controller/actor.py
+++ b/src/forge/controller/actor.py
@@ -206,6 +206,7 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T:
"""
logger.info("Spawning single actor %s", cls.__name__)
actor = await cls.launch(*args, **actor_kwargs)
+ logger.info("Successfully spawned single actor %s", cls.__name__)
return actor
@classmethod
diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py
index 26d51ea5c..7691fd444 100644
--- a/src/forge/controller/provisioner.py
+++ b/src/forge/controller/provisioner.py
@@ -192,6 +192,10 @@ def bootstrap(gpu_ids: list[str]):
os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"
+ os.environ["VLLM_LOG_LEVEL"] = "DEBUG"
+ os.environ["NCCL_DEBUG"] = "INFO"
+ os.environ["NCCL_DEBUG_SUBSYS"] = "INIT"
+
gpu_ids = gpu_manager.get_gpus(num_procs)
procs = host_mesh.spawn_procs(
per_host={"gpus": num_procs},