Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,23 @@
# LICENSE file in the root directory of this source tree.


import asyncio
import logging
import math
import os
from typing import Any

import torch
import torchtitan.experiments.forge.train_spec as forge_train_spec

# from tqdm import tqdm

from forge.controller import ForgeActor
from monarch.actor import current_rank, current_size, endpoint
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import push_state_dict
from torchtitan.components.loss import LossFunction

# from torchdata.stateful_dataloader import StatefulDataLoader
Expand All @@ -25,10 +32,6 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

# from tqdm import tqdm

from forge.controller import ForgeActor

# from forge.interfaces import RLLoss

# stubs for now
Expand Down Expand Up @@ -68,6 +71,10 @@ def __init__(self, config: DictConfig):
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())

# init torchstore
self._tstore = asyncio.run(MultiProcessStore.create_store())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the asyncio.run?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you need the asyncio.run?

This is because __init__ is not async?


self._init_dist()
super().__init__(job_config)

Expand Down Expand Up @@ -201,6 +208,15 @@ def train_step(self, batch) -> None:
# self.profiler.step()
self.current_step += 1

# save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
# TODOs:
# 1. Figure out if there is a value in calling state_dict_adatpr.to_hf()
# 2. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
# May need to replicate the same in this code path.
# 3. Integrate zero-overhead version of push_state_dict.
# 4. Figure out a way to notify the generator app that weights are ready. This beyond the initial integration success.
# 5. Unify CheckpointManager and TorchStore weights save control path.
push_state_dict(self._tstore, self.checkpointer.states, f"v{self.current_step}")
# if self.current_step % self.train_config.val_every_n_steps == 0:
# self.validate()
self.checkpointer.save(
Expand Down
Loading