Skip to content

Commit 75beab5

Browse files
author
--unset
committed
Removed int8 to float32 conversion (* 2.0 - 1.0) from train_transforms as it caused image overexposure.
Added `_resize_for_rectangle_crop` function to enable video cropping functionality. The cropping mode can be configured via `video_reshape_mode`, supporting options: ['center', 'random', 'none'].
1 parent 7f323f0 commit 75beab5

File tree

3 files changed

+746
-10
lines changed

3 files changed

+746
-10
lines changed

examples/cogvideo/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen
180180

181181
> [!TIP]
182182
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
183+
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none'].
183184
184185
> [!IMPORTANT]
185186
> The following settings have been tested at the time of adding CogVideoX LoRA training support:

examples/cogvideo/train_cogvideox_lora.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
4343
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
4444
from diffusers.utils.torch_utils import is_compiled_module
45+
from torchvision.transforms.functional import center_crop, resize
46+
from torchvision.transforms import InterpolationMode
47+
import torchvision.transforms as TT
48+
import numpy as np
4549

4650

4751
if is_wandb_available():
@@ -214,6 +218,12 @@ def get_args():
214218
default=720,
215219
help="All input videos are resized to this width.",
216220
)
221+
parser.add_argument(
222+
"--video_reshape_mode",
223+
type=str,
224+
default="center",
225+
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
226+
)
217227
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
218228
parser.add_argument(
219229
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
@@ -413,6 +423,7 @@ def __init__(
413423
video_column: str = "video",
414424
height: int = 480,
415425
width: int = 720,
426+
video_reshape_mode: str = "center",
416427
fps: int = 8,
417428
max_num_frames: int = 49,
418429
skip_frames_start: int = 0,
@@ -429,6 +440,7 @@ def __init__(
429440
self.video_column = video_column
430441
self.height = height
431442
self.width = width
443+
self.video_reshape_mode = video_reshape_mode
432444
self.fps = fps
433445
self.max_num_frames = max_num_frames
434446
self.skip_frames_start = skip_frames_start
@@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self):
532544

533545
return instance_prompts, instance_videos
534546

547+
def _resize_for_rectangle_crop(self, arr):
548+
image_size = self.height, self.width
549+
reshape_mode = self.video_reshape_mode
550+
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
551+
arr = resize(
552+
arr,
553+
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
554+
interpolation=InterpolationMode.BICUBIC,
555+
)
556+
else:
557+
arr = resize(
558+
arr,
559+
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
560+
interpolation=InterpolationMode.BICUBIC,
561+
)
562+
563+
h, w = arr.shape[2], arr.shape[3]
564+
arr = arr.squeeze(0)
565+
566+
delta_h = h - image_size[0]
567+
delta_w = w - image_size[1]
568+
569+
if reshape_mode == "random" or reshape_mode == "none":
570+
top = np.random.randint(0, delta_h + 1)
571+
left = np.random.randint(0, delta_w + 1)
572+
elif reshape_mode == "center":
573+
top, left = delta_h // 2, delta_w // 2
574+
else:
575+
raise NotImplementedError
576+
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
577+
return arr
578+
535579
def _preprocess_data(self):
536580
try:
537581
import decord
@@ -542,14 +586,14 @@ def _preprocess_data(self):
542586

543587
decord.bridge.set_bridge("torch")
544588

545-
videos = []
546-
train_transforms = transforms.Compose(
547-
[
548-
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
549-
]
589+
progress_dataset_bar = tqdm(
590+
range(0, len(self.instance_video_paths)),
591+
desc="Loading progress resize and crop videos",
550592
)
593+
videos = []
551594

552595
for filename in self.instance_video_paths:
596+
progress_dataset_bar.update(1)
553597
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
554598
video_num_frames = len(video_reader)
555599

@@ -576,10 +620,12 @@ def _preprocess_data(self):
576620
assert (selected_num_frames - 1) % 4 == 0
577621

578622
# Training transforms
579-
frames = frames.float()
580-
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
581-
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
623+
tensor = frames.float() / 255.0
624+
frames = tensor.permute(0, 3, 1, 2)
625+
frames = self._resize_for_rectangle_crop(frames)
626+
videos.append(frames.contiguous()) # [F, C, H, W]
582627

628+
progress_dataset_bar.close()
583629
return videos
584630

585631

@@ -1171,6 +1217,7 @@ def load_model_hook(models, input_dir):
11711217
video_column=args.video_column,
11721218
height=args.height,
11731219
width=args.width,
1220+
video_reshape_mode=args.video_reshape_mode,
11741221
fps=args.fps,
11751222
max_num_frames=args.max_num_frames,
11761223
skip_frames_start=args.skip_frames_start,
@@ -1179,13 +1226,20 @@ def load_model_hook(models, input_dir):
11791226
id_token=args.id_token,
11801227
)
11811228

1182-
def encode_video(video):
1229+
1230+
def encode_video(video, bar):
1231+
bar.update(1)
11831232
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
11841233
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
11851234
latent_dist = vae.encode(video).latent_dist
11861235
return latent_dist
11871236

1188-
train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
1237+
progress_encode_bar = tqdm(
1238+
range(0, len(train_dataset.instance_videos)),
1239+
desc="Loading Encode videos",
1240+
)
1241+
train_dataset.instance_videos = [encode_video(video,progress_encode_bar) for video in train_dataset.instance_videos]
1242+
progress_encode_bar.close()
11891243

11901244
def collate_fn(examples):
11911245
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]

examples/cogvideo/video_fix_rgb_float_and_crop.ipynb

Lines changed: 681 additions & 0 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)