1+ import  argparse 
2+ import  json 
3+ import  os 
4+ 
5+ import  torch 
6+ import  torch .distributed  as  dist 
7+ 
8+ from  fastvideo .v1 .logger  import  init_logger 
9+ from  fastvideo .v1 .utils  import  maybe_download_model , shallow_asdict 
10+ from  fastvideo .v1 .distributed  import  init_distributed_environment , initialize_model_parallel 
11+ from  fastvideo .v1 .fastvideo_args  import  FastVideoArgs 
12+ from  fastvideo .v1 .configs .models .vaes  import  WanVAEConfig 
13+ from  fastvideo  import  PipelineConfig 
14+ from  fastvideo .v1 .pipelines .preprocess_pipeline  import  PreprocessPipeline 
15+ 
16+ logger  =  init_logger (__name__ )
17+ 
18+ BASE_MODEL_PATH  =  "/workspace/data/Wan-AI/Wan2.1-T2V-1.3B-Diffusers" 
19+ MODEL_PATH  =  maybe_download_model (BASE_MODEL_PATH ,
20+                                   local_dir = os .path .join (
21+                                       'data' , BASE_MODEL_PATH ))
22+ 
23+ def  main (args ):
24+     # Assume using torchrun 
25+     local_rank  =  int (os .getenv ("RANK" , 0 ))
26+     rank  =  int (os .environ .get ("RANK" , 0 ))
27+     world_size  =  int (os .getenv ("WORLD_SIZE" , 1 ))
28+     init_distributed_environment (world_size = world_size , rank = rank , local_rank = local_rank )
29+     initialize_model_parallel (tensor_model_parallel_size = world_size , sequence_model_parallel_size = world_size )
30+     torch .cuda .set_device (local_rank )
31+     if  not  dist .is_initialized ():
32+         dist .init_process_group (backend = "nccl" , init_method = "env://" , world_size = world_size , rank = local_rank )
33+ 
34+     pipeline_config  =  PipelineConfig .from_pretrained (MODEL_PATH )
35+     kwargs  =  {
36+         "use_cpu_offload" : False ,
37+         "vae_precision" : "fp32" ,
38+         "vae_config" : WanVAEConfig (load_encoder = True , load_decoder = False ),
39+     }
40+     pipeline_config_args  =  shallow_asdict (pipeline_config )
41+     pipeline_config_args .update (kwargs )
42+     fastvideo_args  =  FastVideoArgs (model_path = MODEL_PATH ,
43+                                    num_gpus = world_size ,
44+                                    device_str = "cuda" ,
45+                                    ** pipeline_config_args ,
46+                                    )
47+     fastvideo_args .check_fastvideo_args ()
48+     fastvideo_args .device  =  torch .device (f"cuda:{ local_rank }  " )
49+ 
50+     pipeline  =  PreprocessPipeline (MODEL_PATH , fastvideo_args )
51+     pipeline .forward (batch = None , fastvideo_args = fastvideo_args , args = args )
52+ 
53+ 
54+ if  __name__  ==  "__main__" :
55+     parser  =  argparse .ArgumentParser ()
56+     # dataset & dataloader 
57+     parser .add_argument ("--model_path" , type = str , default = "data/mochi" )
58+     parser .add_argument ("--model_type" , type = str , default = "mochi" )
59+     parser .add_argument ("--data_merge_path" , type = str , required = True )
60+     parser .add_argument ("--validation_prompt_txt" , type = str )
61+     parser .add_argument ("--num_frames" , type = int , default = 163 )
62+     parser .add_argument (
63+         "--dataloader_num_workers" ,
64+         type = int ,
65+         default = 1 ,
66+         help = "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ,
67+     )
68+     parser .add_argument (
69+         "--preprocess_video_batch_size" ,
70+         type = int ,
71+         default = 2 ,
72+         help = "Batch size (per device) for the training dataloader." ,
73+     )
74+     parser .add_argument (
75+         "--preprocess_text_batch_size" ,
76+         type = int ,
77+         default = 8 ,
78+         help = "Batch size (per device) for the training dataloader." ,
79+     )
80+     parser .add_argument (
81+         "--samples_per_file" ,
82+         type = int ,
83+         default = 64 
84+     )
85+     parser .add_argument (
86+         "--flush_frequency" ,
87+         type = int ,
88+         default = 256 ,
89+         help = "how often to save to parquet files" 
90+     )
91+     parser .add_argument ("--num_latent_t" , type = int , default = 28 , help = "Number of latent timesteps." )
92+     parser .add_argument ("--max_height" , type = int , default = 480 )
93+     parser .add_argument ("--max_width" , type = int , default = 848 )
94+     parser .add_argument ("--video_length_tolerance_range" , type = int , default = 2.0 )
95+     parser .add_argument ("--group_frame" , action = "store_true" )  # TODO 
96+     parser .add_argument ("--group_resolution" , action = "store_true" )  # TODO 
97+     parser .add_argument ("--dataset" , default = "t2v" )
98+     parser .add_argument ("--train_fps" , type = int , default = 30 )
99+     parser .add_argument ("--use_image_num" , type = int , default = 0 )
100+     parser .add_argument ("--text_max_length" , type = int , default = 256 )
101+     parser .add_argument ("--speed_factor" , type = float , default = 1.0 )
102+     parser .add_argument ("--drop_short_ratio" , type = float , default = 1.0 )
103+     # text encoder & vae & diffusion model 
104+     parser .add_argument ("--text_encoder_name" , type = str , default = "google/t5-v1_1-xxl" )
105+     parser .add_argument ("--cache_dir" , type = str , default = "./cache_dir" )
106+     parser .add_argument ("--cfg" , type = float , default = 0.0 )
107+     parser .add_argument (
108+         "--output_dir" ,
109+         type = str ,
110+         default = None ,
111+         help = "The output directory where the model predictions and checkpoints will be written." ,
112+     )
113+     parser .add_argument (
114+         "--logging_dir" ,
115+         type = str ,
116+         default = "logs" ,
117+         help = ("[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" 
118+               " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ),
119+     )
120+ 
121+     args  =  parser .parse_args ()
122+     main (args )
0 commit comments