1414
1515import torch
1616import torch .nn .functional as F
17+ import torchstore as ts
1718from datasets import load_dataset
1819from forge .actors .policy import Policy
1920from forge .actors .replay_buffer import ReplayBuffer
21+ from forge .actors .trainer import _qwen3_hf_to_vllm
2022from forge .cli .config import parse
2123from forge .controller .actor import ForgeActor
2224from forge .controller .service import ServiceConfig , shutdown_service , spawn_service
2628from omegaconf import DictConfig
2729from src .forge .data .utils import exclude_service
2830from torch import nn
29- from torchstore import MultiProcessStore
30- from torchstore ._state_dict_utils import DELIM , push_state_dict
31+ from torchstore .state_dict_utils import DELIM , put_state_dict
3132from transformers import AutoModelForCausalLM
3233from vllm .transformers_utils .tokenizer import get_tokenizer
3334
@@ -144,12 +145,11 @@ class Trainer(ForgeActor):
144145 learning_rate : float = 1e-5
145146 beta : float = 0.1
146147 device : torch .device | None = None
147- store : MultiProcessStore | None = None
148148 state_dict_key : str = "model_state_dict"
149149 dp_rank : int = 0 # TODO: support data parallelism, hard code it for now
150150
151151 @endpoint
152- def setup (self ):
152+ async def setup (self ):
153153 if self .device is None :
154154 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
155155
@@ -167,45 +167,9 @@ def setup(self):
167167
168168 self .loss = SimpleGRPOLoss (self .beta )
169169
170- self .logger . info ( f"Trainer model initialized on { self . device } " )
170+ self .store = await ts . initialize ( )
171171
172- def _qwen3_hf_to_vllm (self , saved_sd ):
173- """Convert transformers state dict to vLLM format."""
174- load_sd = {}
175- num_layers = 28 # For Qwen3-1.7B
176-
177- # Copy over directly mapped keys
178- for k in saved_sd :
179- if any (
180- x in k
181- for x in [
182- "down_proj" ,
183- "input_layernorm" ,
184- "post_attention_layernorm" ,
185- "o_proj" ,
186- "norm.weight" ,
187- "embed_tokens.weight" ,
188- "lm_head.weight" ,
189- ]
190- ):
191- load_sd [k ] = saved_sd [k ]
192-
193- # Fuse QKV and gate_up_proj
194- for i in range (num_layers ):
195- prefix = f"model.layers.{ i } ."
196-
197- # QKV fusion
198- q = saved_sd [prefix + "self_attn.q_proj.weight" ]
199- k = saved_sd [prefix + "self_attn.k_proj.weight" ]
200- v = saved_sd [prefix + "self_attn.v_proj.weight" ]
201- load_sd [prefix + "self_attn.qkv_proj.weight" ] = torch .cat ([q , k , v ], dim = 0 )
202-
203- # MLP gate_up_proj fusion
204- gate = saved_sd [prefix + "mlp.gate_proj.weight" ]
205- up = saved_sd [prefix + "mlp.up_proj.weight" ]
206- load_sd [prefix + "mlp.gate_up_proj.weight" ] = torch .cat ([gate , up ], dim = 0 )
207-
208- return load_sd
172+ self .logger .info (f"Trainer model initialized on { self .device } " )
209173
210174 @endpoint
211175 async def train_step (self , batch : list [list [Episode ]]):
@@ -238,16 +202,16 @@ async def train_step(self, batch: list[list[Episode]]):
238202 loss .backward ()
239203 self .optimizer .step ()
240204
241- return loss .detach ()
205+ return loss .item ()
242206
243207 @endpoint
244208 async def push_weights (self , version : int ):
245209 """Update policy model weights with trainer's current weights."""
246210 start_time = time .time ()
247- assert self .store is not None , "Store must be provided to save weights"
211+ assert self .store is not None , "Store must be initialized to save weights"
248212 key = f"{ self .state_dict_key } { DELIM } { version } " # Use version as unique id
249- new_sd = self . _qwen3_hf_to_vllm (self .model .state_dict ())
250- await push_state_dict (self .store , new_sd , key )
213+ new_sd = _qwen3_hf_to_vllm (self .model .state_dict (), num_layers = 28 )
214+ await put_state_dict (self .store , new_sd , key )
251215 end_time = time .time ()
252216 self .logger .debug (
253217 f"Pushed weights to { key } in { end_time - start_time :.2f} seconds"
@@ -322,11 +286,11 @@ class DatasetActor(ForgeActor):
322286 revision : str = "main"
323287 data_split : str = "train"
324288 streaming : bool = True
325- tokenizer : str = "Qwen/Qwen3-1.7B"
289+ model : str = "Qwen/Qwen3-1.7B"
326290
327291 @endpoint
328292 def setup (self ):
329- self ._tokenizer = get_tokenizer (self .tokenizer )
293+ self ._tokenizer = get_tokenizer (self .model )
330294
331295 def gsm8k_transform (sample ):
332296 system_prompt = """
@@ -380,7 +344,6 @@ async def main(cfg: DictConfig):
380344 )
381345
382346 # ---- Setup services ---- #
383- store = await MultiProcessStore .create_store ()
384347 (
385348 dataloader ,
386349 policy ,
@@ -399,13 +362,11 @@ async def main(cfg: DictConfig):
399362 ServiceConfig (** cfg .policy .service ),
400363 Policy ,
401364 ** exclude_service (cfg .policy ),
402- store = store ,
403365 ),
404366 spawn_service (
405367 ServiceConfig (** cfg .trainer .service ),
406368 Trainer ,
407369 ** exclude_service (cfg .trainer ),
408- store = store ,
409370 ),
410371 spawn_service (
411372 ServiceConfig (** cfg .replay_buffer .service ),
0 commit comments