Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
a3fd845
Initializing nodes_framepack.py for integrating framepack
Tekleab15 Aug 12, 2025
e0da993
Initializing the nodes_framepack.py for optimization
Tekleab15 Aug 12, 2025
ce3c122
Initial starting point framepack implementation
Tekleab15 Aug 12, 2025
4ad4f4f
Passing the framepack specific into the list of required parameters d…
Tekleab15 Aug 12, 2025
446cdb3
Refactoring predict_with_cfg function for better modularity and scala…
Tekleab15 Aug 14, 2025
85c2a0c
Created new nodes_framepack_new for better enhancement
Tekleab15 Aug 17, 2025
06187a9
Adopting the wan_vace specific files and implementations
Tekleab15 Aug 17, 2025
75ad835
Adopting vace_model specific classes
Tekleab15 Aug 17, 2025
22c0ac8
Initializing wan_vace specific model architecture
Tekleab15 Aug 17, 2025
717e257
Adapting preprocessor.py file for vace specific files
Tekleab15 Aug 17, 2025
900e6da
Creating new distributed with its vace specific migrations
Tekleab15 Aug 17, 2025
d6eb2ae
Importing Vace specific files to the new node implemntation
Tekleab15 Aug 17, 2025
552e198
Identifying the main required inputs of the sampler
Tekleab15 Aug 17, 2025
85848d3
Extracting positive and negative prompt embeds from the embed_dictionary
Tekleab15 Aug 17, 2025
c988f61
Calling the generator(sampler) generate_with_framepack with return of…
Tekleab15 Aug 17, 2025
4aa4265
Adopting the wan_vace.py file from the vace_specific implementation
Tekleab15 Aug 17, 2025
c2971a4
Structure file naming for Custom Node installation
Tekleab15 Aug 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 102 additions & 1 deletion nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import hashlib
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler

from .wanvideo.modules.model import rope_params
# Importing the Vace related classes
from .wanvideo.modules.model import rope_params,VaceWanModel, VaceWanAttentionBlock, BaseWanAttentionBlock

from .custom_linear import remove_lora_from_module, set_lora_params
from .wanvideo.schedulers import get_scheduler, get_sampling_sigmas, retrieve_timesteps, scheduler_list
from .gguf.gguf import set_lora_params_gguf
Expand All @@ -27,6 +29,10 @@
from comfy.cli_args import args, LatentPreviewMethod
import folder_paths

# Import the necessary FramePack classes
from .wanvideo.framepack_vace import FramepackVace
from .wanvideo.wan_video_vae import WanVideoVAE

script_directory = os.path.dirname(os.path.abspath(__file__))

device = mm.get_torch_device()
Expand Down Expand Up @@ -3191,6 +3197,99 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i
"samples": callback_latent.unsqueeze(0).cpu() if callback is not None else None,
})

# Framepack VACE specific Sampler
class WanVACEVideoFramepackSampler:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("WANVIDEOMODEL",), # Expects the loaded FramepackVace model
"steps": ("INT", {"default": 30, "min": 1}),
"cfg": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 30.0, "step": 0.01}),
"shift": ("FLOAT", {"default": 5.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
"scheduler": (scheduler_list, {"default": "uni_pc",}),
"text_embeds": ("WANVIDEOTEXTEMBEDS", ),
"frame_num": ("INT", {"default": 81, "min": 1}), # Total number of frames for the output video
"context_scale": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 2.0, "step": 0.1}),
"image_width": ("INT", {"default": 832, "min": 16}),
"image_height": ("INT", {"default": 480, "min": 16}),
"src_video": ("VIDEO", {"default": None}),
"src_mask": ("MASK", {"default": None}),
"src_ref_images": ("IMAGE", {"default": None}),
"force_offload": ("BOOLEAN", {"default": True, "tooltip": "Moves the model to the offload device after sampling"}),
},
"optional": {
"denoise_strength": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
"freeinit_args": ("FREEINITARGS", ),
"start_step": ("INT", {"default": 0, "min": 0, "max": 10000, "step": 1, "tooltip": "Start step for the sampling, 0 means full sampling, otherwise samples only from this step"}),
"end_step": ("INT", {"default": -1, "min": -1, "max": 10000, "step": 1, "tooltip": "End step for the sampling, -1 means full sampling, otherwise samples only until this step"}),
}
}

RETURN_TYPES = ("LATENT",)
RETURN_NAMES = ("samples",)
FUNCTION = "process"
CATEGORY = "WanVideoWrapper"
DESCRIPTION = "A sampler specifically for the FramePack algorithm for long video generation."

def process(self, model, steps, cfg, shift, seed, scheduler, text_embeds, frame_num, context_scale,
image_width, image_height, src_video=None, src_mask=None, src_ref_images=None, force_offload=True,
denoise_strength=1.0, freeinit_args=None, start_step=0, end_step=-1):
# Ensure the provided model is an instance of FramepackVace
if not isinstance(model.model, FramepackVace):
raise TypeError("This sampler requires a FramepackVace model. Please check your model loader node.")

# Get the FramepackVace instance and the current device
framepack_vace = model.model
current_device = mm.get_torch_device()

# Extract positive and negative prompts from the text_embeds dictionary
prompt = text_embeds["prompt_embeds"][0] if text_embeds["prompt_embeds"] else ""
n_prompt = text_embeds["negative_prompt_embeds"][0][0] if text_embeds["negative_prompt_embeds"] else ""

# ComfyUI's VIDEO format is (B, F, H, W, C). The model expects (B, C, F, H, W).
# Handle src_video: (B, F, H, W, C) -> (B, C, F, H, W)
input_frames_list = [src_video.permute(0, 4, 1, 2, 3)] if src_video is not None else [None]

# Handle src_mask: (B, H, W) -> (B, 1, F, H, W)
input_masks_list = [src_mask.unsqueeze(0).unsqueeze(0).permute(0, 1, 4, 2, 3)] if src_mask is not None else [None]

# Handle src_ref_images: (B, H, W, C) -> (B, C, 1, H, W)
input_ref_images_list = [src_ref_images.permute(0, 3, 1, 2).unsqueeze(1)] if src_ref_images is not None else [None]

# Calling the FramepackVace instance.
src_video_prepared, src_mask_prepared, src_ref_images_prepared = framepack_vace.prepare_source(
input_frames_list,
input_masks_list,
input_ref_images_list,
frame_num,
(image_height, image_width),
current_device
)

# Calling FramePack generation method.
log.info(f"Starting FramePack generation for {frame_num} frames.")
final_video_latent = framepack_vace.generate_with_framepack(
input_prompt=prompt,
input_frames=src_video_prepared,
input_masks=src_mask_prepared,
input_ref_images=src_ref_images_prepared,
size=(image_width, image_height),
frame_num=frame_num,
sample_solver=scheduler,
sampling_steps=steps,
guide_scale=cfg,
n_prompt=n_prompt,
seed=seed,
offload_model=force_offload
)

log.info(f"FramePack generation complete. Output latent tensor shape: {final_video_latent.shape}")

# The output of generate_with_framepack is a single tensor (C, T, H, W).
return ({"samples": final_video_latent.unsqueeze(0).cpu(),},)

#region VideoDecode
class WanVideoDecode:
@classmethod
Expand Down Expand Up @@ -3432,6 +3531,7 @@ def encode(self, samples, direction):
"WanVideoTextEncodeCached": WanVideoTextEncodeCached,
"WanVideoAddExtraLatent": WanVideoAddExtraLatent,
"WanVideoLatentReScale": WanVideoLatentReScale,
"WanVACEVideoFramepackSampler": WanVACEVideoFramepackSampler,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"WanVideoSampler": "WanVideo Sampler",
Expand Down Expand Up @@ -3463,4 +3563,5 @@ def encode(self, samples, direction):
"WanVideoTextEncodeCached": "WanVideo TextEncode Cached",
"WanVideoAddExtraLatent": "WanVideo Add Extra Latent",
"WanVideoLatentReScale": "WanVideo Latent ReScale",
"WanVACEVideoFramepackSampler": "WanVideo Framepack Sampler",
}
Loading