Skip to content

Commit 7eedc91

Browse files
committed
Make torchstore actually work!
1 parent bdd03a8 commit 7eedc91

File tree

2 files changed

+5
-10
lines changed

2 files changed

+5
-10
lines changed

apps/grpo/main.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from omegaconf import DictConfig
2929
from src.forge.data.utils import exclude_service
3030
from torch import nn
31-
from torchstore.state_dict_utils import DELIM, put_state_dict
31+
from torchstore.state_dict_utils import DELIM
3232
from transformers import AutoModelForCausalLM
3333
from vllm.transformers_utils.tokenizer import get_tokenizer
3434

@@ -167,8 +167,6 @@ async def setup(self):
167167

168168
self.loss = SimpleGRPOLoss(self.beta)
169169

170-
self.store = await ts.initialize()
171-
172170
self.logger.info(f"Trainer model initialized on {self.device}")
173171

174172
@endpoint
@@ -207,11 +205,10 @@ async def train_step(self, batch: list[list[Episode]]):
207205
@endpoint
208206
async def push_weights(self, version: int):
209207
"""Update policy model weights with trainer's current weights."""
210-
start_time = time.time()
211-
assert self.store is not None, "Store must be initialized to save weights"
212208
key = f"{self.state_dict_key}{DELIM}{version}" # Use version as unique id
213209
new_sd = _qwen3_hf_to_vllm(self.model.state_dict(), num_layers=28)
214-
await put_state_dict(self.store, new_sd, key)
210+
start_time = time.time()
211+
await ts.put_state_dict(new_sd, key)
215212
end_time = time.time()
216213
self.logger.debug(
217214
f"Pushed weights to {key} in {end_time - start_time:.2f} seconds"
@@ -344,6 +341,7 @@ async def main(cfg: DictConfig):
344341
)
345342

346343
# ---- Setup services ---- #
344+
await ts.initialize()
347345
(
348346
dataloader,
349347
policy,

src/forge/actors/policy.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,6 @@ def __post_init__(self):
400400

401401
@endpoint
402402
async def setup(self):
403-
self.store = await ts.initialize()
404403
# TODO: remove ["gpus"] when monarch implements a flat rank
405404
self.rank = current_rank()["gpus"]
406405
self.worker = self.setup_worker()
@@ -424,7 +423,7 @@ async def _load_tensor_parallel_state_dict(
424423

425424
# Load the full tensor from torchstore
426425
# TODO: only get the part of the tensor that is needed
427-
stored_tensor = await self.store.get(
426+
stored_tensor = await ts.get(
428427
f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}"
429428
)
430429
sharding.load_from_source_to_target(
@@ -436,8 +435,6 @@ async def _load_tensor_parallel_state_dict(
436435
@endpoint
437436
async def update(self, version: int):
438437
"""Update model weights by reading state dict from torchstore"""
439-
if self.store is None:
440-
raise Exception("No torchstore configured, skipping model update")
441438
key = f"{self.state_dict_key}{DELIM}{version}"
442439
model = self.worker.model_runner.model
443440
current_state_dict = model.state_dict()

0 commit comments

Comments
 (0)