Skip to content

Commit 27cb7a4

Browse files
authored
Merge branch 'main' into patch-2
2 parents 02f2bbc + ec9e526 commit 27cb7a4

23 files changed

+226
-65
lines changed

docs/source/en/training/distributed_inference.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ transformer = FluxTransformer2DModel.from_pretrained(
177177
```
178178

179179
> [!TIP]
180-
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models.
180+
> At any point, you can try `print(pipeline.hf_device_map)` to see how the various models are distributed across devices. This is useful for tracking the device placement of the models. You can also try `print(transformer.hf_device_map)` to see how the transformer model is sharded across devices.
181181
182182
Add the transformer model to the pipeline for denoising, but set the other model-level components like the text encoders and VAE to `None` because you don't need them yet.
183183

examples/cogvideo/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen
180180

181181
> [!TIP]
182182
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
183+
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none']. See [this](https://gist.github.com/glide-the/7658dbfd5f555be0a1a687a4139dba40) notebook for examples.
183184
184185
> [!IMPORTANT]
185186
> The following settings have been tested at the time of adding CogVideoX LoRA training support:

examples/cogvideo/train_cogvideox_lora.py

Lines changed: 77 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,24 @@
2121
from pathlib import Path
2222
from typing import List, Optional, Tuple, Union
2323

24+
import numpy as np
2425
import torch
26+
import torchvision.transforms as TT
2527
import transformers
2628
from accelerate import Accelerator
2729
from accelerate.logging import get_logger
2830
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
2931
from huggingface_hub import create_repo, upload_folder
3032
from peft import LoraConfig, get_peft_model_state_dict, set_peft_model_state_dict
3133
from torch.utils.data import DataLoader, Dataset
32-
from torchvision import transforms
34+
from torchvision.transforms import InterpolationMode
35+
from torchvision.transforms.functional import resize
3336
from tqdm.auto import tqdm
3437
from transformers import AutoTokenizer, T5EncoderModel, T5Tokenizer
3538

3639
import diffusers
3740
from diffusers import AutoencoderKLCogVideoX, CogVideoXDPMScheduler, CogVideoXPipeline, CogVideoXTransformer3DModel
41+
from diffusers.image_processor import VaeImageProcessor
3842
from diffusers.models.embeddings import get_3d_rotary_pos_embed
3943
from diffusers.optimization import get_scheduler
4044
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
@@ -214,6 +218,12 @@ def get_args():
214218
default=720,
215219
help="All input videos are resized to this width.",
216220
)
221+
parser.add_argument(
222+
"--video_reshape_mode",
223+
type=str,
224+
default="center",
225+
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
226+
)
217227
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
218228
parser.add_argument(
219229
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
@@ -413,6 +423,7 @@ def __init__(
413423
video_column: str = "video",
414424
height: int = 480,
415425
width: int = 720,
426+
video_reshape_mode: str = "center",
416427
fps: int = 8,
417428
max_num_frames: int = 49,
418429
skip_frames_start: int = 0,
@@ -429,6 +440,7 @@ def __init__(
429440
self.video_column = video_column
430441
self.height = height
431442
self.width = width
443+
self.video_reshape_mode = video_reshape_mode
432444
self.fps = fps
433445
self.max_num_frames = max_num_frames
434446
self.skip_frames_start = skip_frames_start
@@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self):
532544

533545
return instance_prompts, instance_videos
534546

547+
def _resize_for_rectangle_crop(self, arr):
548+
image_size = self.height, self.width
549+
reshape_mode = self.video_reshape_mode
550+
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
551+
arr = resize(
552+
arr,
553+
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
554+
interpolation=InterpolationMode.BICUBIC,
555+
)
556+
else:
557+
arr = resize(
558+
arr,
559+
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
560+
interpolation=InterpolationMode.BICUBIC,
561+
)
562+
563+
h, w = arr.shape[2], arr.shape[3]
564+
arr = arr.squeeze(0)
565+
566+
delta_h = h - image_size[0]
567+
delta_w = w - image_size[1]
568+
569+
if reshape_mode == "random" or reshape_mode == "none":
570+
top = np.random.randint(0, delta_h + 1)
571+
left = np.random.randint(0, delta_w + 1)
572+
elif reshape_mode == "center":
573+
top, left = delta_h // 2, delta_w // 2
574+
else:
575+
raise NotImplementedError
576+
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
577+
return arr
578+
535579
def _preprocess_data(self):
536580
try:
537581
import decord
@@ -542,15 +586,14 @@ def _preprocess_data(self):
542586

543587
decord.bridge.set_bridge("torch")
544588

545-
videos = []
546-
train_transforms = transforms.Compose(
547-
[
548-
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
549-
]
589+
progress_dataset_bar = tqdm(
590+
range(0, len(self.instance_video_paths)),
591+
desc="Loading progress resize and crop videos",
550592
)
593+
videos = []
551594

552595
for filename in self.instance_video_paths:
553-
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
596+
video_reader = decord.VideoReader(uri=filename.as_posix())
554597
video_num_frames = len(video_reader)
555598

556599
start_frame = min(self.skip_frames_start, video_num_frames)
@@ -576,10 +619,16 @@ def _preprocess_data(self):
576619
assert (selected_num_frames - 1) % 4 == 0
577620

578621
# Training transforms
579-
frames = frames.float()
580-
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
581-
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
622+
frames = (frames - 127.5) / 127.5
623+
frames = frames.permute(0, 3, 1, 2) # [F, C, H, W]
624+
progress_dataset_bar.set_description(
625+
f"Loading progress Resizing video from {frames.shape[2]}x{frames.shape[3]} to {self.height}x{self.width}"
626+
)
627+
frames = self._resize_for_rectangle_crop(frames)
628+
videos.append(frames.contiguous()) # [F, C, H, W]
629+
progress_dataset_bar.update(1)
582630

631+
progress_dataset_bar.close()
583632
return videos
584633

585634

@@ -694,8 +743,13 @@ def log_validation(
694743

695744
videos = []
696745
for _ in range(args.num_validation_videos):
697-
video = pipe(**pipeline_args, generator=generator, output_type="np").frames[0]
698-
videos.append(video)
746+
pt_images = pipe(**pipeline_args, generator=generator, output_type="pt").frames[0]
747+
pt_images = torch.stack([pt_images[i] for i in range(pt_images.shape[0])])
748+
749+
image_np = VaeImageProcessor.pt_to_numpy(pt_images)
750+
image_pil = VaeImageProcessor.numpy_to_pil(image_np)
751+
752+
videos.append(image_pil)
699753

700754
for tracker in accelerator.trackers:
701755
phase_name = "test" if is_final_validation else "validation"
@@ -1171,6 +1225,7 @@ def load_model_hook(models, input_dir):
11711225
video_column=args.video_column,
11721226
height=args.height,
11731227
width=args.width,
1228+
video_reshape_mode=args.video_reshape_mode,
11741229
fps=args.fps,
11751230
max_num_frames=args.max_num_frames,
11761231
skip_frames_start=args.skip_frames_start,
@@ -1179,13 +1234,21 @@ def load_model_hook(models, input_dir):
11791234
id_token=args.id_token,
11801235
)
11811236

1182-
def encode_video(video):
1237+
def encode_video(video, bar):
1238+
bar.update(1)
11831239
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
11841240
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
11851241
latent_dist = vae.encode(video).latent_dist
11861242
return latent_dist
11871243

1188-
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
1244+
progress_encode_bar = tqdm(
1245+
range(0, len(train_dataset.instance_videos)),
1246+
desc="Loading Encode videos",
1247+
)
1248+
train_dataset.instance_videos = [
1249+
encode_video(video, progress_encode_bar) for video in train_dataset.instance_videos
1250+
]
1251+
progress_encode_bar.close()
11891252

11901253
def collate_fn(examples):
11911254
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]

examples/controlnet/train_controlnet_sd3.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,11 @@ def parse_args(input_args=None):
357357
action="store_true",
358358
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
359359
)
360+
parser.add_argument(
361+
"--upcast_vae",
362+
action="store_true",
363+
help="Whether or not to upcast vae to fp32",
364+
)
360365
parser.add_argument(
361366
"--learning_rate",
362367
type=float,
@@ -1094,7 +1099,10 @@ def load_model_hook(models, input_dir):
10941099
weight_dtype = torch.bfloat16
10951100

10961101
# Move vae, transformer and text_encoder to device and cast to weight_dtype
1097-
vae.to(accelerator.device, dtype=torch.float32)
1102+
if args.upcast_vae:
1103+
vae.to(accelerator.device, dtype=torch.float32)
1104+
else:
1105+
vae.to(accelerator.device, dtype=weight_dtype)
10981106
transformer.to(accelerator.device, dtype=weight_dtype)
10991107
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
11001108
text_encoder_two.to(accelerator.device, dtype=weight_dtype)

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -632,7 +632,7 @@ def handle_qkv(sds_sd, ait_sd, sds_key, ait_keys, dims=None):
632632
new_key += ".lora_B.weight"
633633

634634
# Handle single_blocks
635-
elif old_key.startswith("diffusion_model.single_blocks", "single_blocks"):
635+
elif old_key.startswith(("diffusion_model.single_blocks", "single_blocks")):
636636
block_num = re.search(r"single_blocks\.(\d+)", old_key).group(1)
637637
new_key = f"transformer.single_transformer_blocks.{block_num}"
638638

src/diffusers/loaders/lora_pipeline.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def load_lora_weights(
9999
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
100100
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
101101

102-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
102+
is_correct_format = all("lora" in key for key in state_dict.keys())
103103
if not is_correct_format:
104104
raise ValueError("Invalid LoRA checkpoint.")
105105

@@ -211,6 +211,11 @@ def lora_state_dict(
211211
user_agent=user_agent,
212212
allow_pickle=allow_pickle,
213213
)
214+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
215+
if is_dora_scale_present:
216+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
217+
logger.warning(warn_msg)
218+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
214219

215220
network_alphas = None
216221
# TODO: replace it with a method from `state_dict_utils`
@@ -562,7 +567,8 @@ def load_lora_weights(
562567
unet_config=self.unet.config,
563568
**kwargs,
564569
)
565-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
570+
571+
is_correct_format = all("lora" in key for key in state_dict.keys())
566572
if not is_correct_format:
567573
raise ValueError("Invalid LoRA checkpoint.")
568574

@@ -684,6 +690,11 @@ def lora_state_dict(
684690
user_agent=user_agent,
685691
allow_pickle=allow_pickle,
686692
)
693+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
694+
if is_dora_scale_present:
695+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
696+
logger.warning(warn_msg)
697+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
687698

688699
network_alphas = None
689700
# TODO: replace it with a method from `state_dict_utils`
@@ -1089,6 +1100,12 @@ def lora_state_dict(
10891100
allow_pickle=allow_pickle,
10901101
)
10911102

1103+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1104+
if is_dora_scale_present:
1105+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1106+
logger.warning(warn_msg)
1107+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1108+
10921109
return state_dict
10931110

10941111
def load_lora_weights(
@@ -1125,7 +1142,7 @@ def load_lora_weights(
11251142
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
11261143
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
11271144

1128-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1145+
is_correct_format = all("lora" in key for key in state_dict.keys())
11291146
if not is_correct_format:
11301147
raise ValueError("Invalid LoRA checkpoint.")
11311148

@@ -1587,9 +1604,13 @@ def lora_state_dict(
15871604
user_agent=user_agent,
15881605
allow_pickle=allow_pickle,
15891606
)
1607+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1608+
if is_dora_scale_present:
1609+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1610+
logger.warning(warn_msg)
1611+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
15901612

15911613
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
1592-
15931614
is_kohya = any(".lora_down.weight" in k for k in state_dict)
15941615
if is_kohya:
15951616
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
@@ -1659,7 +1680,7 @@ def load_lora_weights(
16591680
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
16601681
)
16611682

1662-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1683+
is_correct_format = all("lora" in key for key in state_dict.keys())
16631684
if not is_correct_format:
16641685
raise ValueError("Invalid LoRA checkpoint.")
16651686

@@ -2374,6 +2395,12 @@ def lora_state_dict(
23742395
allow_pickle=allow_pickle,
23752396
)
23762397

2398+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2399+
if is_dora_scale_present:
2400+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2401+
logger.warning(warn_msg)
2402+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2403+
23772404
return state_dict
23782405

23792406
def load_lora_weights(
@@ -2405,7 +2432,7 @@ def load_lora_weights(
24052432
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
24062433
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
24072434

2408-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
2435+
is_correct_format = all("lora" in key for key in state_dict.keys())
24092436
if not is_correct_format:
24102437
raise ValueError("Invalid LoRA checkpoint.")
24112438

src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,7 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOut
12341234
return self.tiled_decode(z, return_dict=return_dict)
12351235

12361236
frame_batch_size = self.num_latent_frames_batch_size
1237-
num_batches = num_frames // frame_batch_size
1237+
num_batches = max(num_frames // frame_batch_size, 1)
12381238
conv_cache = None
12391239
dec = []
12401240

0 commit comments

Comments
 (0)