Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dataclasses import dataclass
from typing import Any, Callable

from forge.controller import provisioner
import torch
import torch.nn.functional as F
import torchstore as ts
Expand Down Expand Up @@ -237,15 +238,17 @@ async def pad_token(self):
return self._tokenizer.pad_token_id


from forge.controller.provisioner import _get_provisioner

async def main(cfg: DictConfig):
"""Main GRPO training loop with rollout and training processes."""
group_size = cfg.group_size
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens
mlogger = get_metric_logger(
"wandb",
# "wandb",
freq=1,
project="grpo-training",
# project="grpo-training",
)

# ---- Setup services ---- #
Expand Down Expand Up @@ -387,6 +390,11 @@ async def continuous_training():

@parse
def _main(cfg):
# import pickle
# with open("../qwen3_multinode.pkl", "wb") as fp:
# pickle.dump(cfg, fp)

# return
asyncio.run(main(cfg))

_main() # @parse grabs the cfg from CLI
19 changes: 11 additions & 8 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ dataset:

# Policy configuration
policy:
checkpoint_path: checkpoints
engine_config:
model: ${model}
tokenizer: /home/lpasqualin/titan_manifold/tree/forge/qwen3-1-7b-tokenizer
model: /home/lpasqualin/titan_manifold/tree/qwen3/Qwen3-1-7B
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: false
Expand All @@ -35,7 +37,7 @@ trainer:
model:
name: qwen3
flavor: 1.7B
hf_assets_path: hf://${model}
hf_assets_path: /home/lpasqualin/titan_manifold/tree/qwen3/Qwen3-1-7B/
optimizer:
name: AdamW
lr: 1e-5
Expand All @@ -47,7 +49,7 @@ trainer:
seq_len: 2048
max_norm: 1.0
steps: 1000000
dtype: bfloat16
# dtype: bfloat16
compile:
enable: false
parallelism:
Expand All @@ -60,11 +62,12 @@ trainer:
disable_loss_parallel: true
checkpoint:
enable: true
initial_load_path: hf://${model}
initial_load_path: /home/lpasqualin/titan_manifold/tree/qwen3/Qwen3-1-7B/
initial_load_in_hf: true
last_save_in_hf: true
interval: 500
async_mode: "disabled"
folder: checkpoints
activation_checkpoint:
mode: selective
selective_ac_option: op
Expand All @@ -80,9 +83,9 @@ ref_model:
model:
name: qwen3
flavor: 1.7B
hf_assets_path: hf://${model}
training:
dtype: bfloat16
hf_assets_path: /home/lpasqualin/titan_manifold/tree/qwen3/Qwen3-1-7B/
# training:
# dtype: bfloat16
compile:
enable: false
parallelism:
Expand All @@ -93,7 +96,7 @@ ref_model:
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
initial_load_path: hf://${model}
initial_load_path: /home/lpasqualin/titan_manifold/tree/qwen3/Qwen3-1-7B/
initial_load_in_hf: true

# All resource allocations
Expand Down
88 changes: 71 additions & 17 deletions apps/grpo/qwen3_multinode.yaml
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
# GRPO Training Configuration
# Currently a fork of the main yaml, this just shows
# placement of trainer and inference servers on separate hosts.
# Grouped Relative Policy Optimization (GRPO)
# >>> python -m apps.grpo.qwen3_1_7b --config apps/grpo/qwen3_1_7b.yaml

# Global configuration
group_size: 8
batch_size: 16
max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default

# Dataset configuration
dataset:
path: "openai/gsm8k"
path: "openai/gsm8k" #add manifold path here
revision: "main"
data_split: "train"
streaming: true
Expand All @@ -20,7 +20,8 @@ dataset:
# Policy configuration
policy:
engine_config:
model: ${model}
tokenizer: /mnt/mffuse/forge/qwen3-1-7b-tokenizer
model: /mnt/mffuse/qwen3/Qwen3-1-7B/
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: false
Expand All @@ -32,46 +33,99 @@ policy:

# Trainer configuration
trainer:
model_name: ${model}
learning_rate: 1e-5
model:
name: qwen3
flavor: 1.7B
hf_assets_path: /mnt/mffuse/qwen3/Qwen3-1-7B/
optimizer:
name: AdamW
lr: 1e-5
eps: 1e-8
lr_scheduler:
warmup_steps: 1
training:
local_batch_size: ${batch_size}
seq_len: 2048
max_norm: 1.0
steps: 1000000
# dtype: bfloat16
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
initial_load_path: /mnt/mffuse/qwen3/Qwen3-1-7B/
initial_load_in_hf: true
last_save_in_hf: false
interval: 500
async_mode: "disabled"
activation_checkpoint:
mode: selective
selective_ac_option: op

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 1 # Async by 1
dp_size: 1
max_policy_age: ${off_by_n}
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree

# Reference model configuration
ref_model:
model_name: ${model}
model:
name: qwen3
flavor: 1.7B
hf_assets_path: /mnt/mffuse/qwen3/Qwen3-1-7B
# training:
# dtype: bfloat16
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
initial_load_path: /mnt/mffuse/qwen3/Qwen3-1-7B/
initial_load_in_hf: true

# All resource allocations
services:
dataset:
procs: 1
num_replicas: 1
with_gpus: false
policy:
procs: 1
hosts: 1
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 1
with_gpus: true
hosts: 1
trainer:
procs: 1
hosts: 1
num_replicas: 1
with_gpus: true
hosts: 1
replay_buffer:
procs: 1
num_replicas: 1
with_gpus: false
compute_advantages:
procs: 1
num_replicas: 1
with_gpus: false
ref_model:
procs: 1
num_replicas: 1
with_gpus: true
hosts: 1
compute_advantages:
procs: 1
num_replicas: 1
with_gpus: false
reward_actor:
procs: 1
num_replicas: 1
Expand Down
14 changes: 10 additions & 4 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import time
from collections.abc import Mapping
from copy import copy
from typing import Optional
from dataclasses import asdict, dataclass, field, fields

import torch
Expand Down Expand Up @@ -166,10 +167,10 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
if isinstance(engine_config, Mapping):
engine_config = EngineConfig.from_dict(engine_config)

vllm_config = engine_config.create_vllm_config()
workers = await worker_procs.spawn(
"vllm_worker", PolicyWorker, vllm_config=vllm_config
"vllm_worker", PolicyWorker,
)
await workers.create_config.call(engine_config=engine_config)

if isinstance(sampling_config, Mapping):
sampling_config = SamplingConfig(**sampling_config)
Expand Down Expand Up @@ -399,10 +400,15 @@ async def stop(self):

@dataclass
class PolicyWorker(ForgeActor):
vllm_config: VllmConfig
vllm_config: VllmConfig = None
state_dict_key: str = "model_state_dict"
checkpoint_path: str = ""
use_dcp: bool = True

@endpoint
def create_config(self, engine_config) -> None:
self.vllm_config = engine_config.create_vllm_config()

@endpoint
async def setup(self):
# TODO: remove ["gpus"] when monarch implements a flat rank
Expand All @@ -423,7 +429,7 @@ async def _load_tensor_parallel_state_dict(
self.vllm_config.parallel_config.tensor_parallel_size, self.rank
)

checkpoint_id = f"{self.state_dict_key}{DELIM}{version}"
checkpoint_id = f"{self.checkpoint_path}/{self.state_dict_key}{DELIM}{version}"
dcp_metadata = None
if self.use_dcp:
dcp_metadata = await ts.get(checkpoint_id)
Expand Down
18 changes: 17 additions & 1 deletion src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,22 @@ def __post_init__(self):

@endpoint
async def setup(self):

import socket
print(f"HOST= {socket.gethostname()=}")

# Print all files in the specified directory
import glob
snapshot_dir = "/mnt/mffuse/qwen3/Qwen3-1-7B/snapshots/70d244cc86ccca08cf5af4e1e306ecf908b1ad5e"
files = glob.glob(f"{snapshot_dir}/**", recursive=True)
print("Files in snapshot directory:")
for f in files:
print(f"FILE= {f=}")

if not files:
raise RuntimeError("No files found in snapshot directory")


# TODO: update ForgeEngine to not use ForgeJobConfig
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
for key in {"loss", "state_dict_key", "use_dcp"}:
Expand Down Expand Up @@ -220,7 +236,7 @@ async def push_weights(self, policy_version: int) -> None:
# TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
vllm_ready_hf_sd = _qwen3_hf_to_vllm(sd=hf_state_dict, num_layers=28)

key = f"{self.state_dict_key}{DELIM}{policy_version}"
key = f"{self.checkpoint.folder}/{self.state_dict_key}{DELIM}{policy_version}"
start_time = time.time()
if self.use_dcp:
metadata = dcp.save(checkpoint_id=key, state_dict=vllm_ready_hf_sd)
Expand Down
Loading