Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
226 changes: 225 additions & 1 deletion comfy_extras/nodes_hunyuan.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from numpy import arccos
import nodes
import node_helpers
import torch
import re
import comfy.model_management


class CLIPTextEncodeHunyuanDiT:
@classmethod
def INPUT_TYPES(s):
def INPUT_TYPES(cls):
return {"required": {
"clip": ("CLIP", ),
"bert": ("STRING", {"multiline": True, "dynamicPrompts": True}),
Expand All @@ -23,6 +25,220 @@ def encode(self, clip, bert, mt5xl):

return (clip.encode_from_tokens_scheduled(tokens), )

class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0

def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average

def normalized_guidance_apg(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]

if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average

if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor

v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)

normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + guidance_scale * normalized_update

return pred

class AdaptiveProjectedGuidance:
def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum=None,
adaptive_projected_guidance_rescale: float = 15.0,
# eta: float = 1.0,
eta: float = 0.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
start: float = 0.0,
stop: float = 1.0,
):
super().__init__()

self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None

def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, step=None) -> torch.Tensor:

if step == 0 and self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)

pred = normalized_guidance_apg(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)

return pred

class HunyuanMixModeAPG:

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"has_quoted_text": ("HAS_QUOTED_TEXT", ),

"guidance_scale": ("FLOAT", {"default": 9.0, "min": 1.0, "max": 30.0, "step": 0.1}),

"general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
"general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
"general_start_step": ("INT", {"default": 10, "min": -1, "max": 1000}),

"ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}),
"ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}),
"ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}),
"ocr_start_step": ("INT", {"default": 75, "min": -1, "max": 1000}),

}
}

RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_mix_mode_apg"
CATEGORY = "sampling/custom_sampling/hunyuan"


@classmethod
def IS_CHANGED(cls, model):
Copy link
Member

@Kosinkadink Kosinkadink Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IS_CHANGED still needs to be removed, as per my last comment in the first PR. Node application is not on a per-step basis - always returning True from IS_CHANGED will break the proper node execution behavior. The only things in a per step basis would be the code that is being registered on the model patcher.

I think there may be an issue you are experiencing that is being incorrectly dealt with IS_CHANGED. Can you describe what was happening when you didn't have the IS_CHANGED? You may need to instead add a wrapper function to initialize/reset the dictionary you're trying to use to count steps at sample time.

return True

def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step,
ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step):

general_apg = AdaptiveProjectedGuidance(
guidance_scale=guidance_scale,
eta=general_eta,
adaptive_projected_guidance_rescale=general_norm_threshold,
adaptive_projected_guidance_momentum=general_momentum
)

ocr_apg = AdaptiveProjectedGuidance(
eta=ocr_eta,
adaptive_projected_guidance_rescale=ocr_norm_threshold,
adaptive_projected_guidance_momentum=ocr_momentum
)

current_step = {"step": 0}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Kosinkadink The reason we always return True in IS_CHANGED is that we need to track the dict current_step, which should be re-initialized before every run to reset current sampling step to 0. I find it hard to retrieve the current step index from the KSampler node so I track it by implementing this

Copy link
Member

@Kosinkadink Kosinkadink Sep 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha, in that case, what you are looking for is making an OUTER_SAMPLE wrapper, an example of which is done for the easycache_sample_wrapper in EasyCache nodes for core.

IS_CHANGED does not work how you would expect - returning True does NOT make it run every time, and is also not the way to make sure the code runs every time the sampler is ran; it is separate from the sampling process entirely. The return value of IS_CHANGED is like a 'fingerprint' - if this fingerprint is different than the last time the node instance was run, then the execution function of the node is permitted to run. If the fingerprint is the same, the node will NOT run and the cached value will be used instead. Always returning True from IS_CHANGED means the node will run ONLY ONCE the whole time it is on the graph (the fingerprint will always be 'True', and will match 'True' of the previous run). In V3 schema, this function got renamed to fingerprint_inputs to be more clear in its effects. The node running is completely separate from the sampling process.

What you'll want to do is make sure the dictionary you want to access gets reinitialized in the OUTER_SAMPLE wrapper. The function that OUTER_SAMPLE wraps runs every time sampling is initiated in ComfyUI, so it's the perfect place to reset the step value back to zero. This function also encapsulates sampling before CFG, so it won't get screwed up by the amount of conditioning being passed in.


def cfg_function(args):
cond = args["cond"]
uncond = args["uncond"]
cond_scale = args["cond_scale"]

step = current_step["step"]
current_step["step"] += 1

if not has_quoted_text:
if step > general_start_step:
modified_cond = general_apg(cond, uncond, step).to(torch.bfloat16)
return modified_cond
else:
if cond_scale > 1:
_ = general_apg(cond, uncond, step) # track momentum
return uncond + (cond - uncond) * cond_scale
else:
if step > ocr_start_step:
modified_cond = ocr_apg(cond, uncond, step)
return modified_cond
else:
if cond_scale > 1:
_ = ocr_apg(cond, uncond, step)
return uncond + (cond - uncond) * cond_scale

return cond


m = model.clone()
m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True)
return (m,)

class CLIPTextEncodeHunyuanDiTWithTextDetection:

@classmethod
def INPUT_TYPES(cls):
return {"required": {
"clip": ("CLIP", ),
"text": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}}

RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT")
RETURN_NAMES = ("conditioning", "has_quoted_text")
FUNCTION = "encode"

CATEGORY = "advanced/conditioning/hunyuan"

def detect_quoted_text(self, text):
"""Detect quoted text in the prompt"""
text_prompt_texts = []

# Patterns to match different quote styles
pattern_quote_double = r'\"(.*?)\"'
pattern_quote_chinese_single = r'‘(.*?)’'
pattern_quote_chinese_double = r'“(.*?)”'

matches_quote_double = re.findall(pattern_quote_double, text)
matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text)
matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text)

text_prompt_texts.extend(matches_quote_double)
text_prompt_texts.extend(matches_quote_chinese_single)
text_prompt_texts.extend(matches_quote_chinese_double)

return len(text_prompt_texts) > 0

def encode(self, clip, text):
tokens = clip.tokenize(text)
has_quoted_text = self.detect_quoted_text(text)

conditioning = clip.encode_from_tokens_scheduled(tokens)

c = []
for t in conditioning:
n = [t[0], t[1].copy()]
n[1]['has_quoted_text'] = has_quoted_text
c.append(n)

return (c, has_quoted_text)

class EmptyHunyuanLatentVideo:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -151,8 +367,16 @@ def execute(self, positive, negative, latent, noise_augmentation):
return (positive, negative, out_latent)



NODE_DISPLAY_NAME_MAPPINGS = {
"HunyuanMixModeAPG": "Hunyuan Mix Mode APG",
"HunyuanStepBasedAPG": "Hunyuan Step Based APG",
}

NODE_CLASS_MAPPINGS = {
"HunyuanMixModeAPG": HunyuanMixModeAPG,
"CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT,
"CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection,
"TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo,
"EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo,
"HunyuanImageToVideo": HunyuanImageToVideo,
Expand Down