1+ import argparse
2+ import json
3+ import os
4+
5+ import torch
6+ import torch .distributed as dist
7+
8+ from fastvideo .v1 .logger import init_logger
9+ from fastvideo .v1 .utils import maybe_download_model , shallow_asdict
10+ from fastvideo .v1 .distributed import init_distributed_environment , initialize_model_parallel
11+ from fastvideo .v1 .fastvideo_args import FastVideoArgs
12+ from fastvideo .v1 .configs .models .vaes import WanVAEConfig
13+ from fastvideo import PipelineConfig
14+ from fastvideo .v1 .pipelines .preprocess_pipeline import PreprocessPipeline
15+
16+ logger = init_logger (__name__ )
17+
18+ BASE_MODEL_PATH = "/workspace/data/Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
19+ MODEL_PATH = maybe_download_model (BASE_MODEL_PATH ,
20+ local_dir = os .path .join (
21+ 'data' , BASE_MODEL_PATH ))
22+
23+ def main (args ):
24+ # Assume using torchrun
25+ local_rank = int (os .getenv ("RANK" , 0 ))
26+ rank = int (os .environ .get ("RANK" , 0 ))
27+ world_size = int (os .getenv ("WORLD_SIZE" , 1 ))
28+ init_distributed_environment (world_size = world_size , rank = rank , local_rank = local_rank )
29+ initialize_model_parallel (tensor_model_parallel_size = world_size , sequence_model_parallel_size = world_size )
30+ torch .cuda .set_device (local_rank )
31+ if not dist .is_initialized ():
32+ dist .init_process_group (backend = "nccl" , init_method = "env://" , world_size = world_size , rank = local_rank )
33+
34+ pipeline_config = PipelineConfig .from_pretrained (MODEL_PATH )
35+ kwargs = {
36+ "use_cpu_offload" : False ,
37+ "vae_precision" : "fp32" ,
38+ "vae_config" : WanVAEConfig (load_encoder = True , load_decoder = False ),
39+ }
40+ pipeline_config_args = shallow_asdict (pipeline_config )
41+ pipeline_config_args .update (kwargs )
42+ fastvideo_args = FastVideoArgs (model_path = MODEL_PATH ,
43+ num_gpus = world_size ,
44+ device_str = "cuda" ,
45+ ** pipeline_config_args ,
46+ )
47+ fastvideo_args .check_fastvideo_args ()
48+ fastvideo_args .device = torch .device (f"cuda:{ local_rank } " )
49+
50+ pipeline = PreprocessPipeline (MODEL_PATH , fastvideo_args )
51+ pipeline .forward (batch = None , fastvideo_args = fastvideo_args , args = args )
52+
53+
54+ if __name__ == "__main__" :
55+ parser = argparse .ArgumentParser ()
56+ # dataset & dataloader
57+ parser .add_argument ("--model_path" , type = str , default = "data/mochi" )
58+ parser .add_argument ("--model_type" , type = str , default = "mochi" )
59+ parser .add_argument ("--data_merge_path" , type = str , required = True )
60+ parser .add_argument ("--validation_prompt_txt" , type = str )
61+ parser .add_argument ("--num_frames" , type = int , default = 163 )
62+ parser .add_argument (
63+ "--dataloader_num_workers" ,
64+ type = int ,
65+ default = 1 ,
66+ help = "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ,
67+ )
68+ parser .add_argument (
69+ "--preprocess_video_batch_size" ,
70+ type = int ,
71+ default = 2 ,
72+ help = "Batch size (per device) for the training dataloader." ,
73+ )
74+ parser .add_argument (
75+ "--preprocess_text_batch_size" ,
76+ type = int ,
77+ default = 8 ,
78+ help = "Batch size (per device) for the training dataloader." ,
79+ )
80+ parser .add_argument (
81+ "--samples_per_file" ,
82+ type = int ,
83+ default = 64
84+ )
85+ parser .add_argument (
86+ "--flush_frequency" ,
87+ type = int ,
88+ default = 256 ,
89+ help = "how often to save to parquet files"
90+ )
91+ parser .add_argument ("--num_latent_t" , type = int , default = 28 , help = "Number of latent timesteps." )
92+ parser .add_argument ("--max_height" , type = int , default = 480 )
93+ parser .add_argument ("--max_width" , type = int , default = 848 )
94+ parser .add_argument ("--video_length_tolerance_range" , type = int , default = 2.0 )
95+ parser .add_argument ("--group_frame" , action = "store_true" ) # TODO
96+ parser .add_argument ("--group_resolution" , action = "store_true" ) # TODO
97+ parser .add_argument ("--dataset" , default = "t2v" )
98+ parser .add_argument ("--train_fps" , type = int , default = 30 )
99+ parser .add_argument ("--use_image_num" , type = int , default = 0 )
100+ parser .add_argument ("--text_max_length" , type = int , default = 256 )
101+ parser .add_argument ("--speed_factor" , type = float , default = 1.0 )
102+ parser .add_argument ("--drop_short_ratio" , type = float , default = 1.0 )
103+ # text encoder & vae & diffusion model
104+ parser .add_argument ("--text_encoder_name" , type = str , default = "google/t5-v1_1-xxl" )
105+ parser .add_argument ("--cache_dir" , type = str , default = "./cache_dir" )
106+ parser .add_argument ("--cfg" , type = float , default = 0.0 )
107+ parser .add_argument (
108+ "--output_dir" ,
109+ type = str ,
110+ default = None ,
111+ help = "The output directory where the model predictions and checkpoints will be written." ,
112+ )
113+ parser .add_argument (
114+ "--logging_dir" ,
115+ type = str ,
116+ default = "logs" ,
117+ help = ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
118+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ),
119+ )
120+
121+ args = parser .parse_args ()
122+ main (args )
0 commit comments