| 
 | 1 | +"""  | 
 | 2 | +This script demonstrates how to extract a LoRA checkpoint from a fully finetuned model with the CogVideoX model.  | 
 | 3 | +
  | 
 | 4 | +To make it work for other models:  | 
 | 5 | +
  | 
 | 6 | +* Change the model class. Here we use `CogVideoXTransformer3DModel`. For Flux, it would be `FluxTransformer2DModel`,  | 
 | 7 | +for example. (TODO: more reason to add `AutoModel`).  | 
 | 8 | +* Spply path to the base checkpoint via `base_ckpt_path`.  | 
 | 9 | +* Supply path to the fully fine-tuned checkpoint via `--finetune_ckpt_path`.  | 
 | 10 | +* Change the `--rank` as needed.  | 
 | 11 | +
  | 
 | 12 | +Example usage:  | 
 | 13 | +
  | 
 | 14 | +```bash  | 
 | 15 | +python extract_lora_from_model.py \  | 
 | 16 | +    --base_ckpt_path=THUDM/CogVideoX-5b \  | 
 | 17 | +    --finetune_ckpt_path=finetrainers/cakeify-v0 \  | 
 | 18 | +    --lora_out_path=cakeify_lora.safetensors  | 
 | 19 | +```  | 
 | 20 | +
  | 
 | 21 | +Script is adapted from  | 
 | 22 | +https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py  | 
 | 23 | +"""  | 
 | 24 | + | 
 | 25 | +import argparse  | 
 | 26 | + | 
 | 27 | +import torch  | 
 | 28 | +from safetensors.torch import save_file  | 
 | 29 | +from tqdm.auto import tqdm  | 
 | 30 | + | 
 | 31 | +from diffusers import CogVideoXTransformer3DModel  | 
 | 32 | + | 
 | 33 | + | 
 | 34 | +RANK = 64  | 
 | 35 | +CLAMP_QUANTILE = 0.99  | 
 | 36 | + | 
 | 37 | + | 
 | 38 | +# Comes from  | 
 | 39 | +# https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9  | 
 | 40 | +def extract_lora(diff, rank):  | 
 | 41 | +    # Important to use CUDA otherwise, very slow!  | 
 | 42 | +    if torch.cuda.is_available():  | 
 | 43 | +        diff = diff.to("cuda")  | 
 | 44 | + | 
 | 45 | +    is_conv2d = len(diff.shape) == 4  | 
 | 46 | +    kernel_size = None if not is_conv2d else diff.size()[2:4]  | 
 | 47 | +    is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1)  | 
 | 48 | +    out_dim, in_dim = diff.size()[0:2]  | 
 | 49 | +    rank = min(rank, in_dim, out_dim)  | 
 | 50 | + | 
 | 51 | +    if is_conv2d:  | 
 | 52 | +        if is_conv2d_3x3:  | 
 | 53 | +            diff = diff.flatten(start_dim=1)  | 
 | 54 | +        else:  | 
 | 55 | +            diff = diff.squeeze()  | 
 | 56 | + | 
 | 57 | +    U, S, Vh = torch.linalg.svd(diff.float())  | 
 | 58 | +    U = U[:, :rank]  | 
 | 59 | +    S = S[:rank]  | 
 | 60 | +    U = U @ torch.diag(S)  | 
 | 61 | +    Vh = Vh[:rank, :]  | 
 | 62 | + | 
 | 63 | +    dist = torch.cat([U.flatten(), Vh.flatten()])  | 
 | 64 | +    hi_val = torch.quantile(dist, CLAMP_QUANTILE)  | 
 | 65 | +    low_val = -hi_val  | 
 | 66 | + | 
 | 67 | +    U = U.clamp(low_val, hi_val)  | 
 | 68 | +    Vh = Vh.clamp(low_val, hi_val)  | 
 | 69 | +    if is_conv2d:  | 
 | 70 | +        U = U.reshape(out_dim, rank, 1, 1)  | 
 | 71 | +        Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1])  | 
 | 72 | +    return (U.cpu(), Vh.cpu())  | 
 | 73 | + | 
 | 74 | + | 
 | 75 | +def parse_args():  | 
 | 76 | +    parser = argparse.ArgumentParser()  | 
 | 77 | +    parser.add_argument(  | 
 | 78 | +        "--base_ckpt_path",  | 
 | 79 | +        default=None,  | 
 | 80 | +        type=str,  | 
 | 81 | +        required=True,  | 
 | 82 | +        help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.",  | 
 | 83 | +    )  | 
 | 84 | +    parser.add_argument(  | 
 | 85 | +        "--finetune_ckpt_path",  | 
 | 86 | +        default=None,  | 
 | 87 | +        type=str,  | 
 | 88 | +        required=True,  | 
 | 89 | +        help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.",  | 
 | 90 | +    )  | 
 | 91 | +    parser.add_argument("--rank", default=64, type=int)  | 
 | 92 | +    parser.add_argument("--lora_out_path", default=None, type=str, required=True)  | 
 | 93 | +    args = parser.parse_args()  | 
 | 94 | + | 
 | 95 | +    if not args.lora_out_path.endswith(".safetensors"):  | 
 | 96 | +        raise ValueError("`lora_out_path` must end with `.safetensors`.")  | 
 | 97 | + | 
 | 98 | +    return args  | 
 | 99 | + | 
 | 100 | + | 
 | 101 | +@torch.no_grad()  | 
 | 102 | +def main(args):  | 
 | 103 | +    # Fully fine-tuned checkpoints usually don't have any other components. So, we  | 
 | 104 | +    # don't need the `subfolder`. You can add that if needed.  | 
 | 105 | +    model_finetuned = CogVideoXTransformer3DModel.from_pretrained(args.finetune_ckpt_path, torch_dtype=torch.bfloat16)  | 
 | 106 | +    state_dict_ft = model_finetuned.state_dict()  | 
 | 107 | + | 
 | 108 | +    # Change the `subfolder` as needed.  | 
 | 109 | +    base_model = CogVideoXTransformer3DModel.from_pretrained(  | 
 | 110 | +        args.base_ckpt_path, subfolder="transformer", torch_dtype=torch.bfloat16  | 
 | 111 | +    )  | 
 | 112 | +    state_dict = base_model.state_dict()  | 
 | 113 | +    output_dict = {}  | 
 | 114 | + | 
 | 115 | +    for k in tqdm(state_dict, desc="Extracting LoRA..."):  | 
 | 116 | +        original_param = state_dict[k]  | 
 | 117 | +        finetuned_param = state_dict_ft[k]  | 
 | 118 | +        if len(original_param.shape) >= 2:  | 
 | 119 | +            diff = finetuned_param.float() - original_param.float()  | 
 | 120 | +            out = extract_lora(diff, RANK)  | 
 | 121 | +            name = k  | 
 | 122 | + | 
 | 123 | +            if name.endswith(".weight"):  | 
 | 124 | +                name = name[: -len(".weight")]  | 
 | 125 | +            down_key = "{}.lora_A.weight".format(name)  | 
 | 126 | +            up_key = "{}.lora_B.weight".format(name)  | 
 | 127 | + | 
 | 128 | +            output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype)  | 
 | 129 | +            output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype)  | 
 | 130 | + | 
 | 131 | +    prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet"  | 
 | 132 | +    output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()}  | 
 | 133 | +    save_file(output_dict, args.lora_out_path)  | 
 | 134 | +    print(f"LoRA saved and it contains {len(output_dict)} keys.")  | 
 | 135 | + | 
 | 136 | + | 
 | 137 | +if __name__ == "__main__":  | 
 | 138 | +    main()  | 
0 commit comments