diff --git a/comfy/sd.py b/comfy/sd.py index 2df340739f4e..3fd128e528df 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -836,7 +836,7 @@ class CLIPType(Enum): OMNIGEN2 = 17 QWEN_IMAGE = 18 HUNYUAN_IMAGE = 19 - + HUNYUAN_IMAGE_REFINER = 20 def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] @@ -995,6 +995,9 @@ class EmptyClass: if clip_type == CLIPType.HUNYUAN_IMAGE: clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, **llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageTokenizer + elif clip_type == CLIPType.HUNYUAN_IMAGE_REFINER: + clip_target.clip = comfy.text_encoders.hunyuan_image.te(byt5=False, refiner=True, **llama_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.hunyuan_image.HunyuanImageRefinerTokenizer else: clip_target.clip = comfy.text_encoders.qwen_image.te(**llama_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.qwen_image.QwenImageTokenizer diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index 699eddc3336f..fcaa06e08a6f 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -4,6 +4,8 @@ from transformers import ByT5Tokenizer import os import re +import torch +import numbers class ByT5SmallTokenizer(sd1_clip.SDTokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): @@ -38,6 +40,13 @@ def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs): out['byt5'] = self.byt5.tokenize_with_weights(''.join(map(lambda a: 'Text "{}". '.format(a), text_prompt_texts)), return_word_ids, **kwargs) return out +class HunyuanImageRefinerTokenizer(HunyuanImageTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data) + self.llama_template = "<|start_header_id|>system<|end_header_id|>\n\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|eot_id|>\n<|start_header_id|>user<|end_header_id|>\n\n{}<|eot_id|>" + + + class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) @@ -53,9 +62,9 @@ def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True) -class HunyuanImageTEModel(QwenImageTEModel): +class HunyuanImageTEModel(sd1_clip.SD1ClipModel): def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}): - super(QwenImageTEModel, self).__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) if byt5: self.byt5_small = ByT5SmallModel(device=device, dtype=dtype, model_options=model_options) @@ -63,11 +72,35 @@ def __init__(self, byt5=True, device="cpu", dtype=None, model_options={}): self.byt5_small = None def encode_token_weights(self, token_weight_pairs): - cond, p, extra = super().encode_token_weights(token_weight_pairs) + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + count_im_start = 0 + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 151644 and count_im_start < 2: + template_end = i + count_im_start += 1 + + if out.shape[1] > (template_end + 3): + if tok_pairs[template_end + 1][0] == 872: + if tok_pairs[template_end + 2][0] == 198: + template_end += 3 + + out = out[:, template_end:] + + extra["attention_mask"] = extra["attention_mask"][:, template_end:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") # attention mask is useless if no masked elements + # noqa: W293 + if self.byt5_small is not None and "byt5" in token_weight_pairs: - out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"]) - extra["conditioning_byt5small"] = out[0] - return cond, p, extra + byt5_out = self.byt5_small.encode_token_weights(token_weight_pairs["byt5"]) + extra["conditioning_byt5small"] = byt5_out[0] + return out, pooled, extra + + def set_clip_options(self, options): super().set_clip_options(options) @@ -84,9 +117,33 @@ def load_sd(self, sd): return self.byt5_small.load_sd(sd) else: return super().load_sd(sd) +class HunyuanImageRefinerTEModel(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="qwen25_7b", clip_model=Qwen25_7BVLIModel, model_options=model_options) + + def encode_token_weights(self, token_weight_pairs): + out, pooled, extra = super().encode_token_weights(token_weight_pairs) + tok_pairs = token_weight_pairs["qwen25_7b"][0] + for i, v in enumerate(tok_pairs): + elem = v[0] + if not torch.is_tensor(elem): + if isinstance(elem, numbers.Integral): + if elem == 6171: + template_end = i + break + + out = out[:, template_end-1:] + + extra["attention_mask"] = extra["attention_mask"][:, template_end-1:] + if extra["attention_mask"].sum() == torch.numel(extra["attention_mask"]): + extra.pop("attention_mask") # attention mask is useless if no masked elements + + return out, pooled, extra + + +def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None, refiner=False): + class HunyuanImageTEModel_(HunyuanImageTEModel): -def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None): - class QwenImageTEModel_(HunyuanImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: model_options = model_options.copy() @@ -94,4 +151,14 @@ def __init__(self, device="cpu", dtype=None, model_options={}): if dtype_llama is not None: dtype = dtype_llama super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) - return QwenImageTEModel_ + class HunyuanImageTEModel_refiner(HunyuanImageRefinerTEModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + if dtype_llama is not None: + dtype = dtype_llama + assert refiner, "refiner must be True" + assert not byt5, "byt5 must be False" + super().__init__(device=device, dtype=dtype, model_options=model_options) + return HunyuanImageTEModel_refiner if refiner else HunyuanImageTEModel_ diff --git a/comfy_extras/nodes_hunyuan.py b/comfy_extras/nodes_hunyuan.py index db398cdf14a6..1a6f4569d67f 100644 --- a/comfy_extras/nodes_hunyuan.py +++ b/comfy_extras/nodes_hunyuan.py @@ -1,12 +1,15 @@ +import math import nodes import node_helpers import torch +import re import comfy.model_management +import comfy.patcher_extension class CLIPTextEncodeHunyuanDiT: @classmethod - def INPUT_TYPES(s): + def INPUT_TYPES(cls): return {"required": { "clip": ("CLIP", ), "bert": ("STRING", {"multiline": True, "dynamicPrompts": True}), @@ -23,6 +26,249 @@ def encode(self, clip, bert, mt5xl): return (clip.encode_from_tokens_scheduled(tokens), ) +class MomentumBuffer: + def __init__(self, momentum: float): + self.momentum = momentum + self.running_average = 0 + + def update(self, update_value: torch.Tensor): + new_average = self.momentum * self.running_average + self.running_average = update_value + new_average + +def normalized_guidance_apg( + pred_cond: torch.Tensor, + pred_uncond: torch.Tensor, + guidance_scale: float, + momentum_buffer, + eta: float = 1.0, + norm_threshold: float = 0.0, + use_original_formulation: bool = False, +): + diff = pred_cond - pred_uncond + dim = [-i for i in range(1, len(diff.shape))] + + if momentum_buffer is not None: + momentum_buffer.update(diff) + diff = momentum_buffer.running_average + + if norm_threshold > 0: + ones = torch.ones_like(diff) + diff_norm = diff.norm(p=2, dim=dim, keepdim=True) + scale_factor = torch.minimum(ones, norm_threshold / diff_norm) + diff = diff * scale_factor + + v0, v1 = diff.double(), pred_cond.double() + v1 = torch.nn.functional.normalize(v1, dim=dim) + v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1 + v0_orthogonal = v0 - v0_parallel + diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff) + + normalized_update = diff_orthogonal + eta * diff_parallel + pred = pred_cond if use_original_formulation else pred_uncond + pred = pred + guidance_scale * normalized_update + + return pred + +class AdaptiveProjectedGuidance: + def __init__( + self, + guidance_scale: float = 7.5, + adaptive_projected_guidance_momentum=None, + adaptive_projected_guidance_rescale: float = 15.0, + # eta: float = 1.0, + eta: float = 0.0, + guidance_rescale: float = 0.0, + use_original_formulation: bool = False, + start: float = 0.0, + stop: float = 1.0, + ): + super().__init__() + + self.guidance_scale = guidance_scale + self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum + self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale + self.eta = eta + self.guidance_rescale = guidance_rescale + self.use_original_formulation = use_original_formulation + self.momentum_buffer = None + + def __call__(self, pred_cond: torch.Tensor, pred_uncond=None, is_first_step=False) -> torch.Tensor: + + if is_first_step and self.adaptive_projected_guidance_momentum is not None: + self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) + + pred = normalized_guidance_apg( + pred_cond, + pred_uncond, + self.guidance_scale, + self.momentum_buffer, + self.eta, + self.adaptive_projected_guidance_rescale, + self.use_original_formulation, + ) + + return pred + +class HunyuanMixModeAPG: + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "model": ("MODEL", ), + "has_quoted_text": ("BOOLEAN", ), + + "guidance_scale": ("FLOAT", {"default": 10.0, "min": 1.0, "max": 30.0, "step": 0.1}), + + "general_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "general_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), + "general_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), + "general_start_percent": ("FLOAT", {"default": 0.10, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The relative sampling step to begin use of general APG."}), + + "ocr_eta": ("FLOAT", {"default": 0.0, "min": -10.0, "max": 10.0, "step": 0.01}), + "ocr_norm_threshold": ("FLOAT", {"default": 10.0, "min": 0.0, "max": 50.0, "step": 0.1}), + "ocr_momentum": ("FLOAT", {"default": -0.5, "min": -5.0, "max": 1.0, "step": 0.01}), + "ocr_start_percent": ("FLOAT", {"default": 0.75, "min": 0.0, "max": 1.0, "step": 0.01, "tooltip": "The relative sampling step to begin use of OCR APG."}), + + } + } + + RETURN_TYPES = ("MODEL",) + FUNCTION = "apply_mix_mode_apg" + CATEGORY = "sampling/custom_sampling/hunyuan" + + + def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_percent, + ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_percent): + + general_apg = AdaptiveProjectedGuidance( + guidance_scale=guidance_scale, + eta=general_eta, + adaptive_projected_guidance_rescale=general_norm_threshold, + adaptive_projected_guidance_momentum=general_momentum + ) + + ocr_apg = AdaptiveProjectedGuidance( + eta=ocr_eta, + adaptive_projected_guidance_rescale=ocr_norm_threshold, + adaptive_projected_guidance_momentum=ocr_momentum + ) + + m = model.clone() + + + model_sampling = m.model.model_sampling + general_start_t = model_sampling.percent_to_sigma(general_start_percent) + ocr_start_t = model_sampling.percent_to_sigma(ocr_start_percent) + + + def cfg_function(args): + sigma = args["sigma"].to(torch.float32) + is_first_step = math.isclose(sigma.item(), args['model_options']['transformer_options']['sample_sigmas'][0].item()) + cond = args["cond"] + uncond = args["uncond"] + cond_scale = args["cond_scale"] + + sigma = sigma[:, None, None, None] + + + if not has_quoted_text: + if sigma[0] <= general_start_t: + modified_cond = general_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) + return modified_cond * sigma + else: + if cond_scale > 1: + _ = general_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) # track momentum + return uncond + (cond - uncond) * cond_scale + else: + if sigma[0] <= ocr_start_t: + modified_cond = ocr_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) + return modified_cond * sigma + else: + if cond_scale > 1: + _ = ocr_apg(cond / sigma, uncond / sigma, is_first_step=is_first_step) # track momentum + return uncond + (cond - uncond) * cond_scale + + return cond + + m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True) + return (m,) + +class CLIPTextEncodeHunyuanDiTWithTextDetection: + + @classmethod + def INPUT_TYPES(cls): + return {"required": { + "clip": ("CLIP", ), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), + }} + + RETURN_TYPES = ("CONDITIONING", "BOOLEAN", "STRING") + RETURN_NAMES = ("conditioning", "has_quoted_text", "text") + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning/hunyuan" + + def detect_quoted_text(self, text): + """Detect quoted text in the prompt""" + text_prompt_texts = [] + + # Patterns to match different quote styles + pattern_quote_double = r'\"(.*?)\"' + pattern_quote_chinese_single = r'‘(.*?)’' + pattern_quote_chinese_double = r'“(.*?)”' + + matches_quote_double = re.findall(pattern_quote_double, text) + matches_quote_chinese_single = re.findall(pattern_quote_chinese_single, text) + matches_quote_chinese_double = re.findall(pattern_quote_chinese_double, text) + + text_prompt_texts.extend(matches_quote_double) + text_prompt_texts.extend(matches_quote_chinese_single) + text_prompt_texts.extend(matches_quote_chinese_double) + + return len(text_prompt_texts) > 0 + + def encode(self, clip, text): + tokens = clip.tokenize(text) + has_quoted_text = self.detect_quoted_text(text) + + conditioning = clip.encode_from_tokens_scheduled(tokens) + + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + n[1]['has_quoted_text'] = has_quoted_text + c.append(n) + + return (conditioning, has_quoted_text, text) + + +class CLIPTextEncodeHunyuanImageRefiner: + @classmethod + def INPUT_TYPES(cls): + return {"required": { + "clip": ("CLIP", ), + "text": ("STRING", ), + }} + RETURN_TYPES = ("CONDITIONING",) + RETURN_NAMES = ("conditioning",) + FUNCTION = "encode" + + CATEGORY = "advanced/conditioning/hunyuan" + + + def encode(self, clip, text): + tokens = clip.tokenize(text) + + conditioning = clip.encode_from_tokens_scheduled(tokens) + + c = [] + for t in conditioning: + n = [t[0], t[1].copy()] + c.append(n) + + return (c, ) + class EmptyHunyuanLatentVideo: @classmethod def INPUT_TYPES(s): @@ -151,8 +397,16 @@ def execute(self, positive, negative, latent, noise_augmentation): return (positive, negative, out_latent) + +NODE_DISPLAY_NAME_MAPPINGS = { + "HunyuanMixModeAPG": "Hunyuan Mix Mode APG", +} + NODE_CLASS_MAPPINGS = { + "HunyuanMixModeAPG": HunyuanMixModeAPG, "CLIPTextEncodeHunyuanDiT": CLIPTextEncodeHunyuanDiT, + "CLIPTextEncodeHunyuanDiTWithTextDetection": CLIPTextEncodeHunyuanDiTWithTextDetection, + "CLIPTextEncodeHunyuanImageRefiner": CLIPTextEncodeHunyuanImageRefiner, "TextEncodeHunyuanVideo_ImageToVideo": TextEncodeHunyuanVideo_ImageToVideo, "EmptyHunyuanLatentVideo": EmptyHunyuanLatentVideo, "HunyuanImageToVideo": HunyuanImageToVideo, diff --git a/nodes.py b/nodes.py index 5a5fdcb8ee4e..3efb2f411ef0 100644 --- a/nodes.py +++ b/nodes.py @@ -929,7 +929,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart", "cosmos", "lumina2", "wan", "hidream", "chroma", "ace", "omnigen2", "qwen_image", "hunyuan_image","hunyuan_image_refiner"], ), }, "optional": { "device": (["default", "cpu"], {"advanced": True}),