Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/cogvideo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ Note that setting the `<ID_TOKEN>` is not necessary. From some limited experimen

> [!TIP]
> You can pass `--use_8bit_adam` to reduce the memory requirements of training.
> You can pass `--video_reshape_mode` video cropping functionality, supporting options: ['center', 'random', 'none'].

> [!IMPORTANT]
> The following settings have been tested at the time of adding CogVideoX LoRA training support:
Expand Down
74 changes: 64 additions & 10 deletions examples/cogvideo/train_cogvideox_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@
from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.torch_utils import is_compiled_module
from torchvision.transforms.functional import center_crop, resize
from torchvision.transforms import InterpolationMode
import torchvision.transforms as TT
import numpy as np


if is_wandb_available():
Expand Down Expand Up @@ -214,6 +218,12 @@ def get_args():
default=720,
help="All input videos are resized to this width.",
)
parser.add_argument(
"--video_reshape_mode",
type=str,
default="center",
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
)
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
parser.add_argument(
"--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames."
Expand Down Expand Up @@ -413,6 +423,7 @@ def __init__(
video_column: str = "video",
height: int = 480,
width: int = 720,
video_reshape_mode: str = "center",
fps: int = 8,
max_num_frames: int = 49,
skip_frames_start: int = 0,
Expand All @@ -429,6 +440,7 @@ def __init__(
self.video_column = video_column
self.height = height
self.width = width
self.video_reshape_mode = video_reshape_mode
self.fps = fps
self.max_num_frames = max_num_frames
self.skip_frames_start = skip_frames_start
Expand Down Expand Up @@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self):

return instance_prompts, instance_videos

def _resize_for_rectangle_crop(self, arr):
image_size = self.height, self.width
reshape_mode = self.video_reshape_mode
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
arr = resize(
arr,
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
interpolation=InterpolationMode.BICUBIC,
)
else:
arr = resize(
arr,
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
interpolation=InterpolationMode.BICUBIC,
)

h, w = arr.shape[2], arr.shape[3]
arr = arr.squeeze(0)

delta_h = h - image_size[0]
delta_w = w - image_size[1]

if reshape_mode == "random" or reshape_mode == "none":
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
elif reshape_mode == "center":
top, left = delta_h // 2, delta_w // 2
else:
raise NotImplementedError
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
return arr

def _preprocess_data(self):
try:
import decord
Expand All @@ -542,14 +586,14 @@ def _preprocess_data(self):

decord.bridge.set_bridge("torch")

videos = []
train_transforms = transforms.Compose(
[
transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0),
]
progress_dataset_bar = tqdm(
range(0, len(self.instance_video_paths)),
desc="Loading progress resize and crop videos",
)
videos = []

for filename in self.instance_video_paths:
progress_dataset_bar.update(1)
video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height)
video_num_frames = len(video_reader)

Expand All @@ -576,10 +620,12 @@ def _preprocess_data(self):
assert (selected_num_frames - 1) % 4 == 0

# Training transforms
frames = frames.float()
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
tensor = frames.float() / 255.0
Copy link
Contributor

@a-r-r-o-w a-r-r-o-w Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why we are making the tensors to range [0, 1], instead of [-1, 1]. In the original codebase, we convert to [-1, 1] as well here if I understand correctly, yes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure why we are making the tensors to range [0, 1], instead of [-1, 1]. In the original codebase, we convert to [-1, 1] as well here if I understand correctly, yes?

You're right, it should be in the [-1, 1] range. In fact, this is for matrix calculations during fine-tuning, and the [-1, 1] range is easier for computation. I forgot that this step is handled in the latten2img process, so the image is in the [0, 1] range, while the latent space is in the [-1, 1] range.

I've already verified that the cause of the training result showing a blank screen is that I input a 960x720 image into the dataset, and it was compressed to a 460x720 image for training directly.

Copy link
Contributor Author

@glide-the glide-the Oct 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/THUDM/CogVideo/blob/111756a6a68a8df375ef9c31f9f325818699dfaa/sat/data_video.py#L437
The number 127.5 may experience precision loss during division operations.

encode : images / 255.0 * 2.0 - 1.0
decode: (images / 2 + 0.5).clamp(0, 1)
image

encode : (frames - 127.5) / 127.5
decode: (images / 2 + 0.5).clamp(0, 1)
image

frames = tensor.permute(0, 3, 1, 2)
frames = self._resize_for_rectangle_crop(frames)
videos.append(frames.contiguous()) # [F, C, H, W]

progress_dataset_bar.close()
return videos


Expand Down Expand Up @@ -1171,6 +1217,7 @@ def load_model_hook(models, input_dir):
video_column=args.video_column,
height=args.height,
width=args.width,
video_reshape_mode=args.video_reshape_mode,
fps=args.fps,
max_num_frames=args.max_num_frames,
skip_frames_start=args.skip_frames_start,
Expand All @@ -1179,13 +1226,20 @@ def load_model_hook(models, input_dir):
id_token=args.id_token,
)

def encode_video(video):

def encode_video(video, bar):
bar.update(1)
video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0)
video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
latent_dist = vae.encode(video).latent_dist
return latent_dist

train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos]
progress_encode_bar = tqdm(
range(0, len(train_dataset.instance_videos)),
desc="Loading Encode videos",
)
train_dataset.instance_videos = [encode_video(video,progress_encode_bar) for video in train_dataset.instance_videos]
progress_encode_bar.close()

def collate_fn(examples):
videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples]
Expand Down
681 changes: 681 additions & 0 deletions examples/cogvideo/video_fix_rgb_float_and_crop.ipynb

Large diffs are not rendered by default.

Loading