Skip to content

Commit 8c5e4f8

Browse files
committed
support other tora model
1 parent eaaa0f6 commit 8c5e4f8

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

model_loading.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -971,6 +971,7 @@ def INPUT_TYPES(s):
971971
"model": (
972972
[
973973
"kijai/CogVideoX-5b-Tora",
974+
"kijai/CogVideoX-5b-Tora-I2V",
974975
],
975976
),
976977
},
@@ -1000,14 +1001,17 @@ def loadmodel(self, model):
10001001
pass
10011002

10021003
download_path = os.path.join(folder_paths.models_dir, 'CogVideo', "CogVideoX-5b-Tora")
1003-
fuser_path = os.path.join(download_path, "fuser", "fuser.safetensors")
1004+
1005+
1006+
fuser_model = "fuser.safetensors" if not "I2V" in model else "fuser_I2V.safetensors"
1007+
fuser_path = os.path.join(download_path, "fuser", fuser_model)
10041008
if not os.path.exists(fuser_path):
10051009
log.info(f"Downloading Fuser model to: {fuser_path}")
10061010
from huggingface_hub import snapshot_download
10071011

10081012
snapshot_download(
10091013
repo_id=model,
1010-
allow_patterns=["*fuser.safetensors*"],
1014+
allow_patterns=[fuser_model],
10111015
local_dir=download_path,
10121016
local_dir_use_symlinks=False,
10131017
)
@@ -1029,14 +1033,15 @@ def loadmodel(self, model):
10291033
param.data = param.data.to(torch.bfloat16).to(device)
10301034
del fuser_sd
10311035

1032-
traj_extractor_path = os.path.join(download_path, "traj_extractor", "traj_extractor.safetensors")
1036+
traj_extractor_model = "traj_extractor.safetensors" if not "I2V" in model else "traj_extractor_I2V.safetensors"
1037+
traj_extractor_path = os.path.join(download_path, "traj_extractor", traj_extractor_model)
10331038
if not os.path.exists(traj_extractor_path):
10341039
log.info(f"Downloading trajectory extractor model to: {traj_extractor_path}")
10351040
from huggingface_hub import snapshot_download
10361041

10371042
snapshot_download(
10381043
repo_id="kijai/CogVideoX-5b-Tora",
1039-
allow_patterns=["*traj_extractor.safetensors*"],
1044+
allow_patterns=[traj_extractor_model],
10401045
local_dir=download_path,
10411046
local_dir_use_symlinks=False,
10421047
)

pipeline_cogvideox.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ def __call__(
814814
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
815815
progress_bar.update()
816816
if callback is not None:
817-
callback(i, latents.detach()[-1], None, num_inference_steps)
817+
callback(i, (latents - noise_pred * (t / 1000)).detach()[0], None, num_inference_steps)
818818
else:
819819
comfy_pbar.update(1)
820820

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-cogvideoxwrapper"
33
description = "Diffusers wrapper for CogVideoX -models: [a/https://github.com/THUDM/CogVideo](https://github.com/THUDM/CogVideo)"
4-
version = "1.5.0"
4+
version = "1.5.1"
55
license = {file = "LICENSE"}
66
dependencies = ["huggingface_hub", "diffusers>=0.31.0", "accelerate>=0.33.0"]
77

0 commit comments

Comments
 (0)