2121from pathlib import Path
2222from typing import List , Optional , Tuple , Union
2323
24+ import numpy as np
2425import torch
26+ import torchvision .transforms as TT
2527import transformers
2628from accelerate import Accelerator
2729from accelerate .logging import get_logger
2830from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
2931from huggingface_hub import create_repo , upload_folder
3032from peft import LoraConfig , get_peft_model_state_dict , set_peft_model_state_dict
3133from torch .utils .data import DataLoader , Dataset
32- from torchvision import transforms
34+ from torchvision .transforms import InterpolationMode
35+ from torchvision .transforms .functional import resize
3336from tqdm .auto import tqdm
3437from transformers import AutoTokenizer , T5EncoderModel , T5Tokenizer
3538
3639import diffusers
3740from diffusers import AutoencoderKLCogVideoX , CogVideoXDPMScheduler , CogVideoXPipeline , CogVideoXTransformer3DModel
41+ from diffusers .image_processor import VaeImageProcessor
3842from diffusers .models .embeddings import get_3d_rotary_pos_embed
3943from diffusers .optimization import get_scheduler
4044from diffusers .pipelines .cogvideo .pipeline_cogvideox import get_resize_crop_region_for_grid
@@ -214,6 +218,12 @@ def get_args():
214218 default = 720 ,
215219 help = "All input videos are resized to this width." ,
216220 )
221+ parser .add_argument (
222+ "--video_reshape_mode" ,
223+ type = str ,
224+ default = "center" ,
225+ help = "All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']" ,
226+ )
217227 parser .add_argument ("--fps" , type = int , default = 8 , help = "All input videos will be used at this FPS." )
218228 parser .add_argument (
219229 "--max_num_frames" , type = int , default = 49 , help = "All input videos will be truncated to these many frames."
@@ -413,6 +423,7 @@ def __init__(
413423 video_column : str = "video" ,
414424 height : int = 480 ,
415425 width : int = 720 ,
426+ video_reshape_mode : str = "center" ,
416427 fps : int = 8 ,
417428 max_num_frames : int = 49 ,
418429 skip_frames_start : int = 0 ,
@@ -429,6 +440,7 @@ def __init__(
429440 self .video_column = video_column
430441 self .height = height
431442 self .width = width
443+ self .video_reshape_mode = video_reshape_mode
432444 self .fps = fps
433445 self .max_num_frames = max_num_frames
434446 self .skip_frames_start = skip_frames_start
@@ -532,6 +544,38 @@ def _load_dataset_from_local_path(self):
532544
533545 return instance_prompts , instance_videos
534546
547+ def _resize_for_rectangle_crop (self , arr ):
548+ image_size = self .height , self .width
549+ reshape_mode = self .video_reshape_mode
550+ if arr .shape [3 ] / arr .shape [2 ] > image_size [1 ] / image_size [0 ]:
551+ arr = resize (
552+ arr ,
553+ size = [image_size [0 ], int (arr .shape [3 ] * image_size [0 ] / arr .shape [2 ])],
554+ interpolation = InterpolationMode .BICUBIC ,
555+ )
556+ else :
557+ arr = resize (
558+ arr ,
559+ size = [int (arr .shape [2 ] * image_size [1 ] / arr .shape [3 ]), image_size [1 ]],
560+ interpolation = InterpolationMode .BICUBIC ,
561+ )
562+
563+ h , w = arr .shape [2 ], arr .shape [3 ]
564+ arr = arr .squeeze (0 )
565+
566+ delta_h = h - image_size [0 ]
567+ delta_w = w - image_size [1 ]
568+
569+ if reshape_mode == "random" or reshape_mode == "none" :
570+ top = np .random .randint (0 , delta_h + 1 )
571+ left = np .random .randint (0 , delta_w + 1 )
572+ elif reshape_mode == "center" :
573+ top , left = delta_h // 2 , delta_w // 2
574+ else :
575+ raise NotImplementedError
576+ arr = TT .functional .crop (arr , top = top , left = left , height = image_size [0 ], width = image_size [1 ])
577+ return arr
578+
535579 def _preprocess_data (self ):
536580 try :
537581 import decord
@@ -542,15 +586,14 @@ def _preprocess_data(self):
542586
543587 decord .bridge .set_bridge ("torch" )
544588
545- videos = []
546- train_transforms = transforms .Compose (
547- [
548- transforms .Lambda (lambda x : x / 255.0 * 2.0 - 1.0 ),
549- ]
589+ progress_dataset_bar = tqdm (
590+ range (0 , len (self .instance_video_paths )),
591+ desc = "Loading progress resize and crop videos" ,
550592 )
593+ videos = []
551594
552595 for filename in self .instance_video_paths :
553- video_reader = decord .VideoReader (uri = filename .as_posix (), width = self . width , height = self . height )
596+ video_reader = decord .VideoReader (uri = filename .as_posix ())
554597 video_num_frames = len (video_reader )
555598
556599 start_frame = min (self .skip_frames_start , video_num_frames )
@@ -576,10 +619,16 @@ def _preprocess_data(self):
576619 assert (selected_num_frames - 1 ) % 4 == 0
577620
578621 # Training transforms
579- frames = frames .float ()
580- frames = torch .stack ([train_transforms (frame ) for frame in frames ], dim = 0 )
581- videos .append (frames .permute (0 , 3 , 1 , 2 ).contiguous ()) # [F, C, H, W]
622+ frames = (frames - 127.5 ) / 127.5
623+ frames = frames .permute (0 , 3 , 1 , 2 ) # [F, C, H, W]
624+ progress_dataset_bar .set_description (
625+ f"Loading progress Resizing video from { frames .shape [2 ]} x{ frames .shape [3 ]} to { self .height } x{ self .width } "
626+ )
627+ frames = self ._resize_for_rectangle_crop (frames )
628+ videos .append (frames .contiguous ()) # [F, C, H, W]
629+ progress_dataset_bar .update (1 )
582630
631+ progress_dataset_bar .close ()
583632 return videos
584633
585634
@@ -694,8 +743,13 @@ def log_validation(
694743
695744 videos = []
696745 for _ in range (args .num_validation_videos ):
697- video = pipe (** pipeline_args , generator = generator , output_type = "np" ).frames [0 ]
698- videos .append (video )
746+ pt_images = pipe (** pipeline_args , generator = generator , output_type = "pt" ).frames [0 ]
747+ pt_images = torch .stack ([pt_images [i ] for i in range (pt_images .shape [0 ])])
748+
749+ image_np = VaeImageProcessor .pt_to_numpy (pt_images )
750+ image_pil = VaeImageProcessor .numpy_to_pil (image_np )
751+
752+ videos .append (image_pil )
699753
700754 for tracker in accelerator .trackers :
701755 phase_name = "test" if is_final_validation else "validation"
@@ -1171,6 +1225,7 @@ def load_model_hook(models, input_dir):
11711225 video_column = args .video_column ,
11721226 height = args .height ,
11731227 width = args .width ,
1228+ video_reshape_mode = args .video_reshape_mode ,
11741229 fps = args .fps ,
11751230 max_num_frames = args .max_num_frames ,
11761231 skip_frames_start = args .skip_frames_start ,
@@ -1179,13 +1234,21 @@ def load_model_hook(models, input_dir):
11791234 id_token = args .id_token ,
11801235 )
11811236
1182- def encode_video (video ):
1237+ def encode_video (video , bar ):
1238+ bar .update (1 )
11831239 video = video .to (accelerator .device , dtype = vae .dtype ).unsqueeze (0 )
11841240 video = video .permute (0 , 2 , 1 , 3 , 4 ) # [B, C, F, H, W]
11851241 latent_dist = vae .encode (video ).latent_dist
11861242 return latent_dist
11871243
1188- train_dataset .instance_videos = [encode_video (video ) for video in train_dataset .instance_videos ]
1244+ progress_encode_bar = tqdm (
1245+ range (0 , len (train_dataset .instance_videos )),
1246+ desc = "Loading Encode videos" ,
1247+ )
1248+ train_dataset .instance_videos = [
1249+ encode_video (video , progress_encode_bar ) for video in train_dataset .instance_videos
1250+ ]
1251+ progress_encode_bar .close ()
11891252
11901253 def collate_fn (examples ):
11911254 videos = [example ["instance_video" ].sample () * vae .config .scaling_factor for example in examples ]
0 commit comments