Skip to content
Merged
Show file tree
Hide file tree
Changes from 63 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
10d336a
shared tensor util
casteryh Oct 14, 2025
68395f9
refactor, fix test
casteryh Oct 14, 2025
3789d49
set up titan distributed through monarch utils, colocate policy with …
allenwang28 Oct 14, 2025
23c8fcb
add SharedTensorHandle class
casteryh Oct 14, 2025
f7ed526
merge
allenwang28 Oct 14, 2025
15d2784
shared memory weight loading
casteryh Oct 14, 2025
c7ca738
oopsie
casteryh Oct 14, 2025
cbd529e
end -> stop
casteryh Oct 14, 2025
20e162f
typo
casteryh Oct 14, 2025
eadf3a5
make policy_version optional
casteryh Oct 14, 2025
2b675f8
fix
casteryh Oct 14, 2025
8b488b7
no leak
casteryh Oct 14, 2025
5e7528c
disable dcp in 8b
casteryh Oct 14, 2025
2b23e50
undo the colocation
allenwang28 Oct 14, 2025
4b850ad
debug info
casteryh Oct 14, 2025
f7cbcb4
typo
casteryh Oct 14, 2025
09836f3
refactor
casteryh Oct 14, 2025
e5f984e
Merge branch 'main' into yhu/shared-tensor
casteryh Oct 14, 2025
9fa7395
temp: reduce num_replicas to 2
casteryh Oct 14, 2025
7744f4c
fix bad merge
casteryh Oct 14, 2025
8970ff4
revert to 4
casteryh Oct 14, 2025
48821de
move _fetch_weights to policy worker
casteryh Oct 15, 2025
2798f2e
typo
casteryh Oct 15, 2025
ddf8d26
endpoint
casteryh Oct 15, 2025
71b89c1
clean up
casteryh Oct 15, 2025
1971a4f
fix
casteryh Oct 15, 2025
dc301aa
fix
casteryh Oct 15, 2025
c879753
rearrange
casteryh Oct 15, 2025
c462911
log
casteryh Oct 15, 2025
571750f
vllm colocation works
allenwang28 Oct 15, 2025
a155a4c
caching allocation for weight updates
casteryh Oct 15, 2025
c2ab4e1
Merge branch 'titan_setup' into yhu/shared-tensor
casteryh Oct 15, 2025
9fa3297
test
casteryh Oct 15, 2025
68be7a7
multi processing for fetch
casteryh Oct 15, 2025
f30e666
endpoint
casteryh Oct 15, 2025
a190623
cleanup
casteryh Oct 15, 2025
af3f35c
fix
casteryh Oct 15, 2025
8fb6dde
fix
casteryh Oct 15, 2025
35c0052
fix queue pair checked out
casteryh Oct 15, 2025
555966f
typo
casteryh Oct 15, 2025
16b060b
spawn on worker procs
casteryh Oct 15, 2025
1e4110c
add back generator_proc
casteryh Oct 15, 2025
efe91e3
fix slice
casteryh Oct 15, 2025
d738f5e
fix fetcher call
casteryh Oct 15, 2025
957b4cd
fix arguments
casteryh Oct 15, 2025
be85a94
remove await
casteryh Oct 15, 2025
002d68b
make handle droppable
casteryh Oct 15, 2025
090b6ec
spawn in setup
casteryh Oct 15, 2025
1f9dc13
fix
casteryh Oct 15, 2025
289ca08
add t.stop()
casteryh Oct 15, 2025
e854ddd
create tasks
casteryh Oct 15, 2025
bee85df
pray
casteryh Oct 15, 2025
de3b41d
Revert "pray"
casteryh Oct 15, 2025
2b62964
Revert "create tasks"
casteryh Oct 15, 2025
aa57ba3
more procs
casteryh Oct 15, 2025
fab7cd4
procs=1
casteryh Oct 15, 2025
2d14c40
cleanup, guard with flag
casteryh Oct 15, 2025
acf72b8
Merge branch 'main' into yhu/shared-tensor-mp
casteryh Oct 15, 2025
d62b44f
clean up unused
casteryh Oct 15, 2025
6cf22d7
fork
casteryh Oct 15, 2025
bd921ca
Merge branch 'main' into yhu/shared-tensor-mp
casteryh Oct 15, 2025
7d56eef
remove cleanup which no longer exists
casteryh Oct 15, 2025
fc37e27
fix bad merge
casteryh Oct 15, 2025
bda4798
multiprocessing -> multiprocess
casteryh Oct 16, 2025
ab7dae5
separate spawn_fetchers method
casteryh Oct 16, 2025
e2785f7
rename flag
casteryh Oct 17, 2025
125b039
rename flag
casteryh Oct 17, 2025
b36eb9a
Merge branch 'main' into yhu/shared-tensor-mp
casteryh Oct 17, 2025
c9df35f
tp=8
casteryh Oct 17, 2025
dfd8656
Fix SharedTensor memory leaks with explicit lifecycle management
casteryh Oct 17, 2025
101b7a8
Use context manager pattern for SharedTensor in generator
casteryh Oct 17, 2025
ac9240c
Enable shared memory weight prefetching by default
casteryh Oct 17, 2025
9c35c31
Remove experimental config file
casteryh Oct 17, 2025
4c981bd
Merge remote-tracking branch 'origin/main' into yhu/shared-tensor-mp
casteryh Oct 17, 2025
a556064
Update src/forge/actors/generator.py
casteryh Oct 17, 2025
777a411
debug
casteryh Oct 17, 2025
db1394e
Add missing del param in _WeightFetcher.fetch()
casteryh Oct 17, 2025
856c070
Improve error handling in SharedTensor.close() and drop()
casteryh Oct 17, 2025
b3732de
Fix SharedTensor anti-pattern in multiprocess tests
casteryh Oct 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions apps/grpo/qwen3_32b_experimental.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# This is an experimental fork of qwen3_32b.yaml that enables the following:
# - shared memory based weight prefetching for weight updates
#
# Grouped Relative Policy Optimization (GRPO)
# >>> python -m apps.grpo.main --config apps/grpo/qwen32b.yaml
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability

# Global configuration
group_size: 16
local_batch_size: 32 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default

provisioner:
launcher: slurm

# Main loop configuration
rollout_threads: 32 # make this 4x the number of policy replicas seems to work well

# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
console:
reduce_across_ranks: True

# Dataset configuration
dataset:
path: "openai/gsm8k"
revision: "main"
data_split: "train"
streaming: true
model: ${model}

# Policy configuration
policy:
prefetch_weights: true
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: ${model}
tensor_parallel_size: 4
pipeline_parallel_size: 1
enforce_eager: false
sampling_params: # https://docs.vllm.ai/en/v0.10.0/api/vllm/sampling_params.html#vllm.sampling_params.SamplingParams
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0

# Trainer configuration
trainer:
model:
name: qwen3
flavor: 32B
hf_assets_path: hf://${model}
optimizer:
name: AdamW
lr: 1e-5
eps: 1e-8
lr_scheduler:
warmup_steps: 1
training:
local_batch_size: ${local_batch_size}
seq_len: 2048
max_norm: 1.0
steps: 1000000
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 8
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: true
checkpoint:
enable: true
initial_load_path: hf://${model}
initial_load_in_hf: true
last_save_in_hf: true
interval: 500
async_mode: "disabled"
activation_checkpoint:
mode: full

# Replay buffer configuration
replay_buffer:
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
dp_size: 1

# Reference model configuration
ref_model:
model:
name: qwen3
flavor: 32B
hf_assets_path: hf://${model}
training:
dtype: bfloat16
gc_freq: 1
compile:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 4
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
checkpoint:
enable: true
initial_load_path: hf://${model}
initial_load_in_hf: true

# All resource allocations
services:
policy:
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 4
hosts: 1
with_gpus: true
mesh_name: policy
ref_model:
procs: ${ref_model.parallelism.tensor_parallel_degree}
num_replicas: 1
with_gpus: true
mesh_name: ref_model
reward_actor:
procs: 1
num_replicas: 1
with_gpus: false
mesh_name: reward_actor

actors:
dataset:
procs: 1
with_gpus: false
mesh_name: dataset
trainer:
procs: 8
hosts: 1
with_gpus: true
mesh_name: trainer
replay_buffer:
procs: 1
with_gpus: false
mesh_name: replay_buffer
compute_advantages:
procs: 1
with_gpus: false
mesh_name: compute_advantages
120 changes: 115 additions & 5 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
from collections.abc import Mapping
from copy import copy
from dataclasses import dataclass, field
from typing import Optional

import torch
import torchstore as ts
from monarch.actor import current_rank, endpoint, ProcMesh
from monarch.actor import current_rank, endpoint, ProcMesh, this_host

from vllm.config import VllmConfig

from vllm.engine.arg_utils import EngineArgs
Expand Down Expand Up @@ -60,6 +62,7 @@
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.types import ProcessConfig
from forge.util._shared_tensor import SharedTensor, SharedTensorHandle

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -92,6 +95,8 @@ class Generator(ForgeActor):
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
use_dcp_for_weight_sync: bool | None = None
prefetch_weights: bool = False
n_fetcher_procs: int = 8

def __post_init__(self):
super().__init__()
Expand Down Expand Up @@ -226,11 +231,53 @@ async def setup(self):
log_stats=None,
)
self._start_processing()
fetcher_procs = this_host().spawn_procs(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we guard this with if prefetch_weights? It may also be a bit cleaner to put this init in another function w/ some documentation for what the fetchers here do

Copy link
Contributor Author

@casteryh casteryh Oct 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will move to a separate function.

per_host={"procs": self.n_fetcher_procs}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that we need to address now, but I think we need to spawn these fetcher procs across all generator nodes right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes. but I assume setup call is broadcasted and every Generator node will spawn their own fetcher_procs?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is correct, I meant if a generator workers span 2 nodes i.e. DeepSeek

In that case we would probably want to spin up the fetchers on the worker nodes right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a bit confused -- shouldn't vLLM worker be scoped to a single node?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also don't follow why you need more than 1. Is it to allow you to parallelize torchstore requests?

)
self._fetcher_procs = fetcher_procs
self.weight_fetchers = fetcher_procs.spawn("weight_fetcher", _WeightFetcher)

def _start_processing(self):
if self._run_task is None or self._run_task.done():
self._run_task = asyncio.create_task(self.run())

async def _drop_shared_memory(self, state_dict: dict[str, SharedTensorHandle]):
for handle in state_dict.values():
handle.drop()

async def _fetch_weights(
self,
version: int,
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}."""
t = Tracer("generator_perf/_fetch_weights")
t.start()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
hf_param_names = [extract_param_name(key) for key in matching_keys]

n_fetchers = self.weight_fetchers.size()

def split_keys(keys):
return [keys[i::n_fetchers] for i in range(n_fetchers)]

futures = []
for i, names in enumerate(split_keys(hf_param_names)):
fut = self.weight_fetchers.slice(procs=i).fetch.call_one(
version=version, param_names=names
)
futures.append(fut)

sub_state_dicts = [await fut for fut in futures]

state_dict = {}
for sd in sub_state_dicts:
state_dict.update(sd)

t.stop()

return state_dict

@endpoint
async def generate(self, prompt: str, *, priority: int = 0) -> list[Completion]:
"""Generate a response for the given prompt
Expand Down Expand Up @@ -384,6 +431,12 @@ async def update_weights(self, version: int) -> None:
>>> await trainer.push_weights()
>>> generator.update_weights(version)
"""
# Prefetch only if we are using RDMA
if self.prefetch_weights and not self.use_dcp_for_weight_sync:
logger.info(f"[Generator] Fetching weights for v{version} to shared memory")
fetch_fut = asyncio.create_task(self._fetch_weights(version))
else:
fetch_fut = None
# Serialize updates (only one update at a time)
async with self.update_lock:
# Grab the lock to stop accepting requests and wait on pending requests
Expand Down Expand Up @@ -415,8 +468,19 @@ async def update_weights(self, version: int) -> None:
)

logger.debug(f"Starting weight update on {self.__class__.__name__}")
# Call update_weights on every generator worker
await self.worker.update_weights.call(version=version)

if fetch_fut is not None:
t = Tracer("generator_perf/waiting_for_fetch_weights")
t.start()
fetched_weights = await fetch_fut
t.stop()
# Call update_weights on every policy_worker
await self.worker.update_weights.call(
shared_memory_state_dict=fetched_weights
)
await self._drop_shared_memory(fetched_weights)
else:
await self.worker.update_weights.call(version=version)
self.generator_version = version

# After updating the weights, we need to reset the KV cache
Expand Down Expand Up @@ -490,6 +554,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
await actor.stop.call()
await stop_proc_mesh(actor._worker_procs)
await stop_proc_mesh(actor._generator_proc)
await stop_proc_mesh(actor._fetcher_procs)

@endpoint
async def _test_save_model_params(self):
Expand Down Expand Up @@ -573,14 +638,39 @@ async def execute_model(self, schedule: SchedulerOutput) -> ModelRunnerOutput:
return self.worker.execute_model(schedule)

@endpoint
async def update_weights(self, version: int) -> None:
async def update_weights(
self,
version: Optional[int] = None,
*,
shared_memory_state_dict: Optional[dict[str, SharedTensorHandle]] = None,
) -> None:
model = self.worker.model_runner.model
if shared_memory_state_dict is not None:
logger.info("[PolicyWorker] update weights from shared memory.")
t = Tracer(
"generator_worker_perf/update_weights_from_shared_memory", timer="gpu"
)
t.start()
loaded_weights = set()
for name, param_handle in shared_memory_state_dict.items():
param = param_handle.to_shared_tensor().tensor
loaded = model.load_weights([(name, param)])
loaded_weights.update(loaded)
logger.info(f"[PolicyWorker] updated {len(loaded_weights)} paremeters")
t.stop()
return
# normal update_weights without shared memory prefetching
if version is None:
raise ValueError(
"version must be provided if not using shared_memory_state_dict"
)
logger.info("[PolicyWorker] update weights from torchstore.")
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys
loaded_weights = set()
t = Tracer("worker_perf/update_weights", timer="gpu")
t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu")
t.start()

if use_dcp_for_weight_sync:
Expand Down Expand Up @@ -622,3 +712,23 @@ async def _test_validate_model_params(self, validate_fn):
return validate_fn(
self._test_prev_params, self.worker.model_runner.model, logger
)


class _WeightFetcher(ForgeActor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this could be a method on the generator that gets called from main, so prefetch is controlled and visible from the main loop. I am curious if this has to actually be a separate process since this is an async method and I would think most of the time it's waiting on ts.get.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has to be a separate actor because it has to be launched in a separate process

"""Fetches weights from torchstore and loads them into shared memory.
This has to be colocated with the GeneratorWorker."""

@endpoint
async def fetch(
self,
*,
version: int,
param_names: list[str],
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and load them into shared memory."""
sd = {}
for name in param_names:
param_key = get_param_key(version, name)
param = await ts.get(param_key)
sd[name] = SharedTensor(tensor=param).get_handle()
return sd
Loading
Loading