33from einops import rearrange
44import lightning as pl
55import pandas as pd
6- from diffsynth import WanVideoPipeline , ModelManager
6+ from diffsynth import WanVideoPipeline , ModelManager , load_state_dict
77from peft import LoraConfig , inject_adapter_in_model
88import torchvision
99from PIL import Image
@@ -145,7 +145,7 @@ def __len__(self):
145145
146146
147147class LightningModelForTrain (pl .LightningModule ):
148- def __init__ (self , dit_path , learning_rate = 1e-5 , lora_rank = 4 , lora_alpha = 4 , train_architecture = "lora" , lora_target_modules = "q,k,v,o,ffn.0,ffn.2" , init_lora_weights = "kaiming" , use_gradient_checkpointing = True ):
148+ def __init__ (self , dit_path , learning_rate = 1e-5 , lora_rank = 4 , lora_alpha = 4 , train_architecture = "lora" , lora_target_modules = "q,k,v,o,ffn.0,ffn.2" , init_lora_weights = "kaiming" , use_gradient_checkpointing = True , pretrained_lora_path = None ):
149149 super ().__init__ ()
150150 model_manager = ModelManager (torch_dtype = torch .bfloat16 , device = "cpu" )
151151 model_manager .load_models ([dit_path ])
@@ -160,6 +160,7 @@ def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, trai
160160 lora_alpha = lora_alpha ,
161161 lora_target_modules = lora_target_modules ,
162162 init_lora_weights = init_lora_weights ,
163+ pretrained_lora_path = pretrained_lora_path ,
163164 )
164165 else :
165166 self .pipe .denoising_model ().requires_grad_ (True )
@@ -175,7 +176,7 @@ def freeze_parameters(self):
175176 self .pipe .denoising_model ().train ()
176177
177178
178- def add_lora_to_model (self , model , lora_rank = 4 , lora_alpha = 4 , lora_target_modules = "q,k,v,o,ffn.0,ffn.2" , init_lora_weights = "kaiming" ):
179+ def add_lora_to_model (self , model , lora_rank = 4 , lora_alpha = 4 , lora_target_modules = "q,k,v,o,ffn.0,ffn.2" , init_lora_weights = "kaiming" , pretrained_lora_path = None , state_dict_converter = None ):
179180 # Add LoRA to UNet
180181 self .lora_alpha = lora_alpha
181182 if init_lora_weights == "kaiming" :
@@ -192,6 +193,17 @@ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_module
192193 # Upcast LoRA parameters into fp32
193194 if param .requires_grad :
194195 param .data = param .to (torch .float32 )
196+
197+ # Lora pretrained lora weights
198+ if pretrained_lora_path is not None :
199+ state_dict = load_state_dict (pretrained_lora_path )
200+ if state_dict_converter is not None :
201+ state_dict = state_dict_converter (state_dict )
202+ missing_keys , unexpected_keys = model .load_state_dict (state_dict , strict = False )
203+ all_keys = [i for i , _ in model .named_parameters ()]
204+ num_updated_keys = len (all_keys ) - len (missing_keys )
205+ num_unexpected_keys = len (unexpected_keys )
206+ print (f"{ num_updated_keys } parameters are loaded from { pretrained_lora_path } . { num_unexpected_keys } parameters are unexpected." )
195207
196208
197209 def training_step (self , batch , batch_idx ):
@@ -405,6 +417,12 @@ def parse_args():
405417 choices = ["lora" , "full" ],
406418 help = "Model structure to train. LoRA training or full training." ,
407419 )
420+ parser .add_argument (
421+ "--pretrained_lora_path" ,
422+ type = str ,
423+ default = None ,
424+ help = "Pretrained LoRA path. Required if the training is resumed." ,
425+ )
408426 args = parser .parse_args ()
409427 return args
410428
@@ -460,7 +478,8 @@ def train(args):
460478 lora_alpha = args .lora_alpha ,
461479 lora_target_modules = args .lora_target_modules ,
462480 init_lora_weights = args .init_lora_weights ,
463- use_gradient_checkpointing = args .use_gradient_checkpointing
481+ use_gradient_checkpointing = args .use_gradient_checkpointing ,
482+ pretrained_lora_path = args .pretrained_lora_path ,
464483 )
465484 trainer = pl .Trainer (
466485 max_epochs = args .max_epochs ,
0 commit comments