Skip to content

Commit 00f7519

Browse files
glide-thea-r-r-o-w
andauthored
add VideoDatasetWithResizeAndRectangleCrop dataset resize crop (#13)
* fix issue 12: train dataset _preprocess_data crop video * add_argument video_reshape_mode input videos are reshaped to this mode * arg enable_model_cpu_offloading * Add imageio-ffmpeg and imageio to requirements.txt * apply changes to i2v lora script * make style --------- Co-authored-by: --unset <--unset> Co-authored-by: Aryan <[email protected]>
1 parent b1b72c0 commit 00f7519

File tree

7 files changed

+215
-45
lines changed

7 files changed

+215
-45
lines changed

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,7 @@ pandas
1111
torch
1212
torchvision
1313
torchao
14+
sentencepiece
15+
imageio-ffmpeg
16+
imageio
17+
numpy==2.1.1

training/args.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,12 @@ def _get_training_args(parser: argparse.ArgumentParser) -> None:
192192
default=720,
193193
help="All input videos are resized to this width.",
194194
)
195+
parser.add_argument(
196+
"--video_reshape_mode",
197+
type=str,
198+
default=None,
199+
help="All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']",
200+
)
195201
parser.add_argument("--fps", type=int, default=8, help="All input videos will be used at this FPS.")
196202
parser.add_argument(
197203
"--max_num_frames",

training/cogvideox_image_to_video_lora.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555

5656

5757
from args import get_args # isort:skip
58-
from dataset import BucketSampler, VideoDatasetWithResizing # isort:skip
58+
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
5959
from text_encoder import compute_prompt_embeddings # isort:skip
6060
from utils import get_gradient_norm, get_optimizer, prepare_rotary_positional_embeddings, print_memory, reset_memory # isort:skip
6161

@@ -465,20 +465,37 @@ def load_model_hook(models, input_dir):
465465
)
466466

467467
# Dataset and DataLoader
468-
train_dataset = VideoDatasetWithResizing(
469-
data_root=args.data_root,
470-
dataset_file=args.dataset_file,
471-
caption_column=args.caption_column,
472-
video_column=args.video_column,
473-
max_num_frames=args.max_num_frames,
474-
id_token=args.id_token,
475-
height_buckets=args.height_buckets,
476-
width_buckets=args.width_buckets,
477-
frame_buckets=args.frame_buckets,
478-
load_tensors=args.load_tensors,
479-
random_flip=args.random_flip,
480-
image_to_video=True,
481-
)
468+
if not args.video_reshape_mode:
469+
train_dataset = VideoDatasetWithResizing(
470+
data_root=args.data_root,
471+
dataset_file=args.dataset_file,
472+
caption_column=args.caption_column,
473+
video_column=args.video_column,
474+
max_num_frames=args.max_num_frames,
475+
id_token=args.id_token,
476+
height_buckets=args.height_buckets,
477+
width_buckets=args.width_buckets,
478+
frame_buckets=args.frame_buckets,
479+
load_tensors=args.load_tensors,
480+
random_flip=args.random_flip,
481+
image_to_video=True,
482+
)
483+
else:
484+
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
485+
video_reshape_mode=args.video_reshape_mode,
486+
data_root=args.data_root,
487+
dataset_file=args.dataset_file,
488+
caption_column=args.caption_column,
489+
video_column=args.video_column,
490+
max_num_frames=args.max_num_frames,
491+
id_token=args.id_token,
492+
height_buckets=args.height_buckets,
493+
width_buckets=args.width_buckets,
494+
frame_buckets=args.frame_buckets,
495+
load_tensors=args.load_tensors,
496+
random_flip=args.random_flip,
497+
image_to_video=True,
498+
)
482499

483500
def collate_fn(data):
484501
prompts = [x["prompt"] for x in data[0]]

training/cogvideox_text_to_video_lora.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454

5555

5656
from args import get_args # isort:skip
57-
from dataset import BucketSampler, VideoDatasetWithResizing # isort:skip
57+
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
5858
from text_encoder import compute_prompt_embeddings # isort:skip
5959
from utils import get_gradient_norm, get_optimizer, prepare_rotary_positional_embeddings, print_memory, reset_memory # isort:skip
6060

@@ -462,19 +462,35 @@ def load_model_hook(models, input_dir):
462462
)
463463

464464
# Dataset and DataLoader
465-
train_dataset = VideoDatasetWithResizing(
466-
data_root=args.data_root,
467-
dataset_file=args.dataset_file,
468-
caption_column=args.caption_column,
469-
video_column=args.video_column,
470-
max_num_frames=args.max_num_frames,
471-
id_token=args.id_token,
472-
height_buckets=args.height_buckets,
473-
width_buckets=args.width_buckets,
474-
frame_buckets=args.frame_buckets,
475-
load_tensors=args.load_tensors,
476-
random_flip=args.random_flip,
477-
)
465+
if not args.video_reshape_mode:
466+
train_dataset = VideoDatasetWithResizing(
467+
data_root=args.data_root,
468+
dataset_file=args.dataset_file,
469+
caption_column=args.caption_column,
470+
video_column=args.video_column,
471+
max_num_frames=args.max_num_frames,
472+
id_token=args.id_token,
473+
height_buckets=args.height_buckets,
474+
width_buckets=args.width_buckets,
475+
frame_buckets=args.frame_buckets,
476+
load_tensors=args.load_tensors,
477+
random_flip=args.random_flip,
478+
)
479+
else:
480+
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
481+
video_reshape_mode=args.video_reshape_mode,
482+
data_root=args.data_root,
483+
dataset_file=args.dataset_file,
484+
caption_column=args.caption_column,
485+
video_column=args.video_column,
486+
max_num_frames=args.max_num_frames,
487+
id_token=args.id_token,
488+
height_buckets=args.height_buckets,
489+
width_buckets=args.width_buckets,
490+
frame_buckets=args.frame_buckets,
491+
load_tensors=args.load_tensors,
492+
random_flip=args.random_flip,
493+
)
478494

479495
def collate_fn(data):
480496
prompts = [x["prompt"] for x in data[0]]

training/cogvideox_text_to_video_sft.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@
5353

5454

5555
from args import get_args # isort:skip
56-
from dataset import BucketSampler, VideoDatasetWithResizing # isort:skip
56+
from dataset import BucketSampler, VideoDatasetWithResizing, VideoDatasetWithResizeAndRectangleCrop # isort:skip
5757
from text_encoder import compute_prompt_embeddings # isort:skip
5858
from utils import get_gradient_norm, get_optimizer, prepare_rotary_positional_embeddings, print_memory, reset_memory # isort:skip
5959

@@ -426,19 +426,35 @@ def load_model_hook(models, input_dir):
426426
)
427427

428428
# Dataset and DataLoader
429-
train_dataset = VideoDatasetWithResizing(
430-
data_root=args.data_root,
431-
dataset_file=args.dataset_file,
432-
caption_column=args.caption_column,
433-
video_column=args.video_column,
434-
max_num_frames=args.max_num_frames,
435-
id_token=args.id_token,
436-
height_buckets=args.height_buckets,
437-
width_buckets=args.width_buckets,
438-
frame_buckets=args.frame_buckets,
439-
load_tensors=args.load_tensors,
440-
random_flip=args.random_flip,
441-
)
429+
if not args.video_reshape_mode:
430+
train_dataset = VideoDatasetWithResizing(
431+
data_root=args.data_root,
432+
dataset_file=args.dataset_file,
433+
caption_column=args.caption_column,
434+
video_column=args.video_column,
435+
max_num_frames=args.max_num_frames,
436+
id_token=args.id_token,
437+
height_buckets=args.height_buckets,
438+
width_buckets=args.width_buckets,
439+
frame_buckets=args.frame_buckets,
440+
load_tensors=args.load_tensors,
441+
random_flip=args.random_flip,
442+
)
443+
else:
444+
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
445+
video_reshape_mode=args.video_reshape_mode,
446+
data_root=args.data_root,
447+
dataset_file=args.dataset_file,
448+
caption_column=args.caption_column,
449+
video_column=args.video_column,
450+
max_num_frames=args.max_num_frames,
451+
id_token=args.id_token,
452+
height_buckets=args.height_buckets,
453+
width_buckets=args.width_buckets,
454+
frame_buckets=args.frame_buckets,
455+
load_tensors=args.load_tensors,
456+
random_flip=args.random_flip,
457+
)
442458

443459
def collate_fn(data):
444460
prompts = [x["prompt"] for x in data[0]]

training/dataset.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,14 @@
22
from pathlib import Path
33
from typing import Any, Dict, List, Optional, Tuple
44

5+
import numpy as np
56
import pandas as pd
67
import torch
8+
import torchvision.transforms as TT
79
from accelerate.logging import get_logger
810
from torch.utils.data import Dataset, Sampler
911
from torchvision import transforms
12+
from torchvision.transforms import InterpolationMode
1013
from torchvision.transforms.functional import resize
1114

1215

@@ -281,6 +284,71 @@ def _find_nearest_resolution(self, height, width):
281284
return nearest_res[1], nearest_res[2]
282285

283286

287+
class VideoDatasetWithResizeAndRectangleCrop(VideoDataset):
288+
def __init__(self, video_reshape_mode: str = "center", *args, **kwargs) -> None:
289+
super().__init__(*args, **kwargs)
290+
self.video_reshape_mode = video_reshape_mode
291+
292+
def _resize_for_rectangle_crop(self, arr, image_size):
293+
reshape_mode = self.video_reshape_mode
294+
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
295+
arr = resize(
296+
arr,
297+
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
298+
interpolation=InterpolationMode.BICUBIC,
299+
)
300+
else:
301+
arr = resize(
302+
arr,
303+
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
304+
interpolation=InterpolationMode.BICUBIC,
305+
)
306+
307+
h, w = arr.shape[2], arr.shape[3]
308+
arr = arr.squeeze(0)
309+
310+
delta_h = h - image_size[0]
311+
delta_w = w - image_size[1]
312+
313+
if reshape_mode == "random" or reshape_mode == "none":
314+
top = np.random.randint(0, delta_h + 1)
315+
left = np.random.randint(0, delta_w + 1)
316+
elif reshape_mode == "center":
317+
top, left = delta_h // 2, delta_w // 2
318+
else:
319+
raise NotImplementedError
320+
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
321+
return arr
322+
323+
def _preprocess_video(self, path: Path) -> torch.Tensor:
324+
if self.load_tensors:
325+
return self._load_preprocessed_latents_and_embeds(path)
326+
else:
327+
video_reader = decord.VideoReader(uri=path.as_posix())
328+
video_num_frames = len(video_reader)
329+
nearest_frame_bucket = min(
330+
self.frame_buckets, key=lambda x: abs(x - min(video_num_frames, self.max_num_frames))
331+
)
332+
333+
frame_indices = list(range(0, video_num_frames, video_num_frames // nearest_frame_bucket))
334+
335+
frames = video_reader.get_batch(frame_indices)
336+
frames = frames[:nearest_frame_bucket].float()
337+
frames = frames.permute(0, 3, 1, 2).contiguous()
338+
339+
nearest_res = self._find_nearest_resolution(frames.shape[2], frames.shape[3])
340+
frames_resized = self._resize_for_rectangle_crop(frames, nearest_res)
341+
frames = torch.stack([self.video_transforms(frame) for frame in frames_resized], dim=0)
342+
343+
image = frames[:1].clone() if self.image_to_video else None
344+
345+
return image, frames, None
346+
347+
def _find_nearest_resolution(self, height, width):
348+
nearest_res = min(self.resolutions, key=lambda x: abs(x[1] - height) + abs(x[2] - width))
349+
return nearest_res[1], nearest_res[2]
350+
351+
284352
class BucketSampler(Sampler):
285353
def __init__(self, data_source: VideoDataset, batch_size: int = 8, shuffle: bool = True) -> None:
286354
self.data_source = data_source

training/prepare_dataset.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,16 @@
77
import traceback
88
from typing import Any, Dict, List, Optional, Tuple, Union
99

10+
import numpy as np
1011
import pandas as pd
1112
import torch
1213
import torch.distributed as dist
14+
import torchvision.transforms as TT
1315
from diffusers import AutoencoderKLCogVideoX
1416
from diffusers.utils import export_to_video, get_logger
1517
from torchvision import transforms
18+
from torchvision.transforms import InterpolationMode
19+
from torchvision.transforms.functional import resize
1620
from tqdm import tqdm
1721
from transformers import T5EncoderModel, T5Tokenizer
1822

@@ -153,13 +157,51 @@ def load_dataset_from_csv(
153157
return prompts, video_paths
154158

155159

160+
def resize_for_rectangle_crop(arr, height, width, reshape_mode):
161+
image_size = height, width
162+
if arr.shape[3] / arr.shape[2] > image_size[1] / image_size[0]:
163+
arr = resize(
164+
arr,
165+
size=[image_size[0], int(arr.shape[3] * image_size[0] / arr.shape[2])],
166+
interpolation=InterpolationMode.BICUBIC,
167+
)
168+
else:
169+
arr = resize(
170+
arr,
171+
size=[int(arr.shape[2] * image_size[1] / arr.shape[3]), image_size[1]],
172+
interpolation=InterpolationMode.BICUBIC,
173+
)
174+
175+
h, w = arr.shape[2], arr.shape[3]
176+
arr = arr.squeeze(0)
177+
178+
delta_h = h - image_size[0]
179+
delta_w = w - image_size[1]
180+
181+
if reshape_mode == "random" or reshape_mode == "none":
182+
top = np.random.randint(0, delta_h + 1)
183+
left = np.random.randint(0, delta_w + 1)
184+
elif reshape_mode == "center":
185+
top, left = delta_h // 2, delta_w // 2
186+
else:
187+
raise NotImplementedError
188+
arr = TT.functional.crop(arr, top=top, left=left, height=image_size[0], width=image_size[1])
189+
return arr
190+
191+
156192
def load_and_preprocess_video(
157-
path: pathlib.Path, height: int, width: int, max_num_frames: int, video_transforms, num_threads: int = 0
193+
path: pathlib.Path,
194+
height: int,
195+
width: int,
196+
max_num_frames: int,
197+
video_transforms,
198+
num_threads: int = 0,
199+
video_reshape_mode: str = "center",
158200
) -> Optional[torch.Tensor]:
159201
frames = None
160202

161203
try:
162-
video_reader = decord.VideoReader(uri=path.as_posix(), height=height, width=width, num_threads=num_threads)
204+
video_reader = decord.VideoReader(uri=path.as_posix(), num_threads=num_threads)
163205
video_num_frames = len(video_reader)
164206

165207
if video_num_frames < max_num_frames:
@@ -172,6 +214,7 @@ def load_and_preprocess_video(
172214
frames: torch.Tensor = video_reader.get_batch(indices)
173215
frames = frames[:max_num_frames].float()
174216
frames = frames.permute(0, 3, 1, 2).contiguous()
217+
frames = resize_for_rectangle_crop(frames, height, width, video_reshape_mode)
175218
frames = torch.stack([video_transforms(frame) for frame in frames], dim=0)
176219
except Exception as e:
177220
logger.error(f"Error: {e}. Skipping video located at `{path.as_posix()}`")

0 commit comments

Comments
 (0)