From 8ad4f35e08aa76a7d940abbcf018d9a0966c52c8 Mon Sep 17 00:00:00 2001 From: pradeepfn Date: Fri, 29 Aug 2025 07:54:51 -0700 Subject: [PATCH] skeleton code of ts integration --- src/forge/actors/trainer.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 30ed1b69d..46e5db19c 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. +import asyncio import logging import math import os @@ -12,9 +13,15 @@ import torch import torchtitan.experiments.forge.train_spec as forge_train_spec + +# from tqdm import tqdm + +from forge.controller import ForgeActor from monarch.actor import current_rank, current_size, endpoint from omegaconf import DictConfig, OmegaConf from torch import nn +from torchstore import MultiProcessStore +from torchstore._state_dict_utils import push_state_dict from torchtitan.components.loss import LossFunction # from torchdata.stateful_dataloader import StatefulDataLoader @@ -25,10 +32,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -# from tqdm import tqdm - -from forge.controller import ForgeActor - # from forge.interfaces import RLLoss # stubs for now @@ -68,6 +71,10 @@ def __init__(self, config: DictConfig): self.gradient_accumulation_steps = 1 # Example value, adjust as needed self._rank = current_rank().rank self._size = math.prod(current_size().values()) + + # init torchstore + self._tstore = asyncio.run(MultiProcessStore.create_store()) + self._init_dist() super().__init__(job_config) @@ -201,6 +208,15 @@ def train_step(self, batch) -> None: # self.profiler.step() self.current_step += 1 + # save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. + # TODOs: + # 1. Figure out if there is a value in calling state_dict_adatpr.to_hf() + # 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. + # May need to replicate the same in this code path. + # 3. Integrate zero-overhead version of push_state_dict. + # 4. Figure out a way to notify the generator app that weights are ready. This beyond the initial integration success. + # 5. Unify CheckpointManager and TorchStore weights save control path. + push_state_dict(self._tstore, self.checkpointer.states, f"v{self.current_step}") # if self.current_step % self.train_config.val_every_n_steps == 0: # self.validate() self.checkpointer.save(