Skip to content

Commit ed225df

Browse files
committed
merge
2 parents 061dbdc + 3303af5 commit ed225df

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+851
-565
lines changed

.github/packaging/pre_build_gpu.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ echo "build dir is $BUILD_DIR"
1616
echo "wheel dir is $WHL_DIR"
1717

1818
build_monarch() {
19+
export MONARCH_PACKAGE_NAME="torchmonarch"
1920
# Get Rust build related pieces
2021
if ! command -v rustup &> /dev/null; then
2122
echo "getting rustup"

.github/workflows/build_vllm.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ permissions:
1212

1313
jobs:
1414
build:
15-
name: forge-cu126-nightly
15+
name: forge-cu129-nightly
1616
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
1717
strategy:
1818
fail-fast: false
@@ -31,13 +31,13 @@ jobs:
3131
{
3232
"python_version": "3.10",
3333
"gpu_arch_type": "cpu",
34-
"gpu_arch_version": "12.6",
35-
"desired_cuda": "cu126",
36-
"container_image": "pytorch/manylinux2_28-builder:cuda12.6",
34+
"gpu_arch_version": "12.9",
35+
"desired_cuda": "cu129",
36+
"container_image": "pytorch/manylinux2_28-builder:cuda12.9",
3737
"package_type": "manywheel",
38-
"build_name": "manywheel-py3_10-cuda12_6",
38+
"build_name": "manywheel-py3_10-cuda12_9",
3939
"validation_runner": "linux.12xlarge.memory",
40-
"installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126",
40+
"installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129",
4141
"channel": "nightly",
4242
"upload_to_base_bucket": "no",
4343
"stable_version": "2.8.0",

.github/workflows/build_wheels.yaml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ permissions:
1212

1313
jobs:
1414
build:
15-
name: forge-cu126-nightly
15+
name: forge-cu129-nightly
1616
uses: pytorch/test-infra/.github/workflows/build_wheels_linux.yml@main
1717
strategy:
1818
fail-fast: false
@@ -31,13 +31,13 @@ jobs:
3131
{
3232
"python_version": "3.10",
3333
"gpu_arch_type": "cuda",
34-
"gpu_arch_version": "12.6",
35-
"desired_cuda": "cu126",
36-
"container_image": "pytorch/manylinux2_28-builder:cuda12.6",
34+
"gpu_arch_version": "12.9",
35+
"desired_cuda": "cu129",
36+
"container_image": "pytorch/manylinux2_28-builder:cuda12.9",
3737
"package_type": "manywheel",
38-
"build_name": "manywheel-py3_10-cuda12_6",
38+
"build_name": "manywheel-py3_10-cuda12_9",
3939
"validation_runner": "linux.4xlarge.nvidia.gpu",
40-
"installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu126",
40+
"installation": "pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu129",
4141
"channel": "nightly",
4242
"upload_to_base_bucket": "no",
4343
"stable_version": "2.8.0",

.github/workflows/docs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ on:
99

1010
jobs:
1111
build-docs:
12+
if: github.repository_owner == 'meta-pytorch'
1213
name: Build Documentation
1314
runs-on: linux.g5.4xlarge.nvidia.gpu
1415
timeout-minutes: 30

apps/grpo/qwen3_1_7b.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# Global configuration
55
group_size: 8
6-
batch_size: 16
6+
local_batch_size: 16 # per-device batch size
77
max_req_tokens: 512
88
max_res_tokens: 512
99
model: "Qwen/Qwen3-1.7B"
@@ -56,7 +56,7 @@ trainer:
5656
lr_scheduler:
5757
warmup_steps: 1
5858
training:
59-
local_batch_size: ${batch_size}
59+
local_batch_size: ${local_batch_size}
6060
seq_len: 2048
6161
max_norm: 1.0
6262
steps: 1000000
@@ -85,7 +85,7 @@ trainer:
8585

8686
# Replay buffer configuration
8787
replay_buffer:
88-
batch_size: ${batch_size}
88+
batch_size: ${local_batch_size}
8989
max_policy_age: ${off_by_n}
9090
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
9191

apps/grpo/qwen3_32b.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
# Global configuration
66
group_size: 16
7-
batch_size: 32
7+
local_batch_size: 32 # per-device batch size
88
max_req_tokens: 1024
99
max_res_tokens: 1024
1010
model: "Qwen/Qwen3-32B"
@@ -59,7 +59,7 @@ trainer:
5959
lr_scheduler:
6060
warmup_steps: 1
6161
training:
62-
local_batch_size: ${batch_size}
62+
local_batch_size: ${local_batch_size}
6363
seq_len: 2048
6464
max_norm: 1.0
6565
steps: 1000000
@@ -87,7 +87,7 @@ trainer:
8787

8888
# Replay buffer configuration
8989
replay_buffer:
90-
batch_size: ${batch_size}
90+
batch_size: ${local_batch_size}
9191
max_policy_age: ${off_by_n}
9292
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
9393
dp_size: 1

apps/grpo/qwen3_8b.yaml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
# Global configuration
55
group_size: 8
6-
batch_size: 16
6+
local_batch_size: 16 # per-device batch size
77
max_req_tokens: 512
88
max_res_tokens: 512
99
model: "Qwen/Qwen3-8B"
@@ -28,7 +28,6 @@ dataset:
2828

2929
# Policy configuration
3030
policy:
31-
use_vllm_builtin_load: true
3231
engine_config:
3332
model: ${model}
3433
tensor_parallel_size: 2
@@ -43,7 +42,6 @@ policy:
4342
# Trainer configuration
4443
trainer:
4544
use_dcp: true
46-
use_vllm_builtin_load: true
4745
model:
4846
name: qwen3
4947
flavor: 8B
@@ -55,7 +53,7 @@ trainer:
5553
lr_scheduler:
5654
warmup_steps: 1
5755
training:
58-
local_batch_size: ${batch_size}
56+
local_local_batch_size: ${local_batch_size}
5957
seq_len: 2048
6058
max_norm: 1.0
6159
steps: 1000000
@@ -84,7 +82,7 @@ trainer:
8482

8583
# Replay buffer configuration
8684
replay_buffer:
87-
batch_size: ${batch_size}
85+
local_batch_size: ${local_batch_size}
8886
max_policy_age: ${off_by_n}
8987
# This should match the dp_size of TorchTitan
9088
# Here it's set explicitly to 2, because we've set

src/forge/actors/policy.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import logging
1111
import os
1212
import sys
13-
import time
1413
from collections.abc import Mapping
1514
from copy import copy
1615
from dataclasses import asdict, dataclass, field, fields
@@ -140,7 +139,6 @@ def create_vllm_config(self) -> VllmConfig:
140139
class Policy(PolicyInterface):
141140
engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig)
142141
sampling_config: SamplingConfig | Mapping = field(default_factory=SamplingConfig)
143-
use_vllm_builtin_load: bool = True
144142
available_devices: str | None = None
145143
use_dcp: bool = True
146144
# Gets set up by setup
@@ -246,7 +244,7 @@ async def setup(self):
246244

247245
self.request_id = 0
248246
self.policy_version = 0
249-
self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
247+
self.requests: dict[str, tuple[ParentRequest | None, asyncio.Future]] = {}
250248

251249
# TODO: Investigate whether this can be combined with `policy.running`
252250
# Whether this policy is accepting requests.
@@ -484,14 +482,11 @@ async def update_weights(self, policy_version: int):
484482
record_metric("policy/update_weights/count_weight_updates", 1, Reduce.SUM)
485483

486484
logger.debug(f"Starting weight update on {self.__class__.__name__}")
487-
if self.use_vllm_builtin_load:
488-
await self.policy_worker.update.call(version=policy_version)
489-
else:
490-
await self.policy_worker.update_DEPRECATED.call(version=policy_version)
485+
await self.policy_worker.update.call(version=policy_version)
491486
self.policy_version = policy_version
492487

493488
# After updating the weights, we need to reset the KV cache
494-
self.scheduler.kv_cache_manager.reset_prefix_cache()
489+
self.scheduler.reset_prefix_cache()
495490

496491
# Resume accepting requests and wake up any waiting generate() calls
497492
async with self.request_lock:
@@ -501,16 +496,8 @@ async def update_weights(self, policy_version: int):
501496
logger.info(f"Weight update completed (now v{self.policy_version})")
502497

503498
@endpoint
504-
async def update_weights_DEPRECATED(self, policy_version: int): # noqa: N802
505-
# TODO: If generating long sequences, this might be long and will block policy weight updates
506-
curr_requests = [fut for _, fut in self.requests.values()]
507-
if curr_requests:
508-
logger.debug(f"Waiting for {len(curr_requests)} pending requests")
509-
await asyncio.gather(*curr_requests)
510-
511-
await self.policy_worker.update_DEPRECATED.call(version=policy_version)
512-
self.policy_version = policy_version
513-
logger.info(f"Weight update completed (now v{self.policy_version})")
499+
async def _reset_prefix_cache(self):
500+
self.scheduler.reset_prefix_cache()
514501

515502
@endpoint
516503
async def get_version(self) -> int:
@@ -550,6 +537,7 @@ def _to_completions(self, request_output: RequestOutput) -> list[Completion]:
550537
token_ids=torch.tensor(output.token_ids),
551538
logprobs=self._extract_logprobs(output),
552539
generator_version=self.policy_version,
540+
metadata={"num_cached_tokens": request_output.num_cached_tokens},
553541
)
554542
)
555543

@@ -587,8 +575,8 @@ def __post_init__(self):
587575

588576
@endpoint
589577
async def setup(self):
590-
# TODO: remove ["gpus"] when monarch implements a flat rank
591-
self.rank = current_rank()["gpus"]
578+
self.rank = current_rank().rank
579+
os.environ["RANK"] = str(self.rank)
592580
self.worker = self.setup_worker()
593581

594582
@endpoint
@@ -631,19 +619,6 @@ async def _load_tensor_parallel_state_dict(
631619
current_tensor,
632620
)
633621

634-
@endpoint
635-
async def update_DEPRECATED(self, version: int): # noqa: N802
636-
"""Update model weights by reading state dict from torchstore.
637-
Deprecated. This uses manual sharding logic which is buggy."""
638-
key = f"{self.state_dict_key}{DELIM}{version}"
639-
model = self.worker.model_runner.model
640-
current_state_dict = model.state_dict()
641-
start = time.perf_counter()
642-
await self._load_tensor_parallel_state_dict(current_state_dict, version)
643-
logger.info(
644-
f"Loaded state dict from {key} in {time.perf_counter() - start} seconds"
645-
)
646-
647622
@endpoint
648623
async def update(self, version: int):
649624
"""Update model weights by reading state dict from torchstore"""

src/forge/actors/trainer.py

Lines changed: 2 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from monarch.actor import current_rank, current_size, endpoint
2222
from torch import Tensor
2323
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
24-
from torchstore.state_dict_utils import DELIM
2524
from torchtitan.config.job_config import (
2625
ActivationCheckpoint,
2726
Checkpoint,
@@ -114,8 +113,6 @@ class RLTrainer(ForgeActor):
114113
state_dict_key: str = "model_state_dict"
115114
use_dcp: bool = True
116115
dcp_path: str = "forge_dcp_tmp"
117-
vllm_tp_DEPRECATED: int = 1 # noqa: N815
118-
use_vllm_builtin_load: bool = True
119116

120117
def __post_init__(self):
121118
"""Initializes config types and env variables.
@@ -159,6 +156,8 @@ def __post_init__(self):
159156
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
160157
}
161158
os.environ.update(env)
159+
logger.info("Compiling loss")
160+
self.loss = torch.compile(self.loss)
162161

163162
@endpoint
164163
async def setup(self):
@@ -168,9 +167,7 @@ async def setup(self):
168167
"loss",
169168
"state_dict_key",
170169
"use_dcp",
171-
"use_vllm_builtin_load",
172170
"dcp_path",
173-
"vllm_tp_DEPRECATED",
174171
}:
175172
engine_config.pop(key) # Not part of job config
176173
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
@@ -302,76 +299,12 @@ async def train_step(
302299
t.stop()
303300
return loss
304301

305-
@endpoint
306-
async def push_weights_DEPRECATED( # noqa: N802
307-
self, policy_version: int, vllm_tp_DEPRECATED: int = 1
308-
) -> None: # noqa: N802
309-
"""[Deprecated] This method pushes weights to torchstore in the vllm format,
310-
which is buggy and not scalable to other models.
311-
Deprecated in favor of push_weights."""
312-
return await self._push_weights_DEPRECATED(policy_version, vllm_tp_DEPRECATED)
313-
314-
async def _push_weights_DEPRECATED( # noqa: N802
315-
self, policy_version: int, vllm_tp_DEPRECATED: int
316-
) -> None: # noqa: N802
317-
# Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
318-
# TODO:
319-
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
320-
# May need to replicate the same in this code path.
321-
# 2. Unify CheckpointManager and TorchStore weights save control path.
322-
if "model" not in self.engine.checkpointer.states:
323-
raise RuntimeError("Model state not found in checkpointer state")
324-
325-
sd = self.engine.checkpointer.states["model"].state_dict()
326-
flattened_state_dict, _ = flatten_state_dict(sd)
327-
328-
if self.engine.checkpointer.sd_adapter is None:
329-
raise RuntimeError(
330-
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
331-
)
332-
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
333-
334-
# TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
335-
vllm_ready_hf_sd = _qwen3_hf_to_vllm(
336-
sd=hf_state_dict,
337-
num_layers=self.engine.model_args.n_layers,
338-
vllm_tp=vllm_tp_DEPRECATED,
339-
)
340-
341-
key = f"{self.state_dict_key}{DELIM}{policy_version}"
342-
if self.use_dcp:
343-
# TODO - DCP should probably be being saved to NFS explicitly?
344-
# Right now it will only save everything locally
345-
storage_writer = torch.distributed.checkpoint.FileSystemWriter(
346-
key, single_file_per_rank=False, thread_count=8
347-
)
348-
metadata = dcp.save(
349-
storage_writer=storage_writer, state_dict=vllm_ready_hf_sd
350-
)
351-
await ts.put(key, metadata)
352-
353-
# Delete old weight versions if they exist
354-
if self.rank == 0:
355-
cleanup_old_weight_versions(
356-
state_dict_key=self.state_dict_key,
357-
delim=DELIM,
358-
current_policy_version=policy_version,
359-
)
360-
else:
361-
await ts.put_state_dict(vllm_ready_hf_sd, key)
362-
363302
@endpoint
364303
async def push_weights(self, policy_version: int) -> None:
365304
"""Push weights to torchstore in HF format."""
366305
t = Tracer("rl_trainer_perf/push_weights", timer="gpu", track_memory=True)
367306
t.start()
368307
logger.info(f"Pushing weights for policy version {policy_version}")
369-
if not self.use_vllm_builtin_load:
370-
result = await self._push_weights_DEPRECATED(
371-
policy_version, self.vllm_tp_DEPRECATED
372-
)
373-
t.step("push_weights_DEPRECATED")
374-
return result
375308

376309
start_time = time.perf_counter()
377310
if "model" not in self.engine.checkpointer.states:

0 commit comments

Comments
 (0)