-
Notifications
You must be signed in to change notification settings - Fork 11.6k
HunyuanImage2.1: Implement Hunyuan Mixed APG #9882
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
f3f87e3
84f1acd
9492b98
6e6065b
f23bac8
fd20999
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,14 @@ | ||
| from numpy import arccos | ||
| import nodes | ||
| import node_helpers | ||
| import torch | ||
| import re | ||
| import comfy.model_management | ||
|
|
||
|
|
||
| class CLIPTextEncodeHunyuanDiT: | ||
| @classmethod | ||
| def INPUT_TYPES(s): | ||
| def INPUT_TYPES(cls): | ||
| return {"required": { | ||
| "clip": ("CLIP", ), | ||
| "bert": ("STRING", {"multiline": True, "dynamicPrompts": True}), | ||
|
|
@@ -23,6 +25,220 @@ def encode(self, clip, bert, mt5xl): | |
|
|
||
| return (clip.encode_from_tokens_scheduled(tokens), ) | ||
|
|
||
| class MomentumBuffer: | ||
| def __init__(self, momentum: float): | ||
| self.momentum = momentum | ||
| self.running_average = 0 | ||
|
|
||
| def update(self, update_value: torch.Tensor): | ||
| new_average = self.momentum * self.running_average | ||
| self.running_average = update_value + new_average | ||
|
|
||
| def normalized_guidance_apg( | ||
| pred_cond: torch.Tensor, | ||
| pred_uncond: torch.Tensor, | ||
| guidance_scale: float, | ||
| momentum_buffer, | ||
| eta: float = 1.0, | ||
| norm_threshold: float = 0.0, | ||
| use_original_formulation: bool = False, | ||
| ): | ||
| diff = pred_cond - pred_uncond | ||
| dim = [-i for i in range(1, len(diff.shape))] | ||
|
|
||
| if momentum_buffer is not None: | ||
| momentum_buffer.update(diff) | ||
| diff = momentum_buffer.running_average | ||
|
|
||
| if norm_threshold > 0: | ||
| ones = torch.ones_like(diff) | ||
| diff_norm = diff.norm(p=2, dim=dim, keepdim=True) | ||
| scale_factor = torch.minimum(ones, norm_threshold / diff_norm) | ||
| diff = diff * scale_factor | ||
|
|
||
| v0, v1 = diff.double(), pred_cond.double() | ||
| v1 = torch.nn.functional.normalize(v1, dim=dim) | ||
| v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 | ||
| v0_orthogonal = v0 - v0_parallel | ||
| diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) | ||
|
|
||
| normalized_update = diff_orthogonal + eta * diff_parallel | ||
| pred = pred_cond if use_original_formulation else pred_uncond | ||
| pred = pred + guidance_scale * normalized_update | ||
|
|
||
| return pred | ||
|
|
||
| class AdaptiveProjectedGuidance: | ||
| def __init__( | ||
| self, | ||
| guidance_scale: float = 7.5, | ||
| adaptive_projected_guidance_momentum=None, | ||
| adaptive_projected_guidance_rescale: float = 15.0, | ||
| # eta: float = 1.0, | ||
| eta: float = 0.0, | ||
| guidance_rescale: float = 0.0, | ||
| use_original_formulation: bool = False, | ||
| start: float = 0.0, | ||
| stop: float = 1.0, | ||
| ): | ||
| super().__init__() | ||
|
|
||
| self.guidance_scale = guidance_scale | ||
| self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum | ||
| self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale | ||
| self.eta = eta | ||
| self.guidance_rescale = guidance_rescale | ||
| self.use_original_formulation = use_original_formulation | ||
| self.momentum_buffer = None | ||
|
|
||
| def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, step=None) -> torch.Tensor: | ||
|
|
||
| if step == 0 and self.adaptive_projected_guidance_momentum is not None: | ||
| self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) | ||
|
|
||
| pred = normalized_guidance_apg( | ||
| pred_cond, | ||
| pred_uncond, | ||
| self.guidance_scale, | ||
| self.momentum_buffer, | ||
| self.eta, | ||
| self.adaptive_projected_guidance_rescale, | ||
| self.use_original_formulation, | ||
| ) | ||
|
|
||
| return pred | ||
|
|
||
| class HunyuanMixModeAPG: | ||
|
|
||
| @classmethod | ||
| def INPUT_TYPES(s): | ||
| return { | ||
| "required": { | ||
| "model": ("MODEL", ), | ||
| "has_quoted_text": ("HAS_QUOTED_TEXT", ), | ||
|
|
||
| "guidance_scale": ("FLOAT", {"default": 9.0, "min": 1.0, "max": 30.0, "step": 0.1}), | ||
|
|
||
| "general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), | ||
| "general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), | ||
| "general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), | ||
| "general_start_step": ("INT", {"default": 10, "min": -1, "max": 1000}), | ||
|
|
||
| "ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), | ||
| "ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), | ||
| "ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), | ||
| "ocr_start_step": ("INT", {"default": 75, "min": -1, "max": 1000}), | ||
|
|
||
| } | ||
| } | ||
|
|
||
| RETURN_TYPES = ("MODEL",) | ||
| FUNCTION = "apply_mix_mode_apg" | ||
| CATEGORY = "sampling/custom_sampling/hunyuan" | ||
|
|
||
|
|
||
| @classmethod | ||
| def IS_CHANGED(cls, model): | ||
| return True | ||
|
|
||
| def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step, | ||
| ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step): | ||
|
|
||
| general_apg = AdaptiveProjectedGuidance( | ||
| guidance_scale=guidance_scale, | ||
| eta=general_eta, | ||
| adaptive_projected_guidance_rescale=general_norm_threshold, | ||
| adaptive_projected_guidance_momentum=general_momentum | ||
| ) | ||
|
|
||
| ocr_apg = AdaptiveProjectedGuidance( | ||
| eta=ocr_eta, | ||
| adaptive_projected_guidance_rescale=ocr_norm_threshold, | ||
| adaptive_projected_guidance_momentum=ocr_momentum | ||
| ) | ||
|
|
||
| current_step = {"step": 0} | ||
|
||
|
|
||
| def cfg_function(args): | ||
| cond = args["cond"] | ||
| uncond = args["uncond"] | ||
| cond_scale = args["cond_scale"] | ||
|
|
||
| step = current_step["step"] | ||
| current_step["step"] += 1 | ||
|
|
||
| if not has_quoted_text: | ||
| if step > general_start_step: | ||
| modified_cond = general_apg(cond, uncond, step).to(torch.bfloat16) | ||
| return modified_cond | ||
| else: | ||
| if cond_scale > 1: | ||
| _ = general_apg(cond, uncond, step) # track momentum | ||
| return uncond + (cond - uncond) * cond_scale | ||
| else: | ||
| if step > ocr_start_step: | ||
| modified_cond = ocr_apg(cond, uncond, step) | ||
| return modified_cond | ||
| else: | ||
| if cond_scale > 1: | ||
| _ = ocr_apg(cond, uncond, step) | ||
| return uncond + (cond - uncond) * cond_scale | ||
|
|
||
| return cond | ||
|
|
||
|
|
||
| m = model.clone() | ||
| m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True) | ||
| return (m,) | ||
|
|
||
| class CLIPTextEncodeHunyuanDiTWithTextDetection: | ||
|
|
||
| @classmethod | ||
| def INPUT_TYPES(cls): | ||
| return {"required": { | ||
| "clip": ("CLIP", ), | ||
| "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), | ||
| }} | ||
|
|
||
| RETURN_TYPES = ("CONDITIONING", "HAS_QUOTED_TEXT") | ||
| RETURN_NAMES = ("conditioning", "has_quoted_text") | ||
| FUNCTION = "encode" | ||
|
|
||
| CATEGORY = "advanced/conditioning/hunyuan" | ||
|
|
||
| def detect_quoted_text(self, text): | ||
| """Detect quoted text in the prompt""" | ||
| text_prompt_texts = [] | ||
|
|
||
| # Patterns to match different quote styles | ||
| pattern_quote_double = r'\"(.*?)\"' | ||
| pattern_quote_chinese_single = r'‘(.*?)’' | ||
| pattern_quote_chinese_double = r'“(.*?)”' | ||
|
|
||
| matches_quote_double = re.findall(pattern_quote_double, text) | ||
| matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text) | ||
| matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text) | ||
|
|
||
| text_prompt_texts.extend(matches_quote_double) | ||
| text_prompt_texts.extend(matches_quote_chinese_single) | ||
| text_prompt_texts.extend(matches_quote_chinese_double) | ||
|
|
||
| return len(text_prompt_texts) > 0 | ||
|
|
||
| def encode(self, clip, text): | ||
| tokens = clip.tokenize(text) | ||
| has_quoted_text = self.detect_quoted_text(text) | ||
|
|
||
| conditioning = clip.encode_from_tokens_scheduled(tokens) | ||
|
|
||
| c = [] | ||
| for t in conditioning: | ||
| n = [t[0], t[1].copy()] | ||
| n[1]['has_quoted_text'] = has_quoted_text | ||
| c.append(n) | ||
|
|
||
| return (c, has_quoted_text) | ||
|
|
||
| class EmptyHunyuanLatentVideo: | ||
| @classmethod | ||
| def INPUT_TYPES(s): | ||
|
|
@@ -151,8 +367,16 @@ def execute(self, positive, negative, latent, noise_augmentation): | |
| return (positive, negative, out_latent) | ||
|
|
||
|
|
||
|
|
||
| NODE_DISPLAY_NAME_MAPPINGS = { | ||
| "HunyuanMixModeAPG": "Hunyuan Mix Mode APG", | ||
| "HunyuanStepBasedAPG": "Hunyuan Step Based APG", | ||
| } | ||
|
|
||
| NODE_CLASS_MAPPINGS = { | ||
| "HunyuanMixModeAPG": HunyuanMixModeAPG, | ||
| "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, | ||
| "CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection, | ||
| "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, | ||
| "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, | ||
| "HunyuanImageToVideo": HunyuanImageToVideo, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IS_CHANGED still needs to be removed, as per my last comment in the first PR. Node application is not on a per-step basis - always returning True from IS_CHANGED will break the proper node execution behavior. The only things in a per step basis would be the code that is being registered on the model patcher.
I think there may be an issue you are experiencing that is being incorrectly dealt with IS_CHANGED. Can you describe what was happening when you didn't have the IS_CHANGED? You may need to instead add a wrapper function to initialize/reset the dictionary you're trying to use to count steps at sample time.