Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from dataclasses import dataclass, field, fields

import torch
from forge.controller import ForgeActor
from monarch.actor import current_rank, current_size, endpoint

from torchstore import MultiProcessStore
from torchstore._state_dict_utils import push_state_dict
from torchtitan.config.job_config import (
ActivationCheckpoint,
Checkpoint,
Expand All @@ -25,19 +29,17 @@
Parallelism,
Training,
)

from torchtitan.distributed import utils as dist_utils
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.controller import ForgeActor

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


@dataclass
class RLTrainer(ForgeActor):
store: MultiProcessStore = field(default=None)
model: Model = field(default_factory=Model)
optimizer: Optimizer = field(default_factory=Optimizer)
lr_scheduler: LRScheduler = field(default_factory=LRScheduler)
Expand Down Expand Up @@ -68,7 +70,7 @@ def __post_init__(self):
f"{f.name} should be a {f.type} type or a dict like object"
)

self.current_step = 0
self.current_step = 1 # fragile contract.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we starting at 1? Also, we probably want a todo to update this from the checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because policy engine starting at 1. Lets keep this fragile contract as it is. The true version has to come from a config or external book-keeping entity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can change this without risking breaking checkpoint expectations from titan side. I'd rather just use a separate variable in the trainer for "checkpoint name" (can be a property that's just current_step + 1 for now). This could also be passed in from the controller which would be better.

self.num_training_steps = self.training.steps
self.gradient_accumulation_steps = 1
self.rank = current_rank().rank
Expand All @@ -91,7 +93,9 @@ def __post_init__(self):
@endpoint
async def setup(self):
# TODO: update ForgeEngine to not use ForgeJobConfig
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
engine_config = {
f.name: getattr(self, f.name) for f in fields(self) if f.name != "store"
}
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.current_step)
self.engine.optimizers.zero_grad()
Expand Down Expand Up @@ -262,7 +266,20 @@ def train_step(self, batch) -> None:

@endpoint
def push_weights(self) -> None:
pass
# save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
# TODOs:
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
# May need to replicate the same in this code path.
# 2. Unify CheckpointManager and TorchStore weights save control path.
assert self.store is not None, "TorchStore is not initialized"
print(
f"Getting keys from checkpointer state and pushing to TS {self.engine.checkpointer.states}"
)
push_state_dict(
self.store,
self.engine.checkpointer.states,
f"model_state_dict/{self.current_step}",
)

@endpoint
async def cleanup(self) -> None:
Expand Down
35 changes: 24 additions & 11 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@
import pytest_asyncio

import torch

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig

from forge.actors.trainer import RLTrainer
from forge.controller.service import ServiceConfig, spawn_service
from forge.data.sharding import VLLMSharding

from omegaconf import DictConfig, OmegaConf
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import push_state_dict
from transformers import AutoModelForCausalLM
Expand Down Expand Up @@ -179,7 +182,7 @@ def get_configs(
)

sampling_params = SamplingOverrides(
num_samples=3,
n=3,
guided_decoding=True,
)

Expand All @@ -193,6 +196,22 @@ def get_configs(
return policy_config, service_config


async def run_rl_trainer(store, worker_size) -> None:
"""
1. Spawn the trainer.
2. Inject torchstore references via setup call.
2. Call push weights.
"""
cfg: DictConfig = OmegaConf.load("apps/rl/llama3_8b.yaml")
rl_trainer = await spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
RLTrainer,
**{**cfg.trainer, "store": store},
)
# Push the weights to torchstore
await rl_trainer.push_weights.choose()


async def run_policy_integration(
store, original_state_dict, worker_size
) -> Dict[str, torch.Tensor]:
Expand Down Expand Up @@ -229,7 +248,7 @@ async def run_policy_integration(
@pytest_asyncio.fixture(scope="session")
async def llama3_torchstore_setup():
"""
Pytest fixture to load Llama 3.1 8B-Instruct and write state dict to torchstore.
Pytest fixture to load Llama 3.1 8B-Instruct. We use the loaded state dict as SOT for validation.
Uses session scope so it's only called once when both tests are run.
"""
print("=== PHASE 1: Writing Llama 3.1 8B-Instruct to TorchStore ===")
Expand All @@ -248,17 +267,9 @@ async def llama3_torchstore_setup():

original_state_dict = model.state_dict()
print(f"Original state dict has {len(original_state_dict)} parameters")

print("Converting transformers state dict to vLLM format...")
converted_state_dict = convert_state_dict(original_state_dict)
print(f"Converted state dict has {len(converted_state_dict)} parameters")

state_dict_key = "model_state_dict/1" # {app_namespace}/{version}
await save_state_dict(store, converted_state_dict, state_dict_key)
print(
f"Successfully wrote converted state dict to torchstore with key: {state_dict_key}"
)

return store, converted_state_dict


Expand All @@ -268,6 +279,8 @@ async def test_llama3_policy_update_single(llama3_torchstore_setup):
print("Starting Llama 3 8B torchstore test (single GPU)...")

store, original_state_dict = llama3_torchstore_setup
# store = MultiProcessStore.create_store()
await run_rl_trainer(store, worker_size=1)

loaded_state_dict = await run_policy_integration(
store, original_state_dict, worker_size=1
Expand Down
Loading