Skip to content

Commit bdd03a8

Browse files
committed
Seg fault
1 parent 8fa4451 commit bdd03a8

File tree

4 files changed

+69
-66
lines changed

4 files changed

+69
-66
lines changed

apps/grpo/main.py

Lines changed: 12 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
import torch
1616
import torch.nn.functional as F
17+
import torchstore as ts
1718
from datasets import load_dataset
1819
from forge.actors.policy import Policy
1920
from forge.actors.replay_buffer import ReplayBuffer
21+
from forge.actors.trainer import _qwen3_hf_to_vllm
2022
from forge.cli.config import parse
2123
from forge.controller.actor import ForgeActor
2224
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
@@ -26,8 +28,7 @@
2628
from omegaconf import DictConfig
2729
from src.forge.data.utils import exclude_service
2830
from torch import nn
29-
from torchstore import MultiProcessStore
30-
from torchstore._state_dict_utils import DELIM, push_state_dict
31+
from torchstore.state_dict_utils import DELIM, put_state_dict
3132
from transformers import AutoModelForCausalLM
3233
from vllm.transformers_utils.tokenizer import get_tokenizer
3334

@@ -144,12 +145,11 @@ class Trainer(ForgeActor):
144145
learning_rate: float = 1e-5
145146
beta: float = 0.1
146147
device: torch.device | None = None
147-
store: MultiProcessStore | None = None
148148
state_dict_key: str = "model_state_dict"
149149
dp_rank: int = 0 # TODO: support data parallelism, hard code it for now
150150

151151
@endpoint
152-
def setup(self):
152+
async def setup(self):
153153
if self.device is None:
154154
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
155155

@@ -167,45 +167,9 @@ def setup(self):
167167

168168
self.loss = SimpleGRPOLoss(self.beta)
169169

170-
self.logger.info(f"Trainer model initialized on {self.device}")
170+
self.store = await ts.initialize()
171171

172-
def _qwen3_hf_to_vllm(self, saved_sd):
173-
"""Convert transformers state dict to vLLM format."""
174-
load_sd = {}
175-
num_layers = 28 # For Qwen3-1.7B
176-
177-
# Copy over directly mapped keys
178-
for k in saved_sd:
179-
if any(
180-
x in k
181-
for x in [
182-
"down_proj",
183-
"input_layernorm",
184-
"post_attention_layernorm",
185-
"o_proj",
186-
"norm.weight",
187-
"embed_tokens.weight",
188-
"lm_head.weight",
189-
]
190-
):
191-
load_sd[k] = saved_sd[k]
192-
193-
# Fuse QKV and gate_up_proj
194-
for i in range(num_layers):
195-
prefix = f"model.layers.{i}."
196-
197-
# QKV fusion
198-
q = saved_sd[prefix + "self_attn.q_proj.weight"]
199-
k = saved_sd[prefix + "self_attn.k_proj.weight"]
200-
v = saved_sd[prefix + "self_attn.v_proj.weight"]
201-
load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0)
202-
203-
# MLP gate_up_proj fusion
204-
gate = saved_sd[prefix + "mlp.gate_proj.weight"]
205-
up = saved_sd[prefix + "mlp.up_proj.weight"]
206-
load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0)
207-
208-
return load_sd
172+
self.logger.info(f"Trainer model initialized on {self.device}")
209173

210174
@endpoint
211175
async def train_step(self, batch: list[list[Episode]]):
@@ -238,16 +202,16 @@ async def train_step(self, batch: list[list[Episode]]):
238202
loss.backward()
239203
self.optimizer.step()
240204

241-
return loss.detach()
205+
return loss.item()
242206

243207
@endpoint
244208
async def push_weights(self, version: int):
245209
"""Update policy model weights with trainer's current weights."""
246210
start_time = time.time()
247-
assert self.store is not None, "Store must be provided to save weights"
211+
assert self.store is not None, "Store must be initialized to save weights"
248212
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
249-
new_sd = self._qwen3_hf_to_vllm(self.model.state_dict())
250-
await push_state_dict(self.store, new_sd, key)
213+
new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28)
214+
await put_state_dict(self.store, new_sd, key)
251215
end_time = time.time()
252216
self.logger.debug(
253217
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
@@ -322,11 +286,11 @@ class DatasetActor(ForgeActor):
322286
revision: str = "main"
323287
data_split: str = "train"
324288
streaming: bool = True
325-
tokenizer: str = "Qwen/Qwen3-1.7B"
289+
model: str = "Qwen/Qwen3-1.7B"
326290

327291
@endpoint
328292
def setup(self):
329-
self._tokenizer = get_tokenizer(self.tokenizer)
293+
self._tokenizer = get_tokenizer(self.model)
330294

331295
def gsm8k_transform(sample):
332296
system_prompt = """
@@ -380,7 +344,6 @@ async def main(cfg: DictConfig):
380344
)
381345

382346
# ---- Setup services ---- #
383-
store = await MultiProcessStore.create_store()
384347
(
385348
dataloader,
386349
policy,
@@ -399,13 +362,11 @@ async def main(cfg: DictConfig):
399362
ServiceConfig(**cfg.policy.service),
400363
Policy,
401364
**exclude_service(cfg.policy),
402-
store=store,
403365
),
404366
spawn_service(
405367
ServiceConfig(**cfg.trainer.service),
406368
Trainer,
407369
**exclude_service(cfg.trainer),
408-
store=store,
409370
),
410371
spawn_service(
411372
ServiceConfig(**cfg.replay_buffer.service),

apps/grpo/qwen3_1_7b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dataset:
1313
revision: "main"
1414
data_split: "train"
1515
streaming: true
16-
tokenizer: ${model}
16+
model: ${model}
1717
service:
1818
procs_per_replica: 1
1919
num_replicas: 1

src/forge/actors/policy.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from dataclasses import asdict, dataclass, field, fields
1414

1515
import torch
16+
import torchstore as ts
1617
from monarch.actor import current_rank, endpoint, ProcMesh
17-
from torchstore import MultiProcessStore
18-
from torchstore._state_dict_utils import DELIM
18+
from torchstore.state_dict_utils import DELIM
1919

2020
from vllm.engine.arg_utils import EngineArgs
2121
from vllm.entrypoints.utils import _validate_truncation_size
@@ -107,14 +107,13 @@ class Policy(PolicyInterface):
107107
lora_request: LoRARequest | None = None
108108
tokenization_kwargs: dict = field(default_factory=dict)
109109
policy_worker: "PolicyWorker" = None
110-
store: MultiProcessStore | None = None
111110

112111
def __post_init__(self):
113112
self._run_task: asyncio.Task | None = None
114113
self._policy_proc: ProcMesh | None = None
115114
self._worker_procs: ProcMesh | None = None
116115
self.weights_version: int = 0
117-
self.running: bool = False
116+
self.running = False
118117
if isinstance(self.engine_config, Mapping):
119118
self.engine_config = EngineConfig.from_dict(self.engine_config)
120119
if isinstance(self.sampling_config, Mapping):
@@ -128,7 +127,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
128127
engine_config: EngineConfig | Mapping = EngineConfig(),
129128
sampling_config: SamplingConfig | Mapping = SamplingConfig(),
130129
available_devices: str | None = None,
131-
store: MultiProcessStore | None = None,
132130
**kwargs,
133131
) -> "Policy":
134132
# Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
@@ -161,7 +159,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
161159
sampling_config=sampling_config,
162160
available_devices=available_devices,
163161
policy_worker=workers,
164-
store=store,
165162
)
166163
policy._policy_proc = policy_proc
167164
policy._worker_procs = worker_procs
@@ -189,7 +186,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
189186
async def setup(self):
190187
# Set up policy_worker
191188
assert self.policy_worker is not None, "Policy worker should not be None"
192-
await self.policy_worker.setup.call(store=self.store)
189+
await self.policy_worker.setup.call()
193190

194191
self.request_id = 0
195192
self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
@@ -343,9 +340,8 @@ async def run(self):
343340

344341
for request_output in processed_outputs.request_outputs:
345342
if request_output.finished:
346-
if request_output.request_id in self.requests:
347-
_, fut = self.requests.pop(request_output.request_id)
348-
fut.set_result(request_output)
343+
_, fut = self.requests.pop(request_output.request_id)
344+
fut.set_result(request_output)
349345

350346
@endpoint
351347
async def update_weights(self):
@@ -403,8 +399,8 @@ def __post_init__(self):
403399
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS)
404400

405401
@endpoint
406-
async def setup(self, store: MultiProcessStore = None):
407-
self.torchstore = store
402+
async def setup(self):
403+
self.store = await ts.initialize()
408404
# TODO: remove ["gpus"] when monarch implements a flat rank
409405
self.rank = current_rank()["gpus"]
410406
self.worker = self.setup_worker()
@@ -428,7 +424,7 @@ async def _load_tensor_parallel_state_dict(
428424

429425
# Load the full tensor from torchstore
430426
# TODO: only get the part of the tensor that is needed
431-
stored_tensor = await self.torchstore.get(
427+
stored_tensor = await self.store.get(
432428
f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}"
433429
)
434430
sharding.load_from_source_to_target(
@@ -440,7 +436,7 @@ async def _load_tensor_parallel_state_dict(
440436
@endpoint
441437
async def update(self, version: int):
442438
"""Update model weights by reading state dict from torchstore"""
443-
if self.torchstore is None:
439+
if self.store is None:
444440
raise Exception("No torchstore configured, skipping model update")
445441
key = f"{self.state_dict_key}{DELIM}{version}"
446442
model = self.worker.model_runner.model

src/forge/actors/trainer.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,49 @@ def push_weights(self) -> None:
268268
async def cleanup(self) -> None:
269269
if self.engine.checkpointer:
270270
self.engine.checkpointer.close()
271+
272+
273+
def _qwen3_hf_to_vllm(
274+
sd: dict[str, torch.Tensor], num_layers: int
275+
) -> dict[str, torch.Tensor]:
276+
"""Convert transformers state dict to vLLM format. Specifically, this fuses
277+
QKV projection and MLP gate_up_proj layers.
278+
279+
Args:
280+
sd (dict): State dict from HF model.
281+
num_layers (int): Number of layers in the model.
282+
283+
Returns:
284+
dict: State dict in vLLM format.
285+
"""
286+
load_sd = {}
287+
288+
# Copy over directly mapped keys
289+
for k in sd:
290+
if any(
291+
x in k
292+
for x in [
293+
"down_proj",
294+
"input_layernorm",
295+
"post_attention_layernorm",
296+
"o_proj",
297+
"norm.weight",
298+
"embed_tokens.weight",
299+
"lm_head.weight",
300+
]
301+
):
302+
load_sd[k] = sd[k]
303+
304+
for i in range(num_layers):
305+
prefix = f"model.layers.{i}."
306+
# QKV fusion
307+
q = sd[prefix + "self_attn.q_proj.weight"]
308+
k = sd[prefix + "self_attn.k_proj.weight"]
309+
v = sd[prefix + "self_attn.v_proj.weight"]
310+
load_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat([q, k, v], dim=0)
311+
# MLP gate_up_proj fusion
312+
gate = sd[prefix + "mlp.gate_proj.weight"]
313+
up = sd[prefix + "mlp.up_proj.weight"]
314+
load_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0)
315+
316+
return load_sd

0 commit comments

Comments
 (0)