2121from pathlib import Path
2222from typing import List , Optional , Tuple , Union
2323
24+ import numpy as np
2425import torch
26+ import torchvision .transforms as TT
2527import transformers
2628from accelerate import Accelerator
2729from accelerate .logging import get_logger
2830from accelerate .utils import DistributedDataParallelKwargs , ProjectConfiguration , set_seed
2931from huggingface_hub import create_repo , upload_folder
3032from peft import LoraConfig , get_peft_model_state_dict , set_peft_model_state_dict
3133from torch .utils .data import DataLoader , Dataset
32- from torchvision import transforms
34+ from torchvision .transforms import InterpolationMode
35+ from torchvision .transforms .functional import resize
3336from tqdm .auto import tqdm
3437from transformers import AutoTokenizer , T5EncoderModel , T5Tokenizer
3538
3639import diffusers
3740from diffusers import AutoencoderKLCogVideoX , CogVideoXDPMScheduler , CogVideoXPipeline , CogVideoXTransformer3DModel
41+ from diffusers .image_processor import VaeImageProcessor
3842from diffusers .models .embeddings import get_3d_rotary_pos_embed
3943from diffusers .optimization import get_scheduler
4044from diffusers .pipelines .cogvideo .pipeline_cogvideox import get_resize_crop_region_for_grid
41- from diffusers .training_utils import (
42- cast_training_params ,
43- clear_objs_and_retain_memory ,
44- )
45+ from diffusers .training_utils import cast_training_params , free_memory
4546from diffusers .utils import check_min_version , convert_unet_state_dict_to_peft , export_to_video , is_wandb_available
4647from diffusers .utils .hub_utils import load_or_create_model_card , populate_model_card
4748from diffusers .utils .torch_utils import is_compiled_module
@@ -217,6 +218,12 @@ def get_args():
217218 default = 720 ,
218219 help = "All input videos are resized to this width." ,
219220 )
221+ parser .add_argument (
222+ "--video_reshape_mode" ,
223+ type = str ,
224+ default = "center" ,
225+ help = "All input videos are reshaped to this mode. Choose between ['center', 'random', 'none']" ,
226+ )
220227 parser .add_argument ("--fps" , type = int , default = 8 , help = "All input videos will be used at this FPS." )
221228 parser .add_argument (
222229 "--max_num_frames" , type = int , default = 49 , help = "All input videos will be truncated to these many frames."
@@ -416,6 +423,7 @@ def __init__(
416423 video_column : str = "video" ,
417424 height : int = 480 ,
418425 width : int = 720 ,
426+ video_reshape_mode : str = "center" ,
419427 fps : int = 8 ,
420428 max_num_frames : int = 49 ,
421429 skip_frames_start : int = 0 ,
@@ -432,6 +440,7 @@ def __init__(
432440 self .video_column = video_column
433441 self .height = height
434442 self .width = width
443+ self .video_reshape_mode = video_reshape_mode
435444 self .fps = fps
436445 self .max_num_frames = max_num_frames
437446 self .skip_frames_start = skip_frames_start
@@ -535,6 +544,38 @@ def _load_dataset_from_local_path(self):
535544
536545 return instance_prompts , instance_videos
537546
547+ def _resize_for_rectangle_crop (self , arr ):
548+ image_size = self .height , self .width
549+ reshape_mode = self .video_reshape_mode
550+ if arr .shape [3 ] / arr .shape [2 ] > image_size [1 ] / image_size [0 ]:
551+ arr = resize (
552+ arr ,
553+ size = [image_size [0 ], int (arr .shape [3 ] * image_size [0 ] / arr .shape [2 ])],
554+ interpolation = InterpolationMode .BICUBIC ,
555+ )
556+ else :
557+ arr = resize (
558+ arr ,
559+ size = [int (arr .shape [2 ] * image_size [1 ] / arr .shape [3 ]), image_size [1 ]],
560+ interpolation = InterpolationMode .BICUBIC ,
561+ )
562+
563+ h , w = arr .shape [2 ], arr .shape [3 ]
564+ arr = arr .squeeze (0 )
565+
566+ delta_h = h - image_size [0 ]
567+ delta_w = w - image_size [1 ]
568+
569+ if reshape_mode == "random" or reshape_mode == "none" :
570+ top = np .random .randint (0 , delta_h + 1 )
571+ left = np .random .randint (0 , delta_w + 1 )
572+ elif reshape_mode == "center" :
573+ top , left = delta_h // 2 , delta_w // 2
574+ else :
575+ raise NotImplementedError
576+ arr = TT .functional .crop (arr , top = top , left = left , height = image_size [0 ], width = image_size [1 ])
577+ return arr
578+
538579 def _preprocess_data (self ):
539580 try :
540581 import decord
@@ -545,15 +586,14 @@ def _preprocess_data(self):
545586
546587 decord .bridge .set_bridge ("torch" )
547588
548- videos = []
549- train_transforms = transforms .Compose (
550- [
551- transforms .Lambda (lambda x : x / 255.0 * 2.0 - 1.0 ),
552- ]
589+ progress_dataset_bar = tqdm (
590+ range (0 , len (self .instance_video_paths )),
591+ desc = "Loading progress resize and crop videos" ,
553592 )
593+ videos = []
554594
555595 for filename in self .instance_video_paths :
556- video_reader = decord .VideoReader (uri = filename .as_posix (), width = self . width , height = self . height )
596+ video_reader = decord .VideoReader (uri = filename .as_posix ())
557597 video_num_frames = len (video_reader )
558598
559599 start_frame = min (self .skip_frames_start , video_num_frames )
@@ -579,10 +619,16 @@ def _preprocess_data(self):
579619 assert (selected_num_frames - 1 ) % 4 == 0
580620
581621 # Training transforms
582- frames = frames .float ()
583- frames = torch .stack ([train_transforms (frame ) for frame in frames ], dim = 0 )
584- videos .append (frames .permute (0 , 3 , 1 , 2 ).contiguous ()) # [F, C, H, W]
622+ frames = (frames - 127.5 ) / 127.5
623+ frames = frames .permute (0 , 3 , 1 , 2 ) # [F, C, H, W]
624+ progress_dataset_bar .set_description (
625+ f"Loading progress Resizing video from { frames .shape [2 ]} x{ frames .shape [3 ]} to { self .height } x{ self .width } "
626+ )
627+ frames = self ._resize_for_rectangle_crop (frames )
628+ videos .append (frames .contiguous ()) # [F, C, H, W]
629+ progress_dataset_bar .update (1 )
585630
631+ progress_dataset_bar .close ()
586632 return videos
587633
588634
@@ -697,8 +743,13 @@ def log_validation(
697743
698744 videos = []
699745 for _ in range (args .num_validation_videos ):
700- video = pipe (** pipeline_args , generator = generator , output_type = "np" ).frames [0 ]
701- videos .append (video )
746+ pt_images = pipe (** pipeline_args , generator = generator , output_type = "pt" ).frames [0 ]
747+ pt_images = torch .stack ([pt_images [i ] for i in range (pt_images .shape [0 ])])
748+
749+ image_np = VaeImageProcessor .pt_to_numpy (pt_images )
750+ image_pil = VaeImageProcessor .numpy_to_pil (image_np )
751+
752+ videos .append (image_pil )
702753
703754 for tracker in accelerator .trackers :
704755 phase_name = "test" if is_final_validation else "validation"
@@ -726,7 +777,8 @@ def log_validation(
726777 }
727778 )
728779
729- clear_objs_and_retain_memory ([pipe ])
780+ del pipe
781+ free_memory ()
730782
731783 return videos
732784
@@ -1173,6 +1225,7 @@ def load_model_hook(models, input_dir):
11731225 video_column = args .video_column ,
11741226 height = args .height ,
11751227 width = args .width ,
1228+ video_reshape_mode = args .video_reshape_mode ,
11761229 fps = args .fps ,
11771230 max_num_frames = args .max_num_frames ,
11781231 skip_frames_start = args .skip_frames_start ,
@@ -1181,13 +1234,21 @@ def load_model_hook(models, input_dir):
11811234 id_token = args .id_token ,
11821235 )
11831236
1184- def encode_video (video ):
1237+ def encode_video (video , bar ):
1238+ bar .update (1 )
11851239 video = video .to (accelerator .device , dtype = vae .dtype ).unsqueeze (0 )
11861240 video = video .permute (0 , 2 , 1 , 3 , 4 ) # [B, C, F, H, W]
11871241 latent_dist = vae .encode (video ).latent_dist
11881242 return latent_dist
11891243
1190- train_dataset .instance_videos = [encode_video (video ) for video in train_dataset .instance_videos ]
1244+ progress_encode_bar = tqdm (
1245+ range (0 , len (train_dataset .instance_videos )),
1246+ desc = "Loading Encode videos" ,
1247+ )
1248+ train_dataset .instance_videos = [
1249+ encode_video (video , progress_encode_bar ) for video in train_dataset .instance_videos
1250+ ]
1251+ progress_encode_bar .close ()
11911252
11921253 def collate_fn (examples ):
11931254 videos = [example ["instance_video" ].sample () * vae .config .scaling_factor for example in examples ]
0 commit comments