Skip to content

Commit 1cad3c7

Browse files
committed
updates
1 parent c68e5bd commit 1cad3c7

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

scripts/extract_lora_from_model.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,25 @@ def parse_args():
8181
required=True,
8282
help="Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub.",
8383
)
84+
parser.add_argument(
85+
"--base_subfolder",
86+
default="transformer",
87+
type=str,
88+
help="subfolder to load the base checkpoint from if any.",
89+
)
8490
parser.add_argument(
8591
"--finetune_ckpt_path",
8692
default=None,
8793
type=str,
8894
required=True,
8995
help="Fully fine-tuned checkpoint path. Can be a model ID on the Hub.",
9096
)
97+
parser.add_argument(
98+
"--finetune_subfolder",
99+
default=None,
100+
type=str,
101+
help="subfolder to load the fulle finetuned checkpoint from if any.",
102+
)
91103
parser.add_argument("--rank", default=64, type=int)
92104
parser.add_argument("--lora_out_path", default=None, type=str, required=True)
93105
args = parser.parse_args()
@@ -100,14 +112,14 @@ def parse_args():
100112

101113
@torch.no_grad()
102114
def main(args):
103-
# Fully fine-tuned checkpoints usually don't have any other components. So, we
104-
# don't need the `subfolder`. You can add that if needed.
105-
model_finetuned = CogVideoXTransformer3DModel.from_pretrained(args.finetune_ckpt_path, torch_dtype=torch.bfloat16)
115+
model_finetuned = CogVideoXTransformer3DModel.from_pretrained(
116+
args.finetune_ckpt_path, subfolder=args.finetune_subfolder, torch_dtype=torch.bfloat16
117+
)
106118
state_dict_ft = model_finetuned.state_dict()
107119

108120
# Change the `subfolder` as needed.
109121
base_model = CogVideoXTransformer3DModel.from_pretrained(
110-
args.base_ckpt_path, subfolder="transformer", torch_dtype=torch.bfloat16
122+
args.base_ckpt_path, subfolder=args.base_subfolder, torch_dtype=torch.bfloat16
111123
)
112124
state_dict = base_model.state_dict()
113125
output_dict = {}
@@ -135,4 +147,5 @@ def main(args):
135147

136148

137149
if __name__ == "__main__":
138-
main()
150+
args = parse_args()
151+
main(args)

0 commit comments

Comments
 (0)