@@ -81,13 +81,25 @@ def parse_args():
8181 required = True ,
8282 help = "Base checkpoint path from which the model was finetuned. Can be a model ID on the Hub." ,
8383 )
84+ parser .add_argument (
85+ "--base_subfolder" ,
86+ default = "transformer" ,
87+ type = str ,
88+ help = "subfolder to load the base checkpoint from if any." ,
89+ )
8490 parser .add_argument (
8591 "--finetune_ckpt_path" ,
8692 default = None ,
8793 type = str ,
8894 required = True ,
8995 help = "Fully fine-tuned checkpoint path. Can be a model ID on the Hub." ,
9096 )
97+ parser .add_argument (
98+ "--finetune_subfolder" ,
99+ default = None ,
100+ type = str ,
101+ help = "subfolder to load the fulle finetuned checkpoint from if any." ,
102+ )
91103 parser .add_argument ("--rank" , default = 64 , type = int )
92104 parser .add_argument ("--lora_out_path" , default = None , type = str , required = True )
93105 args = parser .parse_args ()
@@ -100,14 +112,14 @@ def parse_args():
100112
101113@torch .no_grad ()
102114def main (args ):
103- # Fully fine-tuned checkpoints usually don't have any other components. So, we
104- # don't need the ` subfolder`. You can add that if needed.
105- model_finetuned = CogVideoXTransformer3DModel . from_pretrained ( args . finetune_ckpt_path , torch_dtype = torch . bfloat16 )
115+ model_finetuned = CogVideoXTransformer3DModel . from_pretrained (
116+ args . finetune_ckpt_path , subfolder = args . finetune_subfolder , torch_dtype = torch . bfloat16
117+ )
106118 state_dict_ft = model_finetuned .state_dict ()
107119
108120 # Change the `subfolder` as needed.
109121 base_model = CogVideoXTransformer3DModel .from_pretrained (
110- args .base_ckpt_path , subfolder = "transformer" , torch_dtype = torch .bfloat16
122+ args .base_ckpt_path , subfolder = args . base_subfolder , torch_dtype = torch .bfloat16
111123 )
112124 state_dict = base_model .state_dict ()
113125 output_dict = {}
@@ -135,4 +147,5 @@ def main(args):
135147
136148
137149if __name__ == "__main__" :
138- main ()
150+ args = parse_args ()
151+ main (args )
0 commit comments