Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 12 additions & 20 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
from typing import Dict, List

import torch
import torchstore as ts

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import DELIM

from vllm.engine.arg_utils import EngineArgs
from vllm.entrypoints.utils import _validate_truncation_size
Expand All @@ -37,12 +42,6 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig


logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -108,7 +107,6 @@ class Policy(PolicyInterface):
lora_request: LoRARequest | None = None
tokenization_kwargs: dict = field(default_factory=dict)
policy_worker: "PolicyWorker" = None
store: MultiProcessStore | None = None

def __post_init__(self):
self._run_task: asyncio.Task | None = None
Expand All @@ -122,7 +120,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
*,
process_config: ProcessConfig,
config: PolicyConfig,
store: MultiProcessStore | None = None,
**kwargs,
) -> "Policy":
# Note - get_proc_mesh will set MASTER_ADDR, MASTER_PORT and CUDA_VISIBLE_DEVICES
Expand All @@ -146,7 +143,6 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls,
config=config,
policy_worker=workers,
store=store,
)
policy._policy_proc = policy_proc
policy._worker_procs = worker_procs
Expand Down Expand Up @@ -174,7 +170,7 @@ async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride]
async def setup(self):
# Set up policy_worker
assert self.policy_worker is not None, "Policy worker should not be None"
await self.policy_worker.setup.call(store=self.store)
await self.policy_worker.setup.call()

self.request_id = 0
self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
Expand Down Expand Up @@ -397,8 +393,7 @@ def __post_init__(self):
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS)

@endpoint
async def setup(self, store: MultiProcessStore = None):
self.torchstore = store
async def setup(self):
# TODO: remove ["gpus"] when monarch implements a flat rank
self.rank = current_rank()["gpus"]
self.worker = self.setup_worker()
Expand All @@ -420,11 +415,10 @@ async def _load_tensor_parallel_state_dict(

for param_name in current_state_dict.keys():
current_tensor = current_state_dict[param_name]

# Load the full tensor from torchstore
# TODO: only get the part of the tensor that is needed
stored_tensor = await self.torchstore.get(
f"{self.state_dict_key}{DELIM}{version}{DELIM}{param_name}"
stored_tensor = await ts.get(
f"{self.state_dict_key}/{version}/{param_name}"
)
sharding.load_from_source_to_target(
param_name,
Expand All @@ -437,11 +431,9 @@ async def _load_tensor_parallel_state_dict(
@endpoint
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""
if self.torchstore is None:
raise Exception("No torchstore configured, skipping model update")

logger.debug(
f"Starting model update from torchstore with key: {self.state_dict_key}{DELIM}{version}"
f"Starting model update from torchstore with key: {self.state_dict_key}/{version}"
)

model = self.worker.model_runner.model
Expand Down
59 changes: 52 additions & 7 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
from dataclasses import dataclass, field, fields

import torch

import torchstore as ts
from forge.controller import ForgeActor
from monarch.actor import current_rank, current_size, endpoint
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
from torch.distributed.checkpoint.state_dict_saver import _stateful_to_state_dict
from torchtitan.config.job_config import (
ActivationCheckpoint,
Checkpoint,
Expand All @@ -25,13 +30,10 @@
Parallelism,
Training,
)

from torchtitan.distributed import utils as dist_utils
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.controller import ForgeActor

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -68,7 +70,7 @@ def __post_init__(self):
f"{f.name} should be a {f.type} type or a dict like object"
)

self.current_step = 0
self.current_step = 1 # fragile contract.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we starting at 1? Also, we probably want a todo to update this from the checkpoint

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because policy engine starting at 1. Lets keep this fragile contract as it is. The true version has to come from a config or external book-keeping entity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we can change this without risking breaking checkpoint expectations from titan side. I'd rather just use a separate variable in the trainer for "checkpoint name" (can be a property that's just current_step + 1 for now). This could also be passed in from the controller which would be better.

self.num_training_steps = self.training.steps
self.gradient_accumulation_steps = 1
self.rank = current_rank().rank
Expand All @@ -91,7 +93,9 @@ def __post_init__(self):
@endpoint
async def setup(self):
# TODO: update ForgeEngine to not use ForgeJobConfig
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
engine_config = {
f.name: getattr(self, f.name) for f in fields(self) if f.name != "store"
}
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.current_step)
self.engine.optimizers.zero_grad()
Expand Down Expand Up @@ -261,8 +265,49 @@ def train_step(self, batch) -> None:
# return {"loss": avg_loss, "groups_processed": num_groups_processed}

@endpoint
def push_weights(self) -> None:
pass
async def push_weights(self) -> None:
# save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
# TODOs:
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
# May need to replicate the same in this code path.
# 2. Unify CheckpointManager and TorchStore weights save control path.
print(f"Getting keys from checkpointer state and pushing to TS ...")
assert (
"model" in self.engine.checkpointer.states
), "Model state not found in checkpointer state"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When would "model" not be in the self.engine.checkpointer.states? In other words, can we update the assertion error to be more informative? Does this fail if the user didn't initialize the trainer properly/what do they need to do to make it work??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Essentially, this only happens if the checkpoint_manager of torchtitan is not initialized prior to calling push_weights routine. I can update the error message (followup PR) with that.

sd = self.engine.checkpointer.states["model"].state_dict()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is this coming from? When you call this, does it create the sd right then or did it have to be saved in the train step earlier? Does it return the sd on GPU or CPU? Also does it handle blocking the trainer from updating the weights while it's getting them?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm accessing the module state-dict prepped by torch.titan as part of checkpoint save.

  1. This is a in-memory state-dict. ( Tensor/DTensor).
  2. It returns tensors with original storage. Means GPU/UVM backed tensors.

Also does it handle blocking the trainer from updating the weights while it's getting them?

Hmm.. it does not block the trainer. However, ForgeEngine drive the trainer using train_step. Therefore, there is no race-conditions with current code.

There is improvements to be made to this code. In the ideal case;

  1. the state-dict get prepped for weight-exchange and checkpoint save purposes.
  2. Once the initial state-dict prep we can cache the prepped state-dict for later iterations of the training steps for efficiency reasons ( if there is opportunity).
  3. We move all the model weights and optimizer state to torchstore.
  4. Policy engine (only) lookup the model-weights from torchstore
  5. Async checkpointing upload lookups model-weights and optimizer states for uploading in to remote persistent storage.

We don't have all the piece right now. But tapping in to checkpoint state-dict is the right thing to do as the first step.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess you're right that it should be mostly safe since we control the update from the controller. But since they're async calls they could be overlapped so we'll have to be careful for now.


flattened_state_dict, _ = flatten_state_dict(sd)
# Save the state dict using HF format.
# 1. Use the torch.titan adaptor's 'to_hf' routines to convert the state dict.
# 2. Missing conversions ( QKV, MLP fusion) is done using custom code. Probably
# we should move that code to 'to_hf' function.

assert (
self.engine.checkpointer.sd_adapter is not None
), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)

for i in range(32): # improve this using regex similar to to_hf function.
prefix = f"model.layers.{i}."
# QKV fusion
q = hf_state_dict.pop(prefix + "self_attn.q_proj.weight")
k = hf_state_dict.pop(prefix + "self_attn.k_proj.weight")
v = hf_state_dict.pop(prefix + "self_attn.v_proj.weight")
hf_state_dict[prefix + "self_attn.qkv_proj.weight"] = torch.cat(
[q, k, v], dim=0
)
# MLP gate_up_proj fusion
gate = hf_state_dict.pop(prefix + "mlp.gate_proj.weight")
up = hf_state_dict.pop(prefix + "mlp.up_proj.weight")
hf_state_dict[prefix + "mlp.gate_up_proj.weight"] = torch.cat(
[gate, up], dim=0
)

await ts.put_state_dict(
state_dict=hf_state_dict,
key=f"model_state_dict/{self.current_step}",
)

@endpoint
async def cleanup(self) -> None:
Expand Down
Loading