-
Couldn't load subscription status.
- Fork 6.4k
fix: CogVideox train dataset _preprocess_data crop video #9574
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
75beab5
d0f5b05
10bf85f
cdab2cf
ae94599
8115e41
ba7bb57
84d1b32
5178266
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -42,6 +42,10 @@ | |
| from diffusers.utils import check_min_version, convert_unet_state_dict_to_peft, export_to_video, is_wandb_available | ||
| from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card | ||
| from diffusers.utils.torch_utils import is_compiled_module | ||
| from torchvision.transforms.functional import center_crop, resize | ||
| from torchvision.transforms import InterpolationMode | ||
| import torchvision.transforms as TT | ||
| import numpy as np | ||
|
|
||
|
|
||
| if is_wandb_available(): | ||
|
|
@@ -214,6 +218,12 @@ def get_args(): | |
| default=720, | ||
| help="All input videos are resized to this width.", | ||
| ) | ||
| parser.add_argument( | ||
| "--video_reshape_mode", | ||
| type=str, | ||
| default="center", | ||
| help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']", | ||
| ) | ||
| parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.") | ||
| parser.add_argument( | ||
| "--max_num_frames", type=int, default=49, help="All input videos will be truncated to these many frames." | ||
|
|
@@ -413,6 +423,7 @@ def __init__( | |
| video_column: str = "video", | ||
| height: int = 480, | ||
| width: int = 720, | ||
| video_reshape_mode: str = "center", | ||
| fps: int = 8, | ||
| max_num_frames: int = 49, | ||
| skip_frames_start: int = 0, | ||
|
|
@@ -429,6 +440,7 @@ def __init__( | |
| self.video_column = video_column | ||
| self.height = height | ||
| self.width = width | ||
| self.video_reshape_mode = video_reshape_mode | ||
| self.fps = fps | ||
| self.max_num_frames = max_num_frames | ||
| self.skip_frames_start = skip_frames_start | ||
|
|
@@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self): | |
|
|
||
| return instance_prompts, instance_videos | ||
|
|
||
| def _resize_for_rectangle_crop(self, arr): | ||
| image_size = self.height, self.width | ||
| reshape_mode = self.video_reshape_mode | ||
| if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]: | ||
| arr = resize( | ||
| arr, | ||
| size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])], | ||
| interpolation=InterpolationMode.BICUBIC, | ||
| ) | ||
| else: | ||
| arr = resize( | ||
| arr, | ||
| size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]], | ||
| interpolation=InterpolationMode.BICUBIC, | ||
| ) | ||
|
|
||
| h, w = arr.shape[2], arr.shape[3] | ||
| arr = arr.squeeze(0) | ||
|
|
||
| delta_h = h - image_size[0] | ||
| delta_w = w - image_size[1] | ||
|
|
||
| if reshape_mode == "random" or reshape_mode == "none": | ||
| top = np.random.randint(0, delta_h + 1) | ||
| left = np.random.randint(0, delta_w + 1) | ||
| elif reshape_mode == "center": | ||
| top, left = delta_h // 2, delta_w // 2 | ||
| else: | ||
| raise NotImplementedError | ||
| arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1]) | ||
| return arr | ||
|
|
||
| def _preprocess_data(self): | ||
| try: | ||
| import decord | ||
|
|
@@ -542,14 +586,14 @@ def _preprocess_data(self): | |
|
|
||
| decord.bridge.set_bridge("torch") | ||
|
|
||
| videos = [] | ||
| train_transforms = transforms.Compose( | ||
| [ | ||
| transforms.Lambda(lambda x: x / 255.0 * 2.0 - 1.0), | ||
| ] | ||
| progress_dataset_bar = tqdm( | ||
| range(0, len(self.instance_video_paths)), | ||
| desc="Loading progress resize and crop videos", | ||
| ) | ||
| videos = [] | ||
|
|
||
| for filename in self.instance_video_paths: | ||
| progress_dataset_bar.update(1) | ||
| video_reader = decord.VideoReader(uri=filename.as_posix(), width=self.width, height=self.height) | ||
| video_num_frames = len(video_reader) | ||
|
|
||
|
|
@@ -576,10 +620,12 @@ def _preprocess_data(self): | |
| assert (selected_num_frames - 1) % 4 == 0 | ||
|
|
||
| # Training transforms | ||
| frames = frames.float() | ||
| frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) | ||
| videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] | ||
| tensor = frames.float() / 255.0 | ||
|
||
| frames = tensor.permute(0, 3, 1, 2) | ||
| frames = self._resize_for_rectangle_crop(frames) | ||
| videos.append(frames.contiguous()) # [F, C, H, W] | ||
|
|
||
| progress_dataset_bar.close() | ||
| return videos | ||
|
|
||
|
|
||
|
|
@@ -1171,6 +1217,7 @@ def load_model_hook(models, input_dir): | |
| video_column=args.video_column, | ||
| height=args.height, | ||
| width=args.width, | ||
| video_reshape_mode=args.video_reshape_mode, | ||
| fps=args.fps, | ||
| max_num_frames=args.max_num_frames, | ||
| skip_frames_start=args.skip_frames_start, | ||
|
|
@@ -1179,13 +1226,20 @@ def load_model_hook(models, input_dir): | |
| id_token=args.id_token, | ||
| ) | ||
|
|
||
| def encode_video(video): | ||
|
|
||
| def encode_video(video, bar): | ||
| bar.update(1) | ||
| video = video.to(accelerator.device, dtype=vae.dtype).unsqueeze(0) | ||
| video = video.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] | ||
| latent_dist = vae.encode(video).latent_dist | ||
| return latent_dist | ||
|
|
||
| train_dataset.instance_videos = [encode_video(video) for video in train_dataset.instance_videos] | ||
| progress_encode_bar = tqdm( | ||
| range(0, len(train_dataset.instance_videos)), | ||
| desc="Loading Encode videos", | ||
| ) | ||
| train_dataset.instance_videos = [encode_video(video,progress_encode_bar) for video in train_dataset.instance_videos] | ||
| progress_encode_bar.close() | ||
|
|
||
| def collate_fn(examples): | ||
| videos = [example["instance_video"].sample() * vae.config.scaling_factor for example in examples] | ||
|
|
||
Large diffs are not rendered by default.


Uh oh!
There was an error while loading. Please reload this page.