Skip to content

Commit 5f7cf3c

Browse files
committed
working code
1 parent bc58196 commit 5f7cf3c

File tree

2 files changed

+43
-11
lines changed

2 files changed

+43
-11
lines changed

src/forge/actors/policy.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -415,11 +415,11 @@ async def _load_tensor_parallel_state_dict(
415415
# setting explictly to llama3 for now as its our only use case
416416
sharding = VLLMSharding(self.tensor_parallel_size, self.rank)
417417

418-
key_str = ""
419-
for param_name in current_state_dict.keys():
420-
key_str += f" {self.state_dict_key}/{version}/{param_name}\n"
421-
logger.warning(f"############### policy get keys : {key_str}")
422-
return
418+
# key_str = ""
419+
# for param_name in current_state_dict.keys():
420+
# key_str += f" {self.state_dict_key}/{version}/{param_name}\n"
421+
# logger.warning(f"############### policy get keys : {key_str}")
422+
423423
for param_name in current_state_dict.keys():
424424
current_tensor = current_state_dict[param_name]
425425
# Load the full tensor from torchstore

src/forge/actors/trainer.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torchstore as ts
1717
from forge.controller import ForgeActor
1818
from monarch.actor import current_rank, current_size, endpoint
19+
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
1920
from torch.distributed.checkpoint.state_dict_saver import _stateful_to_state_dict
2021
from torchtitan.config.job_config import (
2122
ActivationCheckpoint,
@@ -270,16 +271,47 @@ async def push_weights(self) -> None:
270271
# 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
271272
# May need to replicate the same in this code path.
272273
# 2. Unify CheckpointManager and TorchStore weights save control path.
273-
print(
274-
f"Getting keys from checkpointer state and pushing to TS ..."
275-
)
274+
print(f"Getting keys from checkpointer state and pushing to TS ...")
276275
assert (
277276
"model" in self.engine.checkpointer.states
278277
), "Model state not found in checkpointer state"
278+
sd = self.engine.checkpointer.states["model"].state_dict()
279+
280+
flattened_state_dict, _ = flatten_state_dict(sd)
281+
# Save the state dict using HF format.
282+
# 1. Use the torch.titan adaptor's 'to_hf' routines to convert the state dict.
283+
# 2. Missing conversions ( QKV, MLP fusion) is done using custom code. Probably
284+
# we should move that code to 'to_hf' function.
285+
286+
assert (
287+
self.engine.checkpointer.sd_adapter is not None
288+
), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
289+
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
290+
291+
for i in range(32): # improve this using regex similar to to_hf function.
292+
prefix = f"model.layers.{i}."
293+
# QKV fusion
294+
q = hf_state_dict.pop(prefix + "self_attn.q_proj.weight")
295+
k = hf_state_dict.pop(prefix + "self_attn.k_proj.weight")
296+
v = hf_state_dict.pop(prefix + "self_attn.v_proj.weight")
297+
hf_state_dict[prefix + "self_attn.qkv_proj.weight"] = torch.cat(
298+
[q, k, v], dim=0
299+
)
300+
# MLP gate_up_proj fusion
301+
gate = hf_state_dict.pop(prefix + "mlp.gate_proj.weight")
302+
up = hf_state_dict.pop(prefix + "mlp.up_proj.weight")
303+
hf_state_dict[prefix + "mlp.gate_up_proj.weight"] = torch.cat(
304+
[gate, up], dim=0
305+
)
306+
307+
# Remove before landing
308+
# key_str = ""
309+
# for key, _ in hf_state_dict.items():
310+
# key_str += f" model_state_dict/{self.current_step}/{key}\n"
311+
# logger.warning(f"rltrainer, put_state_dict keys : {key_str}")
312+
279313
await ts.put_state_dict(
280-
state_dict=_stateful_to_state_dict(
281-
{"model": self.engine.checkpointer.states.pop("model")}
282-
),
314+
state_dict=hf_state_dict,
283315
key=f"model_state_dict/{self.current_step}",
284316
)
285317

0 commit comments

Comments
 (0)