|
16 | 16 | import torchstore as ts |
17 | 17 | from forge.controller import ForgeActor |
18 | 18 | from monarch.actor import current_rank, current_size, endpoint |
| 19 | +from torch.distributed.checkpoint._nested_dict import flatten_state_dict |
19 | 20 | from torch.distributed.checkpoint.state_dict_saver import _stateful_to_state_dict |
20 | 21 | from torchtitan.config.job_config import ( |
21 | 22 | ActivationCheckpoint, |
@@ -270,16 +271,47 @@ async def push_weights(self) -> None: |
270 | 271 | # 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL]. |
271 | 272 | # May need to replicate the same in this code path. |
272 | 273 | # 2. Unify CheckpointManager and TorchStore weights save control path. |
273 | | - print( |
274 | | - f"Getting keys from checkpointer state and pushing to TS ..." |
275 | | - ) |
| 274 | + print(f"Getting keys from checkpointer state and pushing to TS ...") |
276 | 275 | assert ( |
277 | 276 | "model" in self.engine.checkpointer.states |
278 | 277 | ), "Model state not found in checkpointer state" |
| 278 | + sd = self.engine.checkpointer.states["model"].state_dict() |
| 279 | + |
| 280 | + flattened_state_dict, _ = flatten_state_dict(sd) |
| 281 | + # Save the state dict using HF format. |
| 282 | + # 1. Use the torch.titan adaptor's 'to_hf' routines to convert the state dict. |
| 283 | + # 2. Missing conversions ( QKV, MLP fusion) is done using custom code. Probably |
| 284 | + # we should move that code to 'to_hf' function. |
| 285 | + |
| 286 | + assert ( |
| 287 | + self.engine.checkpointer.sd_adapter is not None |
| 288 | + ), "trying to save checkpoint in HF safetensors format, but sd_adapter is not provided." |
| 289 | + hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict) |
| 290 | + |
| 291 | + for i in range(32): # improve this using regex similar to to_hf function. |
| 292 | + prefix = f"model.layers.{i}." |
| 293 | + # QKV fusion |
| 294 | + q = hf_state_dict.pop(prefix + "self_attn.q_proj.weight") |
| 295 | + k = hf_state_dict.pop(prefix + "self_attn.k_proj.weight") |
| 296 | + v = hf_state_dict.pop(prefix + "self_attn.v_proj.weight") |
| 297 | + hf_state_dict[prefix + "self_attn.qkv_proj.weight"] = torch.cat( |
| 298 | + [q, k, v], dim=0 |
| 299 | + ) |
| 300 | + # MLP gate_up_proj fusion |
| 301 | + gate = hf_state_dict.pop(prefix + "mlp.gate_proj.weight") |
| 302 | + up = hf_state_dict.pop(prefix + "mlp.up_proj.weight") |
| 303 | + hf_state_dict[prefix + "mlp.gate_up_proj.weight"] = torch.cat( |
| 304 | + [gate, up], dim=0 |
| 305 | + ) |
| 306 | + |
| 307 | + # Remove before landing |
| 308 | + # key_str = "" |
| 309 | + # for key, _ in hf_state_dict.items(): |
| 310 | + # key_str += f" model_state_dict/{self.current_step}/{key}\n" |
| 311 | + # logger.warning(f"rltrainer, put_state_dict keys : {key_str}") |
| 312 | + |
279 | 313 | await ts.put_state_dict( |
280 | | - state_dict=_stateful_to_state_dict( |
281 | | - {"model": self.engine.checkpointer.states.pop("model")} |
282 | | - ), |
| 314 | + state_dict=hf_state_dict, |
283 | 315 | key=f"model_state_dict/{self.current_step}", |
284 | 316 | ) |
285 | 317 |
|
|
0 commit comments