|  | 
|  | 1 | +""" | 
|  | 2 | +This script demonstrates how to extract a LoRA checkpoint from a fully finetuned model with the CogVideoX model. | 
|  | 3 | +
 | 
|  | 4 | +To make it work for other models: | 
|  | 5 | +
 | 
|  | 6 | +* Change the model class. Here we use `CogVideoXTransformer3DModel`. For Flux, it would be `FluxTransformer2DModel`, | 
|  | 7 | +for example. (TODO: more reason to add `AutoModel`). | 
|  | 8 | +* Spply path to the base checkpoint via `base_ckpt_path`. | 
|  | 9 | +* Supply path to the fully fine-tuned checkpoint via `--finetune_ckpt_path`. | 
|  | 10 | +* Change the `--rank` as needed. | 
|  | 11 | +
 | 
|  | 12 | +Example usage: | 
|  | 13 | +
 | 
|  | 14 | +```bash | 
|  | 15 | +python extract_lora_from_model.py \ | 
|  | 16 | +    --base_ckpt_path=THUDM/CogVideoX-5b \ | 
|  | 17 | +    --finetune_ckpt_path=finetrainers/cakeify-v0 \ | 
|  | 18 | +    --lora_out_path=cakeify_lora.safetensors | 
|  | 19 | +``` | 
|  | 20 | +
 | 
|  | 21 | +Script is adapted from | 
|  | 22 | +https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py | 
|  | 23 | +""" | 
|  | 24 | + | 
|  | 25 | +import argparse | 
|  | 26 | + | 
|  | 27 | +import torch | 
|  | 28 | +from safetensors.torch import save_file | 
|  | 29 | +from tqdm.auto import tqdm | 
|  | 30 | + | 
|  | 31 | +from diffusers import CogVideoXTransformer3DModel | 
|  | 32 | + | 
|  | 33 | + | 
|  | 34 | +RANK = 64 | 
|  | 35 | +CLAMP_QUANTILE = 0.99 | 
|  | 36 | + | 
|  | 37 | + | 
|  | 38 | +# Comes from | 
|  | 39 | +# https://github.com/Stability-AI/stability-ComfyUI-nodes/blob/001154622564b17223ce0191803c5fff7b87146c/control_lora_create.py#L9 | 
|  | 40 | +def extract_lora(diff, rank): | 
|  | 41 | +    # Important to use CUDA otherwise, very slow! | 
|  | 42 | +    if torch.cuda.is_available(): | 
|  | 43 | +        diff = diff.to("cuda") | 
|  | 44 | + | 
|  | 45 | +    is_conv2d = len(diff.shape) == 4 | 
|  | 46 | +    kernel_size = None if not is_conv2d else diff.size()[2:4] | 
|  | 47 | +    is_conv2d_3x3 = is_conv2d and kernel_size != (1, 1) | 
|  | 48 | +    out_dim, in_dim = diff.size()[0:2] | 
|  | 49 | +    rank = min(rank, in_dim, out_dim) | 
|  | 50 | + | 
|  | 51 | +    if is_conv2d: | 
|  | 52 | +        if is_conv2d_3x3: | 
|  | 53 | +            diff = diff.flatten(start_dim=1) | 
|  | 54 | +        else: | 
|  | 55 | +            diff = diff.squeeze() | 
|  | 56 | + | 
|  | 57 | +    U, S, Vh = torch.linalg.svd(diff.float()) | 
|  | 58 | +    U = U[:, :rank] | 
|  | 59 | +    S = S[:rank] | 
|  | 60 | +    U = U @ torch.diag(S) | 
|  | 61 | +    Vh = Vh[:rank, :] | 
|  | 62 | + | 
|  | 63 | +    dist = torch.cat([U.flatten(), Vh.flatten()]) | 
|  | 64 | +    hi_val = torch.quantile(dist, CLAMP_QUANTILE) | 
|  | 65 | +    low_val = -hi_val | 
|  | 66 | + | 
|  | 67 | +    U = U.clamp(low_val, hi_val) | 
|  | 68 | +    Vh = Vh.clamp(low_val, hi_val) | 
|  | 69 | +    if is_conv2d: | 
|  | 70 | +        U = U.reshape(out_dim, rank, 1, 1) | 
|  | 71 | +        Vh = Vh.reshape(rank, in_dim, kernel_size[0], kernel_size[1]) | 
|  | 72 | +    return (U.cpu(), Vh.cpu()) | 
|  | 73 | + | 
|  | 74 | + | 
|  | 75 | +def parse_args(): | 
|  | 76 | +    parser = argparse.ArgumentParser() | 
|  | 77 | +    parser.add_argument( | 
|  | 78 | +        "--base_ckpt_path", | 
|  | 79 | +        default=None, | 
|  | 80 | +        type=str, | 
|  | 81 | +        required=True, | 
|  | 82 | +        help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.", | 
|  | 83 | +    ) | 
|  | 84 | +    parser.add_argument( | 
|  | 85 | +        "--base_subfolder", | 
|  | 86 | +        default="transformer", | 
|  | 87 | +        type=str, | 
|  | 88 | +        help="subfolder to load the base checkpoint from if any.", | 
|  | 89 | +    ) | 
|  | 90 | +    parser.add_argument( | 
|  | 91 | +        "--finetune_ckpt_path", | 
|  | 92 | +        default=None, | 
|  | 93 | +        type=str, | 
|  | 94 | +        required=True, | 
|  | 95 | +        help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.", | 
|  | 96 | +    ) | 
|  | 97 | +    parser.add_argument( | 
|  | 98 | +        "--finetune_subfolder", | 
|  | 99 | +        default=None, | 
|  | 100 | +        type=str, | 
|  | 101 | +        help="subfolder to load the fulle finetuned checkpoint from if any.", | 
|  | 102 | +    ) | 
|  | 103 | +    parser.add_argument("--rank", default=64, type=int) | 
|  | 104 | +    parser.add_argument("--lora_out_path", default=None, type=str, required=True) | 
|  | 105 | +    args = parser.parse_args() | 
|  | 106 | + | 
|  | 107 | +    if not args.lora_out_path.endswith(".safetensors"): | 
|  | 108 | +        raise ValueError("`lora_out_path` must end with `.safetensors`.") | 
|  | 109 | + | 
|  | 110 | +    return args | 
|  | 111 | + | 
|  | 112 | + | 
|  | 113 | +@torch.no_grad() | 
|  | 114 | +def main(args): | 
|  | 115 | +    model_finetuned = CogVideoXTransformer3DModel.from_pretrained( | 
|  | 116 | +        args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16 | 
|  | 117 | +    ) | 
|  | 118 | +    state_dict_ft = model_finetuned.state_dict() | 
|  | 119 | + | 
|  | 120 | +    # Change the `subfolder` as needed. | 
|  | 121 | +    base_model = CogVideoXTransformer3DModel.from_pretrained( | 
|  | 122 | +        args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16 | 
|  | 123 | +    ) | 
|  | 124 | +    state_dict = base_model.state_dict() | 
|  | 125 | +    output_dict = {} | 
|  | 126 | + | 
|  | 127 | +    for k in tqdm(state_dict, desc="Extracting LoRA..."): | 
|  | 128 | +        original_param = state_dict[k] | 
|  | 129 | +        finetuned_param = state_dict_ft[k] | 
|  | 130 | +        if len(original_param.shape) >= 2: | 
|  | 131 | +            diff = finetuned_param.float() - original_param.float() | 
|  | 132 | +            out = extract_lora(diff, RANK) | 
|  | 133 | +            name = k | 
|  | 134 | + | 
|  | 135 | +            if name.endswith(".weight"): | 
|  | 136 | +                name = name[: -len(".weight")] | 
|  | 137 | +            down_key = "{}.lora_A.weight".format(name) | 
|  | 138 | +            up_key = "{}.lora_B.weight".format(name) | 
|  | 139 | + | 
|  | 140 | +            output_dict[up_key] = out[0].contiguous().to(finetuned_param.dtype) | 
|  | 141 | +            output_dict[down_key] = out[1].contiguous().to(finetuned_param.dtype) | 
|  | 142 | + | 
|  | 143 | +    prefix = "transformer" if "transformer" in base_model.__class__.__name__.lower() else "unet" | 
|  | 144 | +    output_dict = {f"{prefix}.{k}": v for k, v in output_dict.items()} | 
|  | 145 | +    save_file(output_dict, args.lora_out_path) | 
|  | 146 | +    print(f"LoRA saved and it contains {len(output_dict)} keys.") | 
|  | 147 | + | 
|  | 148 | + | 
|  | 149 | +if __name__ == "__main__": | 
|  | 150 | +    args = parse_args() | 
|  | 151 | +    main(args) | 
0 commit comments