Skip to content

Commit 91f77d2

Browse files
authored
Merge pull request #393 from modelscope/wan-train-update
support resume training
2 parents ee4b022 + eb4d518 commit 91f77d2

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

examples/wanvideo/train_wan_t2v.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from einops import rearrange
44
import lightning as pl
55
import pandas as pd
6-
from diffsynth import WanVideoPipeline, ModelManager
6+
from diffsynth import WanVideoPipeline, ModelManager, load_state_dict
77
from peft import LoraConfig, inject_adapter_in_model
88
import torchvision
99
from PIL import Image
@@ -145,7 +145,7 @@ def __len__(self):
145145

146146

147147
class LightningModelForTrain(pl.LightningModule):
148-
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True):
148+
def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, train_architecture="lora", lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", use_gradient_checkpointing=True, pretrained_lora_path=None):
149149
super().__init__()
150150
model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cpu")
151151
model_manager.load_models([dit_path])
@@ -160,6 +160,7 @@ def __init__(self, dit_path, learning_rate=1e-5, lora_rank=4, lora_alpha=4, trai
160160
lora_alpha=lora_alpha,
161161
lora_target_modules=lora_target_modules,
162162
init_lora_weights=init_lora_weights,
163+
pretrained_lora_path=pretrained_lora_path,
163164
)
164165
else:
165166
self.pipe.denoising_model().requires_grad_(True)
@@ -175,7 +176,7 @@ def freeze_parameters(self):
175176
self.pipe.denoising_model().train()
176177

177178

178-
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming"):
179+
def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_modules="q,k,v,o,ffn.0,ffn.2", init_lora_weights="kaiming", pretrained_lora_path=None, state_dict_converter=None):
179180
# Add LoRA to UNet
180181
self.lora_alpha = lora_alpha
181182
if init_lora_weights == "kaiming":
@@ -192,6 +193,17 @@ def add_lora_to_model(self, model, lora_rank=4, lora_alpha=4, lora_target_module
192193
# Upcast LoRA parameters into fp32
193194
if param.requires_grad:
194195
param.data = param.to(torch.float32)
196+
197+
# Lora pretrained lora weights
198+
if pretrained_lora_path is not None:
199+
state_dict = load_state_dict(pretrained_lora_path)
200+
if state_dict_converter is not None:
201+
state_dict = state_dict_converter(state_dict)
202+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
203+
all_keys = [i for i, _ in model.named_parameters()]
204+
num_updated_keys = len(all_keys) - len(missing_keys)
205+
num_unexpected_keys = len(unexpected_keys)
206+
print(f"{num_updated_keys} parameters are loaded from {pretrained_lora_path}. {num_unexpected_keys} parameters are unexpected.")
195207

196208

197209
def training_step(self, batch, batch_idx):
@@ -405,6 +417,12 @@ def parse_args():
405417
choices=["lora", "full"],
406418
help="Model structure to train. LoRA training or full training.",
407419
)
420+
parser.add_argument(
421+
"--pretrained_lora_path",
422+
type=str,
423+
default=None,
424+
help="Pretrained LoRA path. Required if the training is resumed.",
425+
)
408426
args = parser.parse_args()
409427
return args
410428

@@ -460,7 +478,8 @@ def train(args):
460478
lora_alpha=args.lora_alpha,
461479
lora_target_modules=args.lora_target_modules,
462480
init_lora_weights=args.init_lora_weights,
463-
use_gradient_checkpointing=args.use_gradient_checkpointing
481+
use_gradient_checkpointing=args.use_gradient_checkpointing,
482+
pretrained_lora_path=args.pretrained_lora_path,
464483
)
465484
trainer = pl.Trainer(
466485
max_epochs=args.max_epochs,

0 commit comments

Comments
 (0)