-
Notifications
You must be signed in to change notification settings - Fork 18
Publishing weights in to torchstore from RLTrainer and getting them from policy engine. #138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 14 commits
b8632e5
60ca8c5
d9e1e84
66b969d
bc58196
5f7cf3c
32ac7f3
fe0d924
2793033
856661f
264ccc8
45eb52d
0d11d36
dd555e5
bcc0038
5a062be
1656d5d
825bd1c
a8b4514
8bc468d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,9 +10,15 @@ | |
import os | ||
from collections.abc import Mapping | ||
from dataclasses import dataclass, field, fields | ||
from typing import Any, Dict | ||
|
||
import torch | ||
|
||
import torchstore as ts | ||
from forge.controller import ForgeActor | ||
from monarch.actor import current_rank, current_size, endpoint | ||
from torch.distributed.checkpoint._nested_dict import flatten_state_dict | ||
from torch.distributed.checkpoint.state_dict_saver import _stateful_to_state_dict | ||
from torchtitan.config.job_config import ( | ||
ActivationCheckpoint, | ||
Checkpoint, | ||
|
@@ -25,13 +31,10 @@ | |
Parallelism, | ||
Training, | ||
) | ||
|
||
from torchtitan.distributed import utils as dist_utils | ||
from torchtitan.experiments.forge.engine import ForgeEngine | ||
from torchtitan.experiments.forge.job_config import ForgeJobConfig | ||
|
||
from forge.controller import ForgeActor | ||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
@@ -68,7 +71,7 @@ def __post_init__(self): | |
f"{f.name} should be a {f.type} type or a dict like object" | ||
) | ||
|
||
self.current_step = 0 | ||
self.current_step = 1 # fragile contract. | ||
self.num_training_steps = self.training.steps | ||
self.gradient_accumulation_steps = 1 | ||
self.rank = current_rank().rank | ||
|
@@ -261,10 +264,60 @@ def train_step(self, batch) -> None: | |
# return {"loss": avg_loss, "groups_processed": num_groups_processed} | ||
|
||
@endpoint | ||
def push_weights(self) -> None: | ||
pass | ||
async def push_weights(self) -> None: | ||
# save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now. | ||
# TODOs: | ||
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. | ||
# May need to replicate the same in this code path. | ||
# 2. Unify CheckpointManager and TorchStore weights save control path. | ||
print(f"Getting keys from checkpointer state and pushing to TS ...") | ||
joecummings marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
assert ( | ||
"model" in self.engine.checkpointer.states | ||
), "Model state not found in checkpointer state" | ||
|
||
sd = self.engine.checkpointer.states["model"].state_dict() | ||
|
||
|
||
flattened_state_dict, _ = flatten_state_dict(sd) | ||
# Save the state dict using HF format. | ||
# 1. Use the torch.titan adaptor's 'to_hf' routines to convert the state dict. | ||
# 2. Missing conversions ( QKV, MLP fusion) is done using custom code. Probably | ||
# we should move that code to 'to_hf' function. | ||
|
||
assert ( | ||
self.engine.checkpointer.sd_adapter is not None | ||
), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." | ||
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) | ||
|
||
vllm_ready_hf_sd = llama3_hf_to_vllm(hf_trainer_sd=hf_state_dict) | ||
|
||
await ts.put_state_dict( | ||
state_dict=vllm_ready_hf_sd, | ||
key=f"model_state_dict/{self.current_step}", | ||
) | ||
|
||
@endpoint | ||
async def cleanup(self) -> None: | ||
if self.engine.checkpointer: | ||
self.engine.checkpointer.close() | ||
|
||
|
||
def llama3_hf_to_vllm(hf_trainer_sd: Dict[str, Any]) -> Dict[str, Any]: | ||
""" | ||
Convert HF formatted state-dict to VLLM format. Ideally this conversion | ||
should not be needed, if the VLLM fully supports the loading of | ||
HF formatted llama3 model. | ||
""" | ||
for i in range(32): # number of layers in llama3 8B model. | ||
prefix = f"model.layers.{i}." | ||
# QKV fusion | ||
q = hf_trainer_sd.pop(prefix + "self_attn.q_proj.weight") | ||
k = hf_trainer_sd.pop(prefix + "self_attn.k_proj.weight") | ||
v = hf_trainer_sd.pop(prefix + "self_attn.v_proj.weight") | ||
hf_trainer_sd[prefix + "self_attn.qkv_proj.weight"] = torch.cat( | ||
[q, k, v], dim=0 | ||
) | ||
# MLP gate_up_proj fusion | ||
gate = hf_trainer_sd.pop(prefix + "mlp.gate_proj.weight") | ||
up = hf_trainer_sd.pop(prefix + "mlp.up_proj.weight") | ||
hf_trainer_sd[prefix + "mlp.gate_up_proj.weight"] = torch.cat([gate, up], dim=0) | ||
|
||
return hf_trainer_sd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we starting at 1? Also, we probably want a todo to update this from the checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because policy engine starting at 1. Lets keep this fragile contract as it is. The true version has to come from a config or external book-keeping entity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we can change this without risking breaking checkpoint expectations from titan side. I'd rather just use a separate variable in the trainer for "checkpoint name" (can be a property that's just current_step + 1 for now). This could also be passed in from the controller which would be better.