-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
177 lines (165 loc) · 8.74 KB
/
train.py
File metadata and controls
177 lines (165 loc) · 8.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import torch, os, argparse, accelerate, warnings
from diffsynth.core import UnifiedDataset
from diffsynth.core.data.operators import LoadVideo, LoadAudio, ImageCropAndResize, ToAbsolutePath
from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig
from diffsynth.diffusion import *
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class WanTrainingModule(DiffusionTrainingModule):
def __init__(
self,
model_paths=None, model_id_with_origin_paths=None,
tokenizer_path=None, audio_processor_path=None,
trainable_models=None,
lora_base_model=None, lora_target_modules="", lora_rank=32, lora_checkpoint=None,
preset_lora_path=None, preset_lora_model=None,
use_gradient_checkpointing=True,
use_gradient_checkpointing_offload=False,
extra_inputs=None,
fp8_models=None,
offload_models=None,
device="cpu",
task="sft",
max_timestep_boundary=1.0,
min_timestep_boundary=0.0,
):
super().__init__()
if not use_gradient_checkpointing:
warnings.warn("Gradient checkpointing is detected as disabled. To prevent out-of-memory errors, the training framework will forcibly enable gradient checkpointing.")
use_gradient_checkpointing = True
model_configs = self.parse_model_configs(model_paths, model_id_with_origin_paths, fp8_models=fp8_models, offload_models=offload_models, device=device)
tokenizer_config = ModelConfig(model_id="Wan-AI/Wan2.1-T2V-1.3B", origin_file_pattern="google/umt5-xxl/") if tokenizer_path is None else ModelConfig(tokenizer_path)
audio_processor_config = self.parse_path_or_model_id(audio_processor_path)
self.pipe = WanVideoPipeline.from_pretrained(torch_dtype=torch.bfloat16, device=device, model_configs=model_configs, tokenizer_config=tokenizer_config, audio_processor_config=audio_processor_config)
self.pipe = self.split_pipeline_units(task, self.pipe, trainable_models, lora_base_model)
self.switch_pipe_to_training_mode(
self.pipe, trainable_models,
lora_base_model, lora_target_modules, lora_rank, lora_checkpoint,
preset_lora_path, preset_lora_model,
task=task,
)
self.use_gradient_checkpointing = use_gradient_checkpointing
self.use_gradient_checkpointing_offload = use_gradient_checkpointing_offload
self.extra_inputs = extra_inputs.split(",") if extra_inputs is not None else []
self.fp8_models = fp8_models
self.task = task
self.task_to_loss = {
"sft:data_process": lambda pipe, *args: args,
"direct_distill:data_process": lambda pipe, *args: args,
"sft": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"sft:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: FlowMatchSFTLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
"direct_distill:train": lambda pipe, inputs_shared, inputs_posi, inputs_nega: DirectDistillLoss(pipe, **inputs_shared, **inputs_posi),
}
self.max_timestep_boundary = max_timestep_boundary
self.min_timestep_boundary = min_timestep_boundary
def parse_extra_inputs(self, data, extra_inputs, inputs_shared):
for extra_input in extra_inputs:
if extra_input == "input_image":
inputs_shared["input_image"] = data["video"][0]
elif extra_input == "end_image":
inputs_shared["end_image"] = data["video"][-1]
elif extra_input == "reference_image" or extra_input == "vace_reference_image":
inputs_shared[extra_input] = data[extra_input][0]
else:
inputs_shared[extra_input] = data[extra_input]
return inputs_shared
def get_pipeline_inputs(self, data):
inputs_posi = {"prompt": data["prompt"]}
inputs_nega = {}
inputs_shared = {
"input_video": data["video"],
"height": data["video"][0].size[1],
"width": data["video"][0].size[0],
"num_frames": len(data["video"]),
"cfg_scale": 1,
"tiled": False,
"rand_device": self.pipe.device,
"use_gradient_checkpointing": self.use_gradient_checkpointing,
"use_gradient_checkpointing_offload": self.use_gradient_checkpointing_offload,
"cfg_merge": False,
"vace_scale": 1,
"max_timestep_boundary": self.max_timestep_boundary,
"min_timestep_boundary": self.min_timestep_boundary,
}
inputs_shared = self.parse_extra_inputs(data, self.extra_inputs, inputs_shared)
return inputs_shared, inputs_posi, inputs_nega
def forward(self, data, inputs=None):
if inputs is None: inputs = self.get_pipeline_inputs(data)
inputs = self.transfer_data_to_device(inputs, self.pipe.device, self.pipe.torch_dtype)
for unit in self.pipe.units:
inputs = self.pipe.unit_runner(unit, self.pipe, *inputs)
loss = self.task_to_loss[self.task](self.pipe, *inputs)
return loss
def wan_parser():
parser = argparse.ArgumentParser(description="Wan2.1 Video Model Fine-tuning")
parser = add_general_config(parser)
parser = add_video_size_config(parser)
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to tokenizer.")
parser.add_argument("--audio_processor_path", type=str, default=None, help="Path to the audio processor.")
parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary.")
parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary.")
parser.add_argument("--initialize_model_on_cpu", default=False, action="store_true", help="Whether to initialize models on CPU.")
return parser
if __name__ == "__main__":
parser = wan_parser()
args = parser.parse_args()
accelerator = accelerate.Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
kwargs_handlers=[accelerate.DistributedDataParallelKwargs(find_unused_parameters=args.find_unused_parameters)],
)
dataset = UnifiedDataset(
base_path=args.dataset_base_path,
metadata_path=args.dataset_metadata_path,
repeat=args.dataset_repeat,
data_file_keys=args.data_file_keys.split(","),
main_data_operator=UnifiedDataset.default_video_operator(
base_path=args.dataset_base_path,
max_pixels=args.max_pixels,
height=args.height,
width=args.width,
height_division_factor=16,
width_division_factor=16,
num_frames=args.num_frames,
time_division_factor=4,
time_division_remainder=1,
),
special_operator_map={
"animate_face_video": ToAbsolutePath(args.dataset_base_path) >> LoadVideo(args.num_frames, 4, 1, frame_processor=ImageCropAndResize(512, 512, None, 16, 16)),
"input_audio": ToAbsolutePath(args.dataset_base_path) >> LoadAudio(sr=16000),
}
)
model = WanTrainingModule(
model_paths=args.model_paths,
model_id_with_origin_paths=args.model_id_with_origin_paths,
tokenizer_path=args.tokenizer_path,
audio_processor_path=args.audio_processor_path,
trainable_models=args.trainable_models,
lora_base_model=args.lora_base_model,
lora_target_modules=args.lora_target_modules,
lora_rank=args.lora_rank,
lora_checkpoint=args.lora_checkpoint,
preset_lora_path=args.preset_lora_path,
preset_lora_model=args.preset_lora_model,
use_gradient_checkpointing=args.use_gradient_checkpointing,
use_gradient_checkpointing_offload=args.use_gradient_checkpointing_offload,
extra_inputs=args.extra_inputs,
fp8_models=args.fp8_models,
offload_models=args.offload_models,
task=args.task,
device="cpu" if args.initialize_model_on_cpu else accelerator.device,
max_timestep_boundary=args.max_timestep_boundary,
min_timestep_boundary=args.min_timestep_boundary,
)
model_logger = ModelLogger(
args.output_path,
remove_prefix_in_ckpt=args.remove_prefix_in_ckpt,
)
launcher_map = {
"sft:data_process": launch_data_process_task,
"direct_distill:data_process": launch_data_process_task,
"sft": launch_training_task,
"sft:train": launch_training_task,
"direct_distill": launch_training_task,
"direct_distill:train": launch_training_task,
}
launcher_map[args.task](accelerator, dataset, model, model_logger, args=args)