@@ -971,6 +971,7 @@ def INPUT_TYPES(s):
971971 "model" : (
972972 [
973973 "kijai/CogVideoX-5b-Tora" ,
974+ "kijai/CogVideoX-5b-Tora-I2V" ,
974975 ],
975976 ),
976977 },
@@ -1000,14 +1001,17 @@ def loadmodel(self, model):
10001001 pass
10011002
10021003 download_path = os .path .join (folder_paths .models_dir , 'CogVideo' , "CogVideoX-5b-Tora" )
1003- fuser_path = os .path .join (download_path , "fuser" , "fuser.safetensors" )
1004+
1005+
1006+ fuser_model = "fuser.safetensors" if not "I2V" in model else "fuser_I2V.safetensors"
1007+ fuser_path = os .path .join (download_path , "fuser" , fuser_model )
10041008 if not os .path .exists (fuser_path ):
10051009 log .info (f"Downloading Fuser model to: { fuser_path } " )
10061010 from huggingface_hub import snapshot_download
10071011
10081012 snapshot_download (
10091013 repo_id = model ,
1010- allow_patterns = ["*fuser.safetensors*" ],
1014+ allow_patterns = [fuser_model ],
10111015 local_dir = download_path ,
10121016 local_dir_use_symlinks = False ,
10131017 )
@@ -1029,14 +1033,15 @@ def loadmodel(self, model):
10291033 param .data = param .data .to (torch .bfloat16 ).to (device )
10301034 del fuser_sd
10311035
1032- traj_extractor_path = os .path .join (download_path , "traj_extractor" , "traj_extractor.safetensors" )
1036+ traj_extractor_model = "traj_extractor.safetensors" if not "I2V" in model else "traj_extractor_I2V.safetensors"
1037+ traj_extractor_path = os .path .join (download_path , "traj_extractor" , traj_extractor_model )
10331038 if not os .path .exists (traj_extractor_path ):
10341039 log .info (f"Downloading trajectory extractor model to: { traj_extractor_path } " )
10351040 from huggingface_hub import snapshot_download
10361041
10371042 snapshot_download (
10381043 repo_id = "kijai/CogVideoX-5b-Tora" ,
1039- allow_patterns = ["*traj_extractor.safetensors*" ],
1044+ allow_patterns = [traj_extractor_model ],
10401045 local_dir = download_path ,
10411046 local_dir_use_symlinks = False ,
10421047 )
0 commit comments