1515
1616logger = init_logger (__name__ )
1717
18- BASE_MODEL_PATH = "/workspace/data/Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
19- MODEL_PATH = maybe_download_model (BASE_MODEL_PATH ,
20- local_dir = os .path .join (
21- 'data' , BASE_MODEL_PATH ))
22-
2318def main (args ):
2419 # Assume using torchrun
2520 local_rank = int (os .getenv ("RANK" , 0 ))
@@ -31,23 +26,23 @@ def main(args):
3126 if not dist .is_initialized ():
3227 dist .init_process_group (backend = "nccl" , init_method = "env://" , world_size = world_size , rank = local_rank )
3328
34- pipeline_config = PipelineConfig .from_pretrained (MODEL_PATH )
29+ pipeline_config = PipelineConfig .from_pretrained (args . model_path )
3530 kwargs = {
3631 "use_cpu_offload" : False ,
3732 "vae_precision" : "fp32" ,
3833 "vae_config" : WanVAEConfig (load_encoder = True , load_decoder = False ),
3934 }
4035 pipeline_config_args = shallow_asdict (pipeline_config )
4136 pipeline_config_args .update (kwargs )
42- fastvideo_args = FastVideoArgs (model_path = MODEL_PATH ,
37+ fastvideo_args = FastVideoArgs (model_path = args . model_path ,
4338 num_gpus = world_size ,
4439 device_str = "cuda" ,
4540 ** pipeline_config_args ,
4641 )
4742 fastvideo_args .check_fastvideo_args ()
4843 fastvideo_args .device = torch .device (f"cuda:{ local_rank } " )
4944
50- pipeline = PreprocessPipeline (MODEL_PATH , fastvideo_args )
45+ pipeline = PreprocessPipeline (args . model_path , fastvideo_args )
5146 pipeline .forward (batch = None , fastvideo_args = fastvideo_args , args = args )
5247
5348
0 commit comments