1515
1616logger  =  init_logger (__name__ )
1717
18- BASE_MODEL_PATH  =  "/workspace/data/Wan-AI/Wan2.1-T2V-1.3B-Diffusers" 
19- MODEL_PATH  =  maybe_download_model (BASE_MODEL_PATH ,
20-                                   local_dir = os .path .join (
21-                                       'data' , BASE_MODEL_PATH ))
22- 
2318def  main (args ):
2419    # Assume using torchrun 
2520    local_rank  =  int (os .getenv ("RANK" , 0 ))
@@ -31,23 +26,23 @@ def main(args):
3126    if  not  dist .is_initialized ():
3227        dist .init_process_group (backend = "nccl" , init_method = "env://" , world_size = world_size , rank = local_rank )
3328
34-     pipeline_config  =  PipelineConfig .from_pretrained (MODEL_PATH )
29+     pipeline_config  =  PipelineConfig .from_pretrained (args . model_path )
3530    kwargs  =  {
3631        "use_cpu_offload" : False ,
3732        "vae_precision" : "fp32" ,
3833        "vae_config" : WanVAEConfig (load_encoder = True , load_decoder = False ),
3934    }
4035    pipeline_config_args  =  shallow_asdict (pipeline_config )
4136    pipeline_config_args .update (kwargs )
42-     fastvideo_args  =  FastVideoArgs (model_path = MODEL_PATH ,
37+     fastvideo_args  =  FastVideoArgs (model_path = args . model_path ,
4338                                   num_gpus = world_size ,
4439                                   device_str = "cuda" ,
4540                                   ** pipeline_config_args ,
4641                                   )
4742    fastvideo_args .check_fastvideo_args ()
4843    fastvideo_args .device  =  torch .device (f"cuda:{ local_rank }  )
4944
50-     pipeline  =  PreprocessPipeline (MODEL_PATH , fastvideo_args )
45+     pipeline  =  PreprocessPipeline (args . model_path , fastvideo_args )
5146    pipeline .forward (batch = None , fastvideo_args = fastvideo_args , args = args )
5247
5348
0 commit comments