Skip to content

Commit 11ea544

Browse files
author
Felipe Mello
committed
Merge branch 'main' of https://github.com/pytorch-labs/forge into metric_logging
2 parents e27d451 + d5ae6c7 commit 11ea544

File tree

7 files changed

+301
-141
lines changed

7 files changed

+301
-141
lines changed

.github/workflows/unit_test.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@ jobs:
3232
eval "$(ssh-agent -s)"
3333
ssh-add - <<< '${{ secrets.FORGE_GITHUB_CI_FOR_TORCHSTORE }}'
3434
python -m pip install git+ssh://[email protected]/meta-pytorch/torchstore.git
35+
- name: Install torchtitan
36+
run: |
37+
pip install --pre torchtitan==0.1.0.dev20250826+cpu --extra-index-url https://download.pytorch.org/whl/nightly/cpu
38+
pip install tyro
3539
- name: Install dependencies
3640
run: python -m pip install --no-build-isolation -e ".[dev]"
3741
- name: Run unit tests with coverage

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,6 @@ cover/
193193
wandb/
194194

195195
assets/wheels/vllm*.whl
196+
197+
# DCP artifacts
198+
model_state_dict/

apps/grpo/qwen3_8b.yaml

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# Grouped Relative Policy Optimization (GRPO)
2+
# >>> python -m apps.grpo.main --config apps/grpo/qwen3_8b.yaml
3+
4+
# Global configuration
5+
group_size: 8
6+
batch_size: 16
7+
max_req_tokens: 512
8+
max_res_tokens: 512
9+
model: "Qwen/Qwen3-8B"
10+
off_by_n: 1 # Off by one by default
11+
12+
# Dataset configuration
13+
dataset:
14+
path: "openai/gsm8k"
15+
revision: "main"
16+
data_split: "train"
17+
streaming: true
18+
model: ${model}
19+
20+
# Policy configuration
21+
policy:
22+
engine_config:
23+
model: ${model}
24+
tensor_parallel_size: 2
25+
pipeline_parallel_size: 1
26+
enforce_eager: false
27+
sampling_config:
28+
n: ${group_size}
29+
max_tokens: ${max_res_tokens}
30+
temperature: 1.0
31+
top_p: 1.0
32+
33+
# Trainer configuration
34+
trainer:
35+
model:
36+
name: qwen3
37+
flavor: 8B
38+
hf_assets_path: hf://${model}
39+
optimizer:
40+
name: AdamW
41+
lr: 1e-5
42+
eps: 1e-8
43+
lr_scheduler:
44+
warmup_steps: 1
45+
training:
46+
local_batch_size: ${batch_size}
47+
seq_len: 2048
48+
max_norm: 1.0
49+
steps: 1000000
50+
dtype: bfloat16
51+
compile:
52+
enable: false
53+
parallelism:
54+
data_parallel_replicate_degree: 1
55+
data_parallel_shard_degree: -1
56+
tensor_parallel_degree: 1
57+
pipeline_parallel_degree: 1
58+
context_parallel_degree: 1
59+
expert_parallel_degree: 1
60+
disable_loss_parallel: true
61+
checkpoint:
62+
enable: true
63+
initial_load_path: hf://${model}
64+
initial_load_in_hf: true
65+
last_save_in_hf: true
66+
interval: 500
67+
async_mode: "disabled"
68+
activation_checkpoint:
69+
mode: selective
70+
selective_ac_option: op
71+
72+
# Replay buffer configuration
73+
replay_buffer:
74+
batch_size: ${batch_size}
75+
max_policy_age: ${off_by_n}
76+
# This should match the dp_size of TorchTitan
77+
# Here it's set explicitly to 2, because we've set
78+
# 2 GPUs for the trainer and we're using full FSDP.
79+
dp_size: 2
80+
81+
# Reference model configuration
82+
ref_model:
83+
model:
84+
name: qwen3
85+
flavor: 8B
86+
hf_assets_path: hf://${model}
87+
training:
88+
dtype: bfloat16
89+
compile:
90+
enable: false
91+
parallelism:
92+
data_parallel_replicate_degree: 1
93+
data_parallel_shard_degree: 1
94+
tensor_parallel_degree: 1
95+
pipeline_parallel_degree: 1
96+
context_parallel_degree: 1
97+
expert_parallel_degree: 1
98+
checkpoint:
99+
initial_load_path: hf://${model}
100+
initial_load_in_hf: true
101+
102+
# All resource allocations
103+
services:
104+
dataset:
105+
procs: 1
106+
num_replicas: 1
107+
with_gpus: false
108+
policy:
109+
procs: ${policy.engine_config.tensor_parallel_size}
110+
num_replicas: 1
111+
with_gpus: true
112+
trainer:
113+
procs: 2
114+
num_replicas: 1
115+
with_gpus: true
116+
replay_buffer:
117+
procs: 1
118+
num_replicas: 1
119+
with_gpus: false
120+
ref_model:
121+
procs: 1
122+
num_replicas: 1
123+
with_gpus: true
124+
compute_advantages:
125+
procs: 1
126+
num_replicas: 1
127+
with_gpus: false
128+
reward_actor:
129+
procs: 1
130+
num_replicas: 1
131+
with_gpus: false

src/forge/actors/trainer.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import logging
88
import math
99
import os
10+
import shutil
1011
import time
1112
from collections.abc import Mapping
1213
from dataclasses import dataclass, field, fields
@@ -39,7 +40,46 @@
3940
from forge.data.utils import batch_to_device
4041

4142
logger = logging.getLogger(__name__)
42-
logger.setLevel(logging.INFO)
43+
logger.setLevel(logging.DEBUG)
44+
45+
46+
def cleanup_old_weight_versions(
47+
state_dict_key: str,
48+
delim: str,
49+
current_policy_version: int,
50+
) -> None:
51+
"""Delete old weight versions, keeping only current and N-1 versions.
52+
53+
TODO - issues/194: provide a more robust way to handle eviction.
54+
55+
Args:
56+
state_dict_key: The base key for state dict storage
57+
delim: The delimiter used between key and version
58+
current_policy_version: The current policy version to keep
59+
"""
60+
if current_policy_version <= 1:
61+
return # No cleanup needed for versions 0 or 1
62+
63+
prefix = f"{state_dict_key}{delim}"
64+
current_weights = f"{prefix}{current_policy_version}"
65+
previous_weights = f"{prefix}{current_policy_version - 1}"
66+
67+
# Find all weight directories that match our pattern
68+
parent_dir = os.path.dirname(prefix) or "."
69+
if os.path.exists(parent_dir):
70+
for item in os.listdir(parent_dir):
71+
item_path = os.path.join(parent_dir, item)
72+
if (
73+
item.startswith(os.path.basename(prefix))
74+
and item != os.path.basename(current_weights)
75+
and item != os.path.basename(previous_weights)
76+
and os.path.isdir(item_path)
77+
):
78+
try:
79+
shutil.rmtree(item_path, ignore_errors=True)
80+
logger.debug(f"Removed old weights at {item_path}")
81+
except OSError as e:
82+
logger.debug(f"Error deleting {item_path}: {e}")
4383

4484

4585
@dataclass
@@ -67,6 +107,7 @@ def __post_init__(self):
67107
in monarch for now.
68108
69109
"""
110+
super().__init__()
70111
# Instantiate dict fields
71112
for f in fields(self):
72113
attr = getattr(self, f.name)
@@ -223,13 +264,26 @@ async def push_weights(self, policy_version: int) -> None:
223264
)
224265
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
225266
# TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
226-
vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict, num_layers=28)
267+
vllm_ready_hf_sd = _qwen3_hf_to_vllm(
268+
sd=hf_state_dict, num_layers=self.engine.model_args.n_layers
269+
)
227270

228271
key = f"{self.state_dict_key}{DELIM}{policy_version}"
229272
start_time = time.time()
230273
if self.use_dcp:
274+
275+
# TODO - DCP should probably be being saved to NFS explicitly?
276+
# Right now it will only save everything locally
231277
metadata = dcp.save(checkpoint_id=key, state_dict=vllm_ready_hf_sd)
232278
await ts.put(key, metadata)
279+
280+
# Delete old weight versions if they exist
281+
if self.rank == 0:
282+
cleanup_old_weight_versions(
283+
state_dict_key=self.state_dict_key,
284+
delim=DELIM,
285+
current_policy_version=policy_version,
286+
)
233287
else:
234288
await ts.put_state_dict(vllm_ready_hf_sd, key)
235289
end_time = time.time()

src/forge/controller/provisioner.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,15 +202,18 @@ async def get_proc_mesh(
202202
# We can't currently do this because HostMesh only supports single
203203
# proc_mesh creation at the moment. This will be possible once
204204
# we have "proper HostMesh support".
205-
def bootstrap(gpu_ids: int):
205+
def bootstrap(gpu_ids: list[str]):
206206
# This works for single host, needed for vLLM currently.
207207
import os
208208

209209
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(gpu_ids)
210210
os.environ["MASTER_ADDR"] = socket.gethostname()
211211
# Multiple actors trying to call _get_port doesn't work
212212
# os.environ["MASTER_PORT"] = _get_port()
213-
os.environ["MASTER_PORT"] = "12345"
213+
214+
# Setting the last digit to the first GPU id allows us to i.e.
215+
# create multiple vLLM instances on the same local host.
216+
os.environ["MASTER_PORT"] = f"1234{gpu_ids[0]}"
214217
os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
215218
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"
216219

0 commit comments

Comments
 (0)