Skip to content

Commit d8d775a

Browse files
authored
Merge branch 'meta-pytorch:main' into main
2 parents f79beee + d4011ea commit d8d775a

38 files changed

+3303
-546
lines changed

.github/workflows/unit_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
- name: Install pytorch
2727
run: python -m pip install torch==2.9.0.dev20250826 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
2828
- name: Install monarch
29-
run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/wheels
29+
run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/ci
3030
- name: Install dependencies
3131
run: python -m pip install --no-build-isolation -e ".[dev]"
3232
- name: Run unit tests with coverage

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,5 @@ cover/
187187

188188
# wandb
189189
wandb/
190+
191+
assets/wheels/vllm*.whl

README.md

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,38 @@
66

77
## Installation
88

9-
### Basic (Broken)
9+
### Basic
10+
11+
Forge requires the latest PyTorch nightly with Monarch, vLLM, and torchtitan. For convenience,
12+
we have pre-packaged these dependencies as wheels in assets/wheels. (Note that the basic install script
13+
uses [DNF](https://docs.fedoraproject.org/en-US/quick-docs/dnf/), but could be easily extended to other Linux OS.)
14+
15+
Forge requires the Github CLI (gh) to download a compatible vLLM package. See [here](https://github.com/cli/cli#installation) for gh install instructions before continuting.
1016

1117
```bash
12-
pip install uv
13-
git clone https://github.com/pytorch-labs/forge
14-
cd forge
15-
uv sync
18+
conda create -n forge python=3.10
19+
conda activate forge
20+
./scripts/install.sh
21+
```
1622

17-
# Or for dev install:
18-
uv sync --all-extras
23+
After install, you can run the following command and should see output confirming GRPO training is running.
24+
```
25+
python -m apps.grpo.main
26+
```
27+
28+
If you need to re-build the wheels for whatever reason, you can do so with:
29+
```bash
30+
conda create -n forge python=3.10
31+
conda activate forge
32+
./scripts/build_wheels.sh
1933
```
2034

35+
Since the vLLM wheel is too large for GitHub, we uploaded it as a release:
36+
```
37+
$ gh release create v0.0.0 assets/wheels/vllm-*.whl --title "Forge Wheels v0.0.0"
38+
```
2139

22-
### Internal Machine
40+
### Meta Internal Build
2341

2442
1. Build uv package
2543

apps/grpo/main.py

Lines changed: 85 additions & 140 deletions
Original file line numberDiff line numberDiff line change
@@ -5,51 +5,26 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import asyncio
8+
import logging
89
import time
910
from dataclasses import dataclass
1011
from typing import Callable
1112

1213
import torch
1314
from datasets import load_dataset
1415
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
16+
from forge.actors.reference_actor import compute_sequence_logprobs, TitanRefModel
1517
from forge.actors.replay_buffer import ReplayBuffer
16-
from forge.controller import ServiceConfig, spawn_service
1718
from forge.controller.actor import ForgeActor
19+
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
1820
from forge.data.rewards import MathReward, ThinkingReward
1921
from forge.util.metric_logging import get_metric_logger
2022
from monarch.actor import endpoint
23+
from torchtitan.config.job_config import Model as TitanJobModelConfig
2124
from transformers import AutoModelForCausalLM, AutoTokenizer
2225

23-
24-
def compute_sequence_logprobs(
25-
model: torch.nn.Module,
26-
input_ids: torch.Tensor,
27-
attention_mask: torch.Tensor,
28-
requires_grad: bool = True,
29-
) -> torch.Tensor:
30-
context_manager = torch.enable_grad() if requires_grad else torch.no_grad()
31-
32-
with context_manager:
33-
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
34-
logits = outputs.logits
35-
36-
# Apply log softmax to get log probabilities
37-
log_probs = torch.log_softmax(logits, dim=-1)
38-
39-
# Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
40-
shifted_input_ids = input_ids[:, 1:] # Remove first token
41-
shifted_log_probs = log_probs[:, :-1, :] # Remove last logit
42-
43-
# Gather log probabilities for actual tokens
44-
token_log_probs = torch.gather(
45-
shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
46-
).squeeze(-1)
47-
48-
# Sum log probabilities across sequence (masked by attention)
49-
shifted_attention_mask = attention_mask[:, 1:]
50-
sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1)
51-
52-
return sequence_log_probs
26+
logger = logging.getLogger(__name__)
27+
logger.setLevel(logging.DEBUG)
5328

5429

5530
@dataclass
@@ -269,63 +244,21 @@ async def __call__(self, groups: list[Group]) -> list[float]:
269244
return advantages
270245

271246

272-
class RefModel(ForgeActor):
273-
def __init__(self, model_name, device: torch.device | None = None):
274-
super().__init__()
275-
self.model_name = model_name
276-
277-
# Set device
278-
if device is None:
279-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
280-
else:
281-
self.device = device
282-
283-
# Initialize model and tokenizer
284-
self.model = AutoModelForCausalLM.from_pretrained(
285-
model_name,
286-
torch_dtype=torch.bfloat16,
287-
trust_remote_code=True,
288-
).to(self.device)
289-
290-
# Set model to eval mode for reference computations
291-
self.model.eval()
292-
293-
self.logger.info(f"Model initialized on {self.device}")
294-
295-
@endpoint
296-
async def forward(self, token_ids: list[int]) -> torch.Tensor:
297-
# Use provided token_ids directly
298-
input_ids = (
299-
torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
300-
)
301-
# Create attention mask of all 1s since we have actual tokens (no padding)
302-
attention_mask = torch.ones_like(input_ids).to(self.device)
303-
304-
# Compute log probabilities using shared utility function
305-
sequence_log_probs = compute_sequence_logprobs(
306-
self.model, input_ids, attention_mask, requires_grad=False
307-
)
308-
309-
return (
310-
sequence_log_probs.squeeze()
311-
) # Remove batch dimension for single response
312-
313-
314247
class DatasetActor(ForgeActor):
315248
"""Actor wrapper for HuggingFace dataset to provide async interface."""
316249

317-
def __init__(self, *args, **kwargs):
250+
def __init__(
251+
self, path: str, config_name: str, split: str, streaming: bool, **kwargs
252+
):
318253
super().__init__()
319-
self._setup_dataset(*args, **kwargs)
320254

321-
def _setup_dataset(self, *args, **kwargs):
322255
def gsm8k_to_messages(sample):
323256
question = sample["question"]
324257
full_answer: str = sample["answer"]
325258
answer = full_answer.split("#### ")[1]
326259
return {"question": question, "answer": answer}
327260

328-
ds = load_dataset(*args, **kwargs)
261+
ds = load_dataset(path, config_name, split=split, streaming=streaming)
329262
ds = ds.map(gsm8k_to_messages)
330263
ds = ds.shuffle()
331264
self._iterator = iter(ds)
@@ -341,7 +274,8 @@ async def __next__(self) -> dict[str, str] | None:
341274
async def main():
342275
"""Main GRPO training loop with rollout and training processes."""
343276
group_size = 1
344-
model = "Qwen/Qwen3-1.7B"
277+
model = "Qwen/Qwen3-0.6B"
278+
titan_model = TitanJobModelConfig(name="qwen3", flavor="0.6B")
345279

346280
# ---- Setup WandB Logger ---- #
347281
logger = get_metric_logger(
@@ -351,74 +285,69 @@ async def main():
351285
)
352286

353287
# ---- Setup services ---- #
354-
default_service_cfg = ServiceConfig(
355-
procs_per_replica=1,
356-
num_replicas=1,
357-
)
358-
359-
policy = await spawn_service(
360-
default_service_cfg,
361-
Policy,
362-
PolicyConfig(
363-
num_workers=1,
364-
worker_params=WorkerConfig(model=model),
365-
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
366-
available_devices="3",
288+
(
289+
dataloader,
290+
policy,
291+
trainer,
292+
replay_buffer,
293+
compute_advantages,
294+
ref_model,
295+
reward_actor,
296+
) = await asyncio.gather(
297+
spawn_service(
298+
ServiceConfig(procs_per_replica=1, num_replicas=1),
299+
DatasetActor,
300+
path="openai/gsm8k",
301+
config_name="main",
302+
split="train",
303+
streaming=True,
304+
),
305+
spawn_service(
306+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
307+
Policy,
308+
config=PolicyConfig(
309+
worker_params=WorkerConfig(model=model),
310+
sampling_params=SamplingOverrides(
311+
num_samples=group_size, max_tokens=16
312+
),
313+
),
314+
),
315+
spawn_service(
316+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
317+
Trainer,
318+
learning_rate=1e-5,
319+
beta=0.1,
320+
model_name=model,
321+
),
322+
spawn_service(
323+
ServiceConfig(procs_per_replica=1, num_replicas=1),
324+
ReplayBuffer,
325+
batch_size=4,
326+
max_policy_age=1,
327+
),
328+
spawn_service(
329+
ServiceConfig(procs_per_replica=1, num_replicas=1),
330+
ComputeAdvantages,
331+
gamma=0.99,
332+
lambda_=0.95,
333+
),
334+
spawn_service(
335+
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
336+
TitanRefModel,
337+
model=titan_model,
338+
),
339+
spawn_service(
340+
ServiceConfig(procs_per_replica=1, num_replicas=1),
341+
RewardActor,
342+
reward_functions=[MathReward(), ThinkingReward()],
367343
),
368-
)
369-
370-
trainer = await spawn_service(
371-
default_service_cfg,
372-
Trainer,
373-
learning_rate=1e-5,
374-
beta=0.1,
375-
model_name=model,
376-
device=torch.device("cuda:1"),
377-
)
378-
379-
replay_buffer = await spawn_service(
380-
default_service_cfg,
381-
ReplayBuffer,
382-
batch_size=4,
383-
max_policy_age=1,
384-
)
385-
386-
dataloader = await spawn_service(
387-
default_service_cfg,
388-
DatasetActor,
389-
"openai/gsm8k",
390-
"main",
391-
split="train",
392-
streaming=True,
393-
)
394-
395-
compute_advantages = await spawn_service(
396-
default_service_cfg,
397-
ComputeAdvantages,
398-
gamma=0.99,
399-
lambda_=0.95,
400-
)
401-
402-
ref_model = await spawn_service(
403-
default_service_cfg,
404-
RefModel,
405-
model_name=model,
406-
device=torch.device("cuda:2"),
407-
)
408-
409-
reward_actor = await spawn_service(
410-
default_service_cfg,
411-
RewardActor,
412-
reward_functions=[MathReward(), ThinkingReward()],
413344
)
414345

415346
print("All services initialized successfully!")
416347

417348
# ---- Core RL loops ---- #
418349
async def continuous_rollouts():
419350
rollout_count = 0
420-
# TODO: Move this into setup
421-
asyncio.create_task(policy.run_processing.call())
422351
while True:
423352
sample = await dataloader.__next__.choose()
424353
if sample is None:
@@ -432,9 +361,14 @@ async def continuous_rollouts():
432361
target=target,
433362
policy_version=version,
434363
)
435-
actions = await policy.generate.choose(prompt)
364+
responses = await policy.generate.choose(prompt)
365+
actions = responses.outputs
436366
for action in actions:
437-
ref_logprobs = await ref_model.forward.choose(action.token_ids)
367+
request_tokens = responses.prompt_token_ids
368+
response_tokens = action.token_ids
369+
ref_logprobs = await ref_model.forward.choose(
370+
request=request_tokens, response=response_tokens
371+
)
438372
reward = await reward_actor.evaluate_response.choose(
439373
prompt=prompt, response=action.text, target=target
440374
)
@@ -489,6 +423,17 @@ async def continuous_training():
489423
print("Training interrupted by user")
490424
rollout_task.cancel()
491425
training_task.cancel()
426+
finally:
427+
print("Shutting down...")
428+
await asyncio.gather(
429+
shutdown_service(policy),
430+
shutdown_service(trainer),
431+
shutdown_service(replay_buffer),
432+
shutdown_service(dataloader),
433+
shutdown_service(compute_advantages),
434+
shutdown_service(ref_model),
435+
shutdown_service(reward_actor),
436+
)
492437

493438

494439
if __name__ == "__main__":

apps/rl/llama3_8b.yaml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ trainer:
1818
processes:
1919
scheduler: local # local | mast (not supported yet)
2020
num_hosts: 1
21+
with_gpus: True
2122
num_procs: 4
2223

2324
optimizer:
@@ -33,9 +34,11 @@ trainer:
3334
seq_len: 2048
3435
max_norm: 1.0
3536
steps: 5
36-
compile: false
3737
dataset: "c4"
3838

39+
compile:
40+
enable: false
41+
3942
parallelism:
4043
data_parallel_replicate_degree: 1
4144
data_parallel_shard_degree: -1
@@ -65,6 +68,7 @@ replay_buffer:
6568
processes:
6669
scheduler: local # local | mast (not supported yet)
6770
num_hosts: 1
71+
with_gpus: False
6872
num_procs: 1
6973

7074
# policy:

0 commit comments

Comments
 (0)