From 001a6b6ad11adcbe3e7e805beea4f273c7472ce3 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 4 Oct 2025 18:20:28 -0700 Subject: [PATCH 01/20] enable rdma for weight sync --- apps/grpo/qwen3_8b.yaml | 2 +- src/forge/actors/trainer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index c46ee0620..d05fd18f9 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -42,7 +42,7 @@ policy: # Trainer configuration trainer: - use_dcp: true + use_dcp: false use_vllm_builtin_load: true model: name: qwen3 diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 4ffc63001..b8f9fa97e 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -403,7 +403,8 @@ async def push_weights(self, policy_version: int) -> None: else: for name, param in hf_state_dict.items(): key = get_param_key(policy_version, name) - await ts.put(key, param) + # RDMA is still broken on GPU, so we need to copy to CPU + await ts.put(key, param.detach().cpu()) t.step("ts_save") t.stop() end_time = time.perf_counter() From 4f73bb3c8f31d21a58dc28e8ae1ce3cd42a6f527 Mon Sep 17 00:00:00 2001 From: yuxuanh Date: Wed, 8 Oct 2025 17:38:36 +0000 Subject: [PATCH 02/20] stash --- apps/grpo/qwen3_32b.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 3d1b80852..9a71cf4d4 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -45,6 +45,7 @@ policy: # Trainer configuration trainer: + use_dcp: false model: name: qwen3 flavor: 32B From 0e2fef6a14f66feaffe1b106bd43ef5149d99d1e Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Wed, 8 Oct 2025 23:22:15 +0000 Subject: [PATCH 03/20] inherit env variables in provisioner.py --- src/forge/controller/provisioner.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 5ca331f32..32569f7bb 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -240,6 +240,13 @@ def bootstrap(env: dict[str, str]): # Shows detailed logs for Monarch rust failures env_vars["RUST_BACKTRACE"] = "1" + + env_vars_to_inherit = ["TORCHSTORE_RDMA_ENABLED"] + + for name in env_vars_to_inherit: + val = os.environ.get(name) + if val is not None: + env_vars[name] = val procs = host_mesh.spawn_procs( per_host={"gpus": num_procs}, From 64421c4e0159e5ed486b886ed3770f68c0998409 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 10 Oct 2025 02:06:01 +0000 Subject: [PATCH 04/20] modified config --- apps/grpo/qwen3_32b.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 7d10c4d0c..69cbcc7ab 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -6,7 +6,7 @@ group_size: 2 batch_size: 8 max_req_tokens: 512 -max_res_tokens: 512 +max_res_tokens: 1536 model: "Qwen/Qwen3-32B" off_by_n: 1 # Off by one by default @@ -14,7 +14,7 @@ provisioner: launcher: slurm # Main loop configuration -rollout_threads: 1 # Recommended to set equal to policy.num_replicas +rollout_threads: 8 # Recommended to set equal to policy.num_replicas # Observability configuration metric_logging: From 197d89e6bdf8ee283627db588c1e1f1ff80eb16a Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 10 Oct 2025 02:34:29 +0000 Subject: [PATCH 05/20] change config --- apps/grpo/qwen3_32b.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 69cbcc7ab..dd461d2dd 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -3,8 +3,8 @@ # NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability # Global configuration -group_size: 2 -batch_size: 8 +group_size: 16 +batch_size: 16 max_req_tokens: 512 max_res_tokens: 1536 model: "Qwen/Qwen3-32B" @@ -14,7 +14,7 @@ provisioner: launcher: slurm # Main loop configuration -rollout_threads: 8 # Recommended to set equal to policy.num_replicas +rollout_threads: 16 # Recommended to set equal to policy.num_replicas # Observability configuration metric_logging: @@ -70,8 +70,8 @@ trainer: enable: false parallelism: data_parallel_replicate_degree: 1 - data_parallel_shard_degree: -1 - tensor_parallel_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 8 pipeline_parallel_degree: 1 context_parallel_degree: 1 expert_parallel_degree: 1 @@ -91,7 +91,7 @@ replay_buffer: batch_size: ${batch_size} max_policy_age: ${off_by_n} # dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree - dp_size: 8 + dp_size: 1 # Reference model configuration ref_model: @@ -120,7 +120,7 @@ ref_model: services: policy: procs: ${policy.engine_config.tensor_parallel_size} - num_replicas: 1 + num_replicas: 4 hosts: 1 with_gpus: true ref_model: From b783cebadeaf8b6f78cfd49f5018b0c36c84a6be Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Thu, 9 Oct 2025 20:37:10 -0700 Subject: [PATCH 06/20] add profiling --- src/forge/actors/policy.py | 88 ++++++++++++++++++++++---------------- 1 file changed, 52 insertions(+), 36 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 212f8831a..b7c9405c5 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -19,6 +19,7 @@ import torch.distributed.checkpoint as dcp import torchstore as ts from monarch.actor import current_rank, endpoint, ProcMesh +from torch.profiler import profile, ProfilerActivity, record_function from torchstore.state_dict_utils import DELIM from vllm.config import VllmConfig @@ -63,6 +64,12 @@ logger.setLevel(logging.INFO) +def trace_handler(rank, p): + p.export_chrome_trace( + "/mnt/data/yuxuanh/trace_rank_{rank}_" + str(p.step_num) + ".json" + ) + + @dataclass class SamplingConfig: """ @@ -647,42 +654,51 @@ async def update_DEPRECATED(self, version: int): # noqa: N802 @endpoint async def update(self, version: int): """Update model weights by reading state dict from torchstore""" - logger.info( - f"[PolicyWorker::update] start updating weights to version {version}" - ) - model = self.worker.model_runner.model - prefix = get_param_prefix(version) - logger.debug(f"{prefix=}") - matching_keys = await ts.keys(prefix) - logger.debug(f"{matching_keys=}") - dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) - loaded_weights = set() - t = Tracer("policy_worker_perf/update", timer="gpu") - t.start() - # Entire state dict is stored in a single DCP handle - if dcp_whole_state_dict_key in matching_keys: - logger.info( - f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}" - ) - dcp_handle = await ts.get(dcp_whole_state_dict_key) - hf_param_names = dcp_handle.param_names - for name in hf_param_names: - param = load_tensor_from_dcp(dcp_handle, name) - loaded = model.load_weights([(name, param)]) - del param - loaded_weights.update(loaded) - else: # Load each parameter from torchstore directly without DCP - hf_param_names = [extract_param_name(key) for key in matching_keys] - # We can't pass a generator since vllm load_weights is not async. - # Instead, we just call load_weights with one parameter at a time. - for name in hf_param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) - loaded = model.load_weights([(name, param)]) - del param - loaded_weights.update(loaded) - t.stop() - logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.GPU], + record_shapes=True, + on_trace_ready=lambda p: trace_handler(self.rank, p), + with_stack=True, + profile_memory=True, + ) as prof: + with record_function("policy_worker_perf/update"): + logger.info( + f"[PolicyWorker::update] start updating weights to version {version}" + ) + model = self.worker.model_runner.model + prefix = get_param_prefix(version) + logger.debug(f"{prefix=}") + matching_keys = await ts.keys(prefix) + logger.debug(f"{matching_keys=}") + dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) + loaded_weights = set() + t = Tracer("policy_worker_perf/update", timer="gpu") + t.start() + # Entire state dict is stored in a single DCP handle + if dcp_whole_state_dict_key in matching_keys: + logger.info( + f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}" + ) + dcp_handle = await ts.get(dcp_whole_state_dict_key) + hf_param_names = dcp_handle.param_names + for name in hf_param_names: + param = load_tensor_from_dcp(dcp_handle, name) + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) + else: # Load each parameter from torchstore directly without DCP + hf_param_names = [extract_param_name(key) for key in matching_keys] + # We can't pass a generator since vllm load_weights is not async. + # Instead, we just call load_weights with one parameter at a time. + for name in hf_param_names: + param_key = get_param_key(version, name) + param = await ts.get(param_key) + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) + t.stop() + logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") + prof.step() @endpoint async def setup_kv_cache(self): From 6ff0297c329d56393cfa7335b30a91ecf829dbfa Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Thu, 9 Oct 2025 20:46:14 -0700 Subject: [PATCH 07/20] GPU -> CUDA --- src/forge/actors/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index b7c9405c5..028740d55 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -655,7 +655,7 @@ async def update_DEPRECATED(self, version: int): # noqa: N802 async def update(self, version: int): """Update model weights by reading state dict from torchstore""" with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.GPU], + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, on_trace_ready=lambda p: trace_handler(self.rank, p), with_stack=True, From 9beab76e8bd948c9ba6debd868beea8979a1a403 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Thu, 9 Oct 2025 20:59:17 -0700 Subject: [PATCH 08/20] fix profiler --- src/forge/actors/policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 028740d55..b6e796ab1 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -66,7 +66,7 @@ def trace_handler(rank, p): p.export_chrome_trace( - "/mnt/data/yuxuanh/trace_rank_{rank}_" + str(p.step_num) + ".json" + f"/mnt/data/yuxuanh/profiler/trace_rank_{rank}_" + str(p.step_num) + ".json" ) From 10053143b8504d4d5071b7930cd1e14630858793 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Thu, 9 Oct 2025 21:24:45 -0700 Subject: [PATCH 09/20] fix profiler path --- src/forge/actors/policy.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index b6e796ab1..2b821bb97 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -9,6 +9,7 @@ import asyncio import logging import os +import socket import sys import time from collections.abc import Mapping @@ -63,10 +64,14 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +hostname = socket.gethostname() + def trace_handler(rank, p): p.export_chrome_trace( - f"/mnt/data/yuxuanh/profiler/trace_rank_{rank}_" + str(p.step_num) + ".json" + f"/mnt/data/yuxuanh/profiler/{hostname}_trace_rank_{rank}_" + + str(p.step_num) + + ".json" ) From f9491db05ceddfd293bc28871ded644301543d04 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 10 Oct 2025 21:36:21 -0700 Subject: [PATCH 10/20] parallelize --- apps/grpo/qwen3_1_7b.yaml | 1 + src/forge/actors/policy.py | 19 +++++++++------- src/forge/actors/trainer.py | 44 ++++++++++++++++++++++++++++--------- 3 files changed, 46 insertions(+), 18 deletions(-) diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 53eec5cfb..ffdc1ff42 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -45,6 +45,7 @@ policy: # Trainer configuration trainer: + use_dcp: false model: name: qwen3 flavor: 1.7B diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 2b821bb97..9bae17361 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -75,6 +75,11 @@ def trace_handler(rank, p): ) +async def _ts_parallel_get(keys: list[str]) -> list[torch.Tensor]: + coros = [ts.get(key) for key in keys] + return await asyncio.gather(*coros) + + @dataclass class SamplingConfig: """ @@ -693,14 +698,12 @@ async def update(self, version: int): loaded_weights.update(loaded) else: # Load each parameter from torchstore directly without DCP hf_param_names = [extract_param_name(key) for key in matching_keys] - # We can't pass a generator since vllm load_weights is not async. - # Instead, we just call load_weights with one parameter at a time. - for name in hf_param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) - loaded = model.load_weights([(name, param)]) - del param - loaded_weights.update(loaded) + param_keys = [ + get_param_key(version, name) for name in hf_param_names + ] + new_params = await _ts_parallel_get(param_keys) + loaded = model.load_weights(zip(hf_param_names, new_params)) + loaded_weights.update(loaded) t.stop() logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") prof.step() diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index b8f9fa97e..ef0dcc713 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import logging import math import os @@ -12,7 +13,7 @@ import time from collections.abc import Mapping from dataclasses import dataclass, field, fields -from typing import Callable +from typing import Callable, Iterable import torch import torch.distributed.checkpoint as dcp @@ -39,11 +40,7 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.actors._torchstore_utils import ( - DcpHandle, - get_dcp_whole_state_dict_key, - get_param_key, -) +from forge.actors._torchstore_utils import DcpHandle, get_dcp_whole_state_dict_key from forge.controller import ForgeActor from forge.data.utils import batch_to_device @@ -93,6 +90,36 @@ def cleanup_old_weight_versions( logger.debug(f"Error deleting {item_path}: {e}") +async def _parallel_to_cpu(tensors: list[Tensor]): + num_streams = min(4, len(tensors)) # GPU typically supports 2-4 parallel copies + streams = [torch.cuda.Stream() for _ in range(num_streams)] + results = [] + + for i, tensor in enumerate(tensors): + stream = streams[i % num_streams] + with torch.cuda.stream(stream): + # Non-blocking copy happens in this stream + cpu_tensor = tensor.detach().to("cpu", non_blocking=True) + results.append(cpu_tensor) + + # Yield to event loop periodically + if i % 10 == 0: + await asyncio.sleep(0) + + # Wait for all streams to complete + for stream in streams: + stream.synchronize() + + return results + + +async def _parallel_put(kv_pairs: Iterable[tuple[str, Tensor]]): + keys, tensors = zip(*kv_pairs) + cpu_tensors = await _parallel_to_cpu(tensors) + coros = [ts.put(key, cpu_tensor) for key, cpu_tensor in zip(keys, cpu_tensors)] + await asyncio.gather(*coros) + + @dataclass class RLTrainer(ForgeActor): job: Job = field(default_factory=Job) @@ -401,10 +428,7 @@ async def push_weights(self, policy_version: int) -> None: await ts.put(key, dcp_handle) t.step("dcp_save") else: - for name, param in hf_state_dict.items(): - key = get_param_key(policy_version, name) - # RDMA is still broken on GPU, so we need to copy to CPU - await ts.put(key, param.detach().cpu()) + await _parallel_put(hf_state_dict.items()) t.step("ts_save") t.stop() end_time = time.perf_counter() From b5f307b021af58d13f28f35c8b2c6ba156f6a33f Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Fri, 10 Oct 2025 23:07:53 -0700 Subject: [PATCH 11/20] fix --- src/forge/actors/policy.py | 1 - src/forge/actors/trainer.py | 26 +------------------------- 2 files changed, 1 insertion(+), 26 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 9bae17361..d872a25b9 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -706,7 +706,6 @@ async def update(self, version: int): loaded_weights.update(loaded) t.stop() logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") - prof.step() @endpoint async def setup_kv_cache(self): diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index ef0dcc713..7d99f8d24 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -90,33 +90,9 @@ def cleanup_old_weight_versions( logger.debug(f"Error deleting {item_path}: {e}") -async def _parallel_to_cpu(tensors: list[Tensor]): - num_streams = min(4, len(tensors)) # GPU typically supports 2-4 parallel copies - streams = [torch.cuda.Stream() for _ in range(num_streams)] - results = [] - - for i, tensor in enumerate(tensors): - stream = streams[i % num_streams] - with torch.cuda.stream(stream): - # Non-blocking copy happens in this stream - cpu_tensor = tensor.detach().to("cpu", non_blocking=True) - results.append(cpu_tensor) - - # Yield to event loop periodically - if i % 10 == 0: - await asyncio.sleep(0) - - # Wait for all streams to complete - for stream in streams: - stream.synchronize() - - return results - - async def _parallel_put(kv_pairs: Iterable[tuple[str, Tensor]]): keys, tensors = zip(*kv_pairs) - cpu_tensors = await _parallel_to_cpu(tensors) - coros = [ts.put(key, cpu_tensor) for key, cpu_tensor in zip(keys, cpu_tensors)] + coros = [ts.put(key, tensor.detach().cpu()) for key, tensor in zip(keys, tensors)] await asyncio.gather(*coros) From 82d8ecd1f6c1077800d82ddbbdf060ffcd1b00e9 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 02:10:44 -0700 Subject: [PATCH 12/20] env var for cpu-bound operation perf --- src/forge/env.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/forge/env.py b/src/forge/env.py index 1699ecc90..071f7a432 100644 --- a/src/forge/env.py +++ b/src/forge/env.py @@ -105,6 +105,12 @@ def get_value(self) -> Any: description="Sets the maximum frame length for Monarch's actor message delivery in bytes.", ) +OMP_NUM_THREADS = EnvVar( + name="OMP_NUM_THREADS", + default=16, # Recommended <= # cores / # of gpus since we are using 1 gpu per process + description="Sets the number of OpenMP threads to use. This is used for CPU-bound operations in PyTorch.", +) + def all_env_vars() -> list[EnvVar]: """Retrieves all registered environment variable names.""" From 0f414ddd2e73cb0273a50e1eb4127d73c3d6a4f1 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 02:15:15 -0700 Subject: [PATCH 13/20] fix bad merge --- apps/grpo/qwen3_32b.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 34205558e..dc098df08 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -4,7 +4,7 @@ # Global configuration group_size: 16 -batch_size: 16 +local_batch_size: 16 max_req_tokens: 512 max_res_tokens: 1536 model: "Qwen/Qwen3-32B" From 8e3174f315bf31b410cb92a606dd3db06862640a Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 02:17:20 -0700 Subject: [PATCH 14/20] fix bad merge --- apps/grpo/qwen3_32b.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index dc098df08..3e6600910 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -119,7 +119,7 @@ ref_model: # All resource allocations services: policy: - procs: ${policy.engine_config.tensor_parallel_size} + procs: ${policy.engine_args.tensor_parallel_size} num_replicas: 4 hosts: 1 with_gpus: true From 05de3e73a04250ef65eedf86c62e552a888a0549 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 02:58:22 -0700 Subject: [PATCH 15/20] temp fix --- apps/grpo/main.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index c64f00bc2..88e6c3cce 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -462,7 +462,8 @@ async def continuous_training(): t.step("update_weights") if training_step >= 2: - await drop_weights(training_step - 1) + # TODO: figure out why setting to training_step - 1 will trigger error + await drop_weights(training_step - 2) t.step("drop_weights") t.stop() From 55ecef31d0c03077a8533d6f5dfe979dd48f2d59 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 03:39:18 -0700 Subject: [PATCH 16/20] disable prof --- src/forge/actors/policy.py | 83 +++++++++++++++----------------------- 1 file changed, 32 insertions(+), 51 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index aa66c25fa..9a6705ea9 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -19,7 +19,6 @@ import torch.distributed.checkpoint as dcp import torchstore as ts from monarch.actor import current_rank, endpoint, ProcMesh -from torch.profiler import profile, ProfilerActivity, record_function from torchstore.state_dict_utils import DELIM from vllm.config import VllmConfig @@ -66,14 +65,6 @@ hostname = socket.gethostname() -def trace_handler(rank, p): - p.export_chrome_trace( - f"/mnt/data/yuxuanh/profiler/{hostname}_trace_rank_{rank}_" - + str(p.step_num) - + ".json" - ) - - async def _ts_parallel_get(keys: list[str]) -> list[torch.Tensor]: coros = [ts.get(key) for key in keys] return await asyncio.gather(*coros) @@ -569,48 +560,38 @@ async def _load_tensor_parallel_state_dict( @endpoint async def update(self, version: int): """Update model weights by reading state dict from torchstore""" - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], - record_shapes=True, - on_trace_ready=lambda p: trace_handler(self.rank, p), - with_stack=True, - profile_memory=True, - ) as prof: - with record_function("policy_worker_perf/update"): - logger.info( - f"[PolicyWorker::update] start updating weights to version {version}" - ) - model = self.worker.model_runner.model - prefix = get_param_prefix(version) - logger.debug(f"{prefix=}") - matching_keys = await ts.keys(prefix) - logger.debug(f"{matching_keys=}") - dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) - loaded_weights = set() - t = Tracer("policy_worker_perf/update", timer="gpu") - t.start() - # Entire state dict is stored in a single DCP handle - if dcp_whole_state_dict_key in matching_keys: - logger.info( - f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}" - ) - dcp_handle = await ts.get(dcp_whole_state_dict_key) - hf_param_names = dcp_handle.param_names - for name in hf_param_names: - param = load_tensor_from_dcp(dcp_handle, name) - loaded = model.load_weights([(name, param)]) - del param - loaded_weights.update(loaded) - else: # Load each parameter from torchstore directly without DCP - hf_param_names = [extract_param_name(key) for key in matching_keys] - param_keys = [ - get_param_key(version, name) for name in hf_param_names - ] - new_params = await _ts_parallel_get(param_keys) - loaded = model.load_weights(zip(hf_param_names, new_params)) - loaded_weights.update(loaded) - t.stop() - logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") + logger.info( + f"[PolicyWorker::update] start updating weights to version {version}" + ) + model = self.worker.model_runner.model + prefix = get_param_prefix(version) + logger.debug(f"{prefix=}") + matching_keys = await ts.keys(prefix) + logger.debug(f"{matching_keys=}") + dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version) + loaded_weights = set() + t = Tracer("policy_worker_perf/update", timer="gpu") + t.start() + # Entire state dict is stored in a single DCP handle + if dcp_whole_state_dict_key in matching_keys: + logger.info( + f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}" + ) + dcp_handle = await ts.get(dcp_whole_state_dict_key) + hf_param_names = dcp_handle.param_names + for name in hf_param_names: + param = load_tensor_from_dcp(dcp_handle, name) + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) + else: # Load each parameter from torchstore directly without DCP + hf_param_names = [extract_param_name(key) for key in matching_keys] + param_keys = [get_param_key(version, name) for name in hf_param_names] + new_params = await _ts_parallel_get(param_keys) + loaded = model.load_weights(zip(hf_param_names, new_params)) + loaded_weights.update(loaded) + t.stop() + logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") @endpoint async def setup_kv_cache(self): From 6d27c863f015a86e25a3c09bf9fc29a3af353039 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 12:23:47 -0700 Subject: [PATCH 17/20] fix bad merge --- apps/grpo/qwen3_8b.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 64a45de2f..be87aa87f 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -53,7 +53,7 @@ trainer: lr_scheduler: warmup_steps: 1 training: - local_local_batch_size: ${local_batch_size} + local_batch_size: ${local_batch_size} seq_len: 2048 max_norm: 1.0 steps: 1000000 From e47493c401f573238af28a6c3cd5ae4d6752ed83 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Sat, 11 Oct 2025 12:39:30 -0700 Subject: [PATCH 18/20] sequential put --- src/forge/actors/trainer.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 3e19efc30..a135c3a70 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -19,6 +19,17 @@ import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.actors._torchstore_utils import ( + DcpHandle, + get_dcp_whole_state_dict_key, + get_param_key, +) + +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict @@ -39,13 +50,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.actors._torchstore_utils import DcpHandle, get_dcp_whole_state_dict_key - -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device -from forge.observability.metrics import record_metric, Reduce -from forge.observability.perf_tracker import Tracer - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -89,12 +93,6 @@ def cleanup_old_weight_versions( logger.debug(f"Error deleting {item_path}: {e}") -async def _parallel_put(kv_pairs: Iterable[tuple[str, Tensor]]): - keys, tensors = zip(*kv_pairs) - coros = [ts.put(key, tensor.detach().cpu()) for key, tensor in zip(keys, tensors)] - await asyncio.gather(*coros) - - @dataclass class RLTrainer(ForgeActor): job: Job = field(default_factory=Job) @@ -337,7 +335,10 @@ async def push_weights(self, policy_version: int) -> None: await ts.put(key, dcp_handle) t.step("dcp_save") else: - await _parallel_put(hf_state_dict.items()) + for name, param in hf_state_dict.items(): + key = get_param_key(policy_version, name) + await ts.put(key, param.detach().cpu()) + t.step("ts_save") t.step("ts_save") t.stop() end_time = time.perf_counter() From 8e792c1f010719db079c42db10c6a2fb639fe2ce Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 13 Oct 2025 00:28:27 -0700 Subject: [PATCH 19/20] revert --- src/forge/actors/policy.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 9a6705ea9..03050cc58 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -586,10 +586,14 @@ async def update(self, version: int): loaded_weights.update(loaded) else: # Load each parameter from torchstore directly without DCP hf_param_names = [extract_param_name(key) for key in matching_keys] - param_keys = [get_param_key(version, name) for name in hf_param_names] - new_params = await _ts_parallel_get(param_keys) - loaded = model.load_weights(zip(hf_param_names, new_params)) - loaded_weights.update(loaded) + # We can't pass a generator since vllm load_weights is not async. + # Instead, we just call load_weights with one parameter at a time. + for name in hf_param_names: + param_key = get_param_key(version, name) + param = await ts.get(param_key) + loaded = model.load_weights([(name, param)]) + del param + loaded_weights.update(loaded) t.stop() logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") From 40148c5e8c9c45ca681be30db47bcbc0dccca9f5 Mon Sep 17 00:00:00 2001 From: Yuxuan Hu Date: Mon, 13 Oct 2025 19:06:46 +0000 Subject: [PATCH 20/20] change default to False for dcp --- src/forge/actors/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index a135c3a70..2086d9408 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -112,7 +112,7 @@ class RLTrainer(ForgeActor): # Non JobConfig-related fields loss: Callable = lambda logits, **targets: logits state_dict_key: str = "model_state_dict" - use_dcp: bool = True + use_dcp: bool = False dcp_path: str = "forge_dcp_tmp" def __post_init__(self):