|
28 | 28 | from omegaconf import DictConfig |
29 | 29 | from src.forge.data.utils import exclude_service |
30 | 30 | from torch import nn |
31 | | -from torchstore.state_dict_utils import DELIM, put_state_dict |
| 31 | +from torchstore.state_dict_utils import DELIM |
32 | 32 | from transformers import AutoModelForCausalLM |
33 | 33 | from vllm.transformers_utils.tokenizer import get_tokenizer |
34 | 34 |
|
@@ -167,8 +167,6 @@ async def setup(self): |
167 | 167 |
|
168 | 168 | self.loss = SimpleGRPOLoss(self.beta) |
169 | 169 |
|
170 | | - self.store = await ts.initialize() |
171 | | - |
172 | 170 | self.logger.info(f"Trainer model initialized on {self.device}") |
173 | 171 |
|
174 | 172 | @endpoint |
@@ -207,11 +205,10 @@ async def train_step(self, batch: list[list[Episode]]): |
207 | 205 | @endpoint |
208 | 206 | async def push_weights(self, version: int): |
209 | 207 | """Update policy model weights with trainer's current weights.""" |
210 | | - start_time = time.time() |
211 | | - assert self.store is not None, "Store must be initialized to save weights" |
212 | 208 | key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id |
213 | 209 | new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28) |
214 | | - await put_state_dict(self.store, new_sd, key) |
| 210 | + start_time = time.time() |
| 211 | + await ts.put_state_dict(new_sd, key) |
215 | 212 | end_time = time.time() |
216 | 213 | self.logger.debug( |
217 | 214 | f"Pushed weights to {key} in {end_time - start_time:.2f} seconds" |
@@ -344,6 +341,7 @@ async def main(cfg: DictConfig): |
344 | 341 | ) |
345 | 342 |
|
346 | 343 | # ---- Setup services ---- # |
| 344 | + await ts.initialize() |
347 | 345 | ( |
348 | 346 | dataloader, |
349 | 347 | policy, |
|
0 commit comments