Skip to content

Commit 8ad4f35

Browse files
committed
skeleton code of ts integration
1 parent ce91430 commit 8ad4f35

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

src/forge/actors/trainer.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,23 @@
55
# LICENSE file in the root directory of this source tree.
66

77

8+
import asyncio
89
import logging
910
import math
1011
import os
1112
from typing import Any
1213

1314
import torch
1415
import torchtitan.experiments.forge.train_spec as forge_train_spec
16+
17+
# from tqdm import tqdm
18+
19+
from forge.controller import ForgeActor
1520
from monarch.actor import current_rank, current_size, endpoint
1621
from omegaconf import DictConfig, OmegaConf
1722
from torch import nn
23+
from torchstore import MultiProcessStore
24+
from torchstore._state_dict_utils import push_state_dict
1825
from torchtitan.components.loss import LossFunction
1926

2027
# from torchdata.stateful_dataloader import StatefulDataLoader
@@ -25,10 +32,6 @@
2532
from torchtitan.experiments.forge.engine import ForgeEngine
2633
from torchtitan.experiments.forge.job_config import ForgeJobConfig
2734

28-
# from tqdm import tqdm
29-
30-
from forge.controller import ForgeActor
31-
3235
# from forge.interfaces import RLLoss
3336

3437
# stubs for now
@@ -68,6 +71,10 @@ def __init__(self, config: DictConfig):
6871
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
6972
self._rank = current_rank().rank
7073
self._size = math.prod(current_size().values())
74+
75+
# init torchstore
76+
self._tstore = asyncio.run(MultiProcessStore.create_store())
77+
7178
self._init_dist()
7279
super().__init__(job_config)
7380

@@ -201,6 +208,15 @@ def train_step(self, batch) -> None:
201208
# self.profiler.step()
202209
self.current_step += 1
203210

211+
# save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
212+
# TODOs:
213+
# 1. Figure out if there is a value in calling state_dict_adatpr.to_hf()
214+
# 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
215+
# May need to replicate the same in this code path.
216+
# 3. Integrate zero-overhead version of push_state_dict.
217+
# 4. Figure out a way to notify the generator app that weights are ready. This beyond the initial integration success.
218+
# 5. Unify CheckpointManager and TorchStore weights save control path.
219+
push_state_dict(self._tstore, self.checkpointer.states, f"v{self.current_step}")
204220
# if self.current_step % self.train_config.val_every_n_steps == 0:
205221
# self.validate()
206222
self.checkpointer.save(

0 commit comments

Comments
 (0)