|
5 | 5 | # LICENSE file in the root directory of this source tree. |
6 | 6 |
|
7 | 7 |
|
| 8 | +import asyncio |
8 | 9 | import logging |
9 | 10 | import math |
10 | 11 | import os |
11 | 12 | from typing import Any |
12 | 13 |
|
13 | 14 | import torch |
14 | 15 | import torchtitan.experiments.forge.train_spec as forge_train_spec |
| 16 | + |
| 17 | +# from tqdm import tqdm |
| 18 | + |
| 19 | +from forge.controller import ForgeActor |
15 | 20 | from monarch.actor import current_rank, current_size, endpoint |
16 | 21 | from omegaconf import DictConfig, OmegaConf |
17 | 22 | from torch import nn |
| 23 | +from torchstore import MultiProcessStore |
| 24 | +from torchstore._state_dict_utils import push_state_dict |
18 | 25 | from torchtitan.components.loss import LossFunction |
19 | 26 |
|
20 | 27 | # from torchdata.stateful_dataloader import StatefulDataLoader |
|
25 | 32 | from torchtitan.experiments.forge.engine import ForgeEngine |
26 | 33 | from torchtitan.experiments.forge.job_config import ForgeJobConfig |
27 | 34 |
|
28 | | -# from tqdm import tqdm |
29 | | - |
30 | | -from forge.controller import ForgeActor |
31 | | - |
32 | 35 | # from forge.interfaces import RLLoss |
33 | 36 |
|
34 | 37 | # stubs for now |
@@ -68,6 +71,10 @@ def __init__(self, config: DictConfig): |
68 | 71 | self.gradient_accumulation_steps = 1 # Example value, adjust as needed |
69 | 72 | self._rank = current_rank().rank |
70 | 73 | self._size = math.prod(current_size().values()) |
| 74 | + |
| 75 | + # init torchstore |
| 76 | + self._tstore = asyncio.run(MultiProcessStore.create_store()) |
| 77 | + |
71 | 78 | self._init_dist() |
72 | 79 | super().__init__(job_config) |
73 | 80 |
|
@@ -201,6 +208,15 @@ def train_step(self, batch) -> None: |
201 | 208 | # self.profiler.step() |
202 | 209 | self.current_step += 1 |
203 | 210 |
|
| 211 | + # save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. |
| 212 | + # TODOs: |
| 213 | + # 1. Figure out if there is a value in calling state_dict_adatpr.to_hf() |
| 214 | + # 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. |
| 215 | + # May need to replicate the same in this code path. |
| 216 | + # 3. Integrate zero-overhead version of push_state_dict. |
| 217 | + # 4. Figure out a way to notify the generator app that weights are ready. This beyond the initial integration success. |
| 218 | + # 5. Unify CheckpointManager and TorchStore weights save control path. |
| 219 | + push_state_dict(self._tstore, self.checkpointer.states, f"v{self.current_step}") |
204 | 220 | # if self.current_step % self.train_config.val_every_n_steps == 0: |
205 | 221 | # self.validate() |
206 | 222 | self.checkpointer.save( |
|
0 commit comments