|  | 
|  | 1 | +import json | 
|  | 2 | +import os | 
|  | 3 | + | 
|  | 4 | +import torch | 
|  | 5 | +from torch.distributed.fsdp import FullStateDictConfig | 
|  | 6 | +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP | 
|  | 7 | +from torch.distributed.fsdp import StateDictType | 
|  | 8 | + | 
|  | 9 | +from fastvideo.v1.logger import init_logger | 
|  | 10 | + | 
|  | 11 | +logger = init_logger(__name__) | 
|  | 12 | + | 
|  | 13 | + | 
|  | 14 | +def save_checkpoint(transformer, rank, output_dir, step): | 
|  | 15 | +    # Configure FSDP to save full state dict | 
|  | 16 | +    FSDP.set_state_dict_type( | 
|  | 17 | +        transformer, | 
|  | 18 | +        state_dict_type=StateDictType.FULL_STATE_DICT, | 
|  | 19 | +        state_dict_config=FullStateDictConfig(offload_to_cpu=True, | 
|  | 20 | +                                              rank0_only=True), | 
|  | 21 | +    ) | 
|  | 22 | + | 
|  | 23 | +    # Now get the state dict | 
|  | 24 | +    cpu_state = transformer.state_dict() | 
|  | 25 | + | 
|  | 26 | +    # Save it (only on rank 0 since we used rank0_only=True) | 
|  | 27 | +    if rank <= 0: | 
|  | 28 | +        save_dir = os.path.join(output_dir, f"checkpoint-{step}") | 
|  | 29 | +        os.makedirs(save_dir, exist_ok=True) | 
|  | 30 | +        weight_path = os.path.join(save_dir, "diffusion_pytorch_model.pt") | 
|  | 31 | +        torch.save(cpu_state, weight_path) | 
|  | 32 | +        config_dict = transformer.hf_config | 
|  | 33 | +        if "dtype" in config_dict: | 
|  | 34 | +            del config_dict["dtype"]  # TODO | 
|  | 35 | +        config_path = os.path.join(save_dir, "config.json") | 
|  | 36 | +        # save dict as json | 
|  | 37 | +        with open(config_path, "w") as f: | 
|  | 38 | +            json.dump(config_dict, f, indent=4) | 
|  | 39 | +    logger.info("--> checkpoint saved at step {step} to {weight_path}", | 
|  | 40 | +                step=step, | 
|  | 41 | +                weight_path=weight_path) | 
0 commit comments