Skip to content

Commit d63a826

Browse files
authored
I2V multiresolution finetuning by removing learned PEs (#31)
* i2v finetuning without learned pe * update * update * update * update readme * refactor * refactor
1 parent cbbac06 commit d63a826

File tree

7 files changed

+163
-97
lines changed

7 files changed

+163
-97
lines changed

README.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,23 @@ video = pipe("<my-awesome-prompt>").frames[0]
5353
export_to_video(video, "output.mp4", fps=8)
5454
```
5555

56-
You can also check if your LoRA is correctly mounted [here](tests/test_lora_inference.py).
56+
For Image-to-Video LoRAs trained with multiresolution videos, one must also add the following lines (see [this](https://github.com/a-r-r-o-w/cogvideox-factory/issues/26) Issue for more details):
57+
58+
```python
59+
from diffusers import CogVideoXImageToVideoPipeline
60+
61+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
62+
"THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16
63+
).to("cuda")
5764

58-
**Note:** For Image-to-Video finetuning, you must install diffusers from [this](https://github.com/huggingface/diffusers/pull/9482) branch (which adds lora loading support in CogVideoX image-to-video) until it is merged.
65+
# ...
66+
67+
del pipe.transformer.patch_embed.pos_embedding
68+
pipe.transformer.patch_embed.use_learned_positional_embeddings = False
69+
pipe.transformer.config.use_learned_positional_embeddings = False
70+
```
71+
72+
You can also check if your LoRA is correctly mounted [here](tests/test_lora_inference.py).
5973

6074
Below we provide additional sections detailing on more options explored in this repository. They all attempt to make fine-tuning for video models as accessible as possible by reducing memory requirements as much as possible.
6175

training/args.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,13 @@ def _get_validation_args(parser: argparse.ArgumentParser) -> None:
131131
parser.add_argument(
132132
"--validation_epochs",
133133
type=int,
134-
default=50,
134+
default=None,
135+
help="Run validation every X training epochs. Validation consists of running the validation prompt `args.num_validation_videos` times.",
136+
)
137+
parser.add_argument(
138+
"--validation_steps",
139+
type=int,
140+
default=None,
135141
help="Run validation every X training steps. Validation consists of running the validation prompt `args.num_validation_videos` times.",
136142
)
137143
parser.add_argument(
@@ -323,6 +329,16 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None:
323329
default=0.05,
324330
help="Image condition dropout probability when finetuning image-to-video.",
325331
)
332+
parser.add_argument(
333+
"--ignore_learned_positional_embeddings",
334+
action="store_true",
335+
default=False,
336+
help=(
337+
"Whether to ignore the learned positional embeddings when training CogVideoX Image-to-Video. This setting "
338+
"should be used when performing multi-resolution training, because CogVideoX-I2V does not support it "
339+
"otherwise. Please read the comments in https://github.com/a-r-r-o-w/cogvideox-factory/issues/26 to understand why."
340+
),
341+
)
326342

327343

328344
def _get_optimizer_args(parser: argparse.ArgumentParser) -> None:

training/cogvideox_image_to_video_lora.py

Lines changed: 91 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from diffusers.training_utils import cast_training_params
4747
from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video, load_image
4848
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
49-
from diffusers.utils.torch_utils import is_compiled_module
5049
from huggingface_hub import create_repo, upload_folder
5150
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
5251
from torch.utils.data import DataLoader
@@ -57,7 +56,14 @@
5756
from args import get_args # isort:skip
5857
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
5958
from text_encoder import compute_prompt_embeddings # isort:skip
60-
from utils import get_gradient_norm, get_optimizer, prepare_rotary_positional_embeddings, print_memory, reset_memory # isort:skip
59+
from utils import (
60+
get_gradient_norm,
61+
get_optimizer,
62+
prepare_rotary_positional_embeddings,
63+
print_memory,
64+
reset_memory,
65+
unwrap_model,
66+
)
6167

6268

6369
logger = get_logger(__name__)
@@ -155,7 +161,6 @@ def log_validation(
155161
pipe: CogVideoXImageToVideoPipeline,
156162
args: Dict[str, Any],
157163
pipeline_args: Dict[str, Any],
158-
epoch,
159164
is_final_validation: bool = False,
160165
):
161166
logger.info(
@@ -201,6 +206,64 @@ def log_validation(
201206
return videos
202207

203208

209+
def run_validation(
210+
args: Dict[str, Any],
211+
accelerator: Accelerator,
212+
transformer,
213+
scheduler,
214+
model_config: Dict[str, Any],
215+
weight_dtype: torch.dtype,
216+
) -> None:
217+
accelerator.print("===== Memory before validation =====")
218+
print_memory(accelerator.device)
219+
torch.cuda.synchronize(accelerator.device)
220+
221+
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
222+
args.pretrained_model_name_or_path,
223+
transformer=unwrap_model(accelerator, transformer),
224+
scheduler=scheduler,
225+
revision=args.revision,
226+
variant=args.variant,
227+
torch_dtype=weight_dtype,
228+
)
229+
230+
if args.enable_slicing:
231+
pipe.vae.enable_slicing()
232+
if args.enable_tiling:
233+
pipe.vae.enable_tiling()
234+
if args.enable_model_cpu_offload:
235+
pipe.enable_model_cpu_offload()
236+
237+
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
238+
validation_images = args.validation_images.split(args.validation_prompt_separator)
239+
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
240+
pipeline_args = {
241+
"image": load_image(validation_image),
242+
"prompt": validation_prompt,
243+
"guidance_scale": args.guidance_scale,
244+
"use_dynamic_cfg": args.use_dynamic_cfg,
245+
"height": args.height,
246+
"width": args.width,
247+
"max_sequence_length": model_config.max_text_seq_length,
248+
}
249+
250+
log_validation(
251+
pipe=pipe,
252+
args=args,
253+
accelerator=accelerator,
254+
pipeline_args=pipeline_args,
255+
)
256+
257+
accelerator.print("===== Memory after validation =====")
258+
print_memory(accelerator.device)
259+
reset_memory(accelerator.device)
260+
261+
del pipe
262+
gc.collect()
263+
torch.cuda.empty_cache()
264+
torch.cuda.synchronize(accelerator.device)
265+
266+
204267
class CollateFunction:
205268
def __init__(self, weight_dtype: torch.dtype, load_tensors: bool) -> None:
206269
self.weight_dtype = weight_dtype
@@ -308,6 +371,12 @@ def main(args):
308371
variant=args.variant,
309372
)
310373

374+
# These changes will also be required when trying to run inference with the trained lora
375+
if args.ignore_learned_positional_embeddings:
376+
del transformer.patch_embed.pos_embedding
377+
transformer.patch_embed.use_learned_positional_embeddings = False
378+
transformer.config.use_learned_positional_embeddings = False
379+
311380
vae = AutoencoderKLCogVideoX.from_pretrained(
312381
args.pretrained_model_name_or_path,
313382
subfolder="vae",
@@ -373,19 +442,14 @@ def main(args):
373442
)
374443
transformer.add_adapter(transformer_lora_config)
375444

376-
def unwrap_model(model):
377-
model = accelerator.unwrap_model(model)
378-
model = model._orig_mod if is_compiled_module(model) else model
379-
return model
380-
381445
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
382446
def save_model_hook(models, weights, output_dir):
383447
if accelerator.is_main_process:
384448
transformer_lora_layers_to_save = None
385449

386450
for model in models:
387-
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
388-
model = unwrap_model(model)
451+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
452+
model = unwrap_model(accelerator, model)
389453
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
390454
else:
391455
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -407,10 +471,10 @@ def load_model_hook(models, input_dir):
407471
while len(models) > 0:
408472
model = models.pop()
409473

410-
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
411-
transformer_ = unwrap_model(model)
474+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
475+
transformer_ = unwrap_model(accelerator, model)
412476
else:
413-
raise ValueError(f"Unexpected save model: {unwrap_model(model).__class__}")
477+
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
414478
else:
415479
transformer_ = CogVideoXTransformer3DModel.from_pretrained(
416480
args.pretrained_model_name_or_path, subfolder="transformer"
@@ -776,6 +840,7 @@ def load_model_hook(models, input_dir):
776840
progress_bar.update(1)
777841
global_step += 1
778842

843+
# Checkpointing
779844
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
780845
if global_step % args.checkpointing_steps == 0:
781846
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
@@ -802,6 +867,13 @@ def load_model_hook(models, input_dir):
802867
accelerator.save_state(save_path)
803868
logger.info(f"Saved state to {save_path}")
804869

870+
# Validation
871+
should_run_validation = args.validation_prompt is not None and (
872+
args.validation_steps is not None and global_step % args.validation_steps == 0
873+
)
874+
if should_run_validation:
875+
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
876+
805877
last_lr = lr_scheduler.get_last_lr()[0] if lr_scheduler is not None else args.learning_rate
806878
logs = {"loss": loss.detach().item(), "lr": last_lr}
807879
# gradnorm + deepspeed: https://github.com/microsoft/DeepSpeed/issues/4555
@@ -819,61 +891,16 @@ def load_model_hook(models, input_dir):
819891
break
820892

821893
if accelerator.is_main_process:
822-
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
823-
accelerator.print("===== Memory before validation =====")
824-
print_memory(accelerator.device)
825-
torch.cuda.synchronize(accelerator.device)
826-
827-
pipe = CogVideoXImageToVideoPipeline.from_pretrained(
828-
args.pretrained_model_name_or_path,
829-
transformer=unwrap_model(transformer),
830-
scheduler=scheduler,
831-
revision=args.revision,
832-
variant=args.variant,
833-
torch_dtype=weight_dtype,
834-
)
835-
836-
if args.enable_slicing:
837-
pipe.vae.enable_slicing()
838-
if args.enable_tiling:
839-
pipe.vae.enable_tiling()
840-
if args.enable_model_cpu_offload:
841-
pipe.enable_model_cpu_offload()
842-
843-
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
844-
validation_images = args.validation_images.split(args.validation_prompt_separator)
845-
for validation_image, validation_prompt in zip(validation_images, validation_prompts):
846-
pipeline_args = {
847-
"image": load_image(validation_image),
848-
"prompt": validation_prompt,
849-
"guidance_scale": args.guidance_scale,
850-
"use_dynamic_cfg": args.use_dynamic_cfg,
851-
"height": args.height,
852-
"width": args.width,
853-
"max_sequence_length": model_config.max_text_seq_length,
854-
}
855-
856-
log_validation(
857-
pipe=pipe,
858-
args=args,
859-
accelerator=accelerator,
860-
pipeline_args=pipeline_args,
861-
epoch=epoch,
862-
)
863-
864-
accelerator.print("===== Memory after validation =====")
865-
print_memory(accelerator.device)
866-
reset_memory(accelerator.device)
867-
868-
del pipe
869-
gc.collect()
870-
torch.cuda.empty_cache()
871-
torch.cuda.synchronize(accelerator.device)
894+
should_run_validation = args.validation_prompt is not None and (
895+
args.validation_epochs is not None and (epoch + 1) % args.validation_epochs == 0
896+
)
897+
if should_run_validation:
898+
run_validation(args, accelerator, transformer, scheduler, model_config, weight_dtype)
872899

873900
accelerator.wait_for_everyone()
874901

875902
if accelerator.is_main_process:
876-
transformer = unwrap_model(transformer)
903+
transformer = unwrap_model(accelerator, transformer)
877904
dtype = (
878905
torch.float16
879906
if args.mixed_precision == "fp16"
@@ -944,7 +971,6 @@ def load_model_hook(models, input_dir):
944971
pipe=pipe,
945972
args=args,
946973
pipeline_args=pipeline_args,
947-
epoch=epoch,
948974
is_final_validation=True,
949975
)
950976
validation_outputs.extend(video)

training/cogvideox_text_to_video_lora.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@
4545
from diffusers.training_utils import cast_training_params
4646
from diffusers.utils import convert_unet_state_dict_to_peft, export_to_video
4747
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
48-
from diffusers.utils.torch_utils import is_compiled_module
4948
from huggingface_hub import create_repo, upload_folder
5049
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
5150
from torch.utils.data import DataLoader
@@ -56,7 +55,14 @@
5655
from args import get_args # isort:skip
5756
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
5857
from text_encoder import compute_prompt_embeddings # isort:skip
59-
from utils import get_gradient_norm, get_optimizer, prepare_rotary_positional_embeddings, print_memory, reset_memory # isort:skip
58+
from utils import (
59+
get_gradient_norm,
60+
get_optimizer,
61+
prepare_rotary_positional_embeddings,
62+
print_memory,
63+
reset_memory,
64+
unwrap_model,
65+
) # isort:skip
6066

6167

6268
logger = get_logger(__name__)
@@ -366,19 +372,14 @@ def main(args):
366372
)
367373
transformer.add_adapter(transformer_lora_config)
368374

369-
def unwrap_model(model):
370-
model = accelerator.unwrap_model(model)
371-
model = model._orig_mod if is_compiled_module(model) else model
372-
return model
373-
374375
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
375376
def save_model_hook(models, weights, output_dir):
376377
if accelerator.is_main_process:
377378
transformer_lora_layers_to_save = None
378379

379380
for model in models:
380-
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
381-
model = unwrap_model(model)
381+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
382+
model = unwrap_model(accelerator, model)
382383
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
383384
else:
384385
raise ValueError(f"unexpected save model: {model.__class__}")
@@ -400,10 +401,10 @@ def load_model_hook(models, input_dir):
400401
while len(models) > 0:
401402
model = models.pop()
402403

403-
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
404-
transformer_ = unwrap_model(model)
404+
if isinstance(unwrap_model(accelerator, model), type(unwrap_model(accelerator, transformer))):
405+
transformer_ = unwrap_model(accelerator, model)
405406
else:
406-
raise ValueError(f"Unexpected save model: {unwrap_model(model).__class__}")
407+
raise ValueError(f"Unexpected save model: {unwrap_model(accelerator, model).__class__}")
407408
else:
408409
transformer_ = CogVideoXTransformer3DModel.from_pretrained(
409410
args.pretrained_model_name_or_path, subfolder="transformer"
@@ -797,7 +798,7 @@ def load_model_hook(models, input_dir):
797798

798799
pipe = CogVideoXPipeline.from_pretrained(
799800
args.pretrained_model_name_or_path,
800-
transformer=unwrap_model(transformer),
801+
transformer=unwrap_model(accelerator, transformer),
801802
scheduler=scheduler,
802803
revision=args.revision,
803804
variant=args.variant,
@@ -842,7 +843,7 @@ def load_model_hook(models, input_dir):
842843
accelerator.wait_for_everyone()
843844

844845
if accelerator.is_main_process:
845-
transformer = unwrap_model(transformer)
846+
transformer = unwrap_model(accelerator, transformer)
846847
dtype = (
847848
torch.float16
848849
if args.mixed_precision == "fp16"

0 commit comments

Comments
 (0)