Skip to content

Commit fe7899d

Browse files
committed
Add OUTER_SAMPLE wrapper
1 parent 9cb2eda commit fe7899d

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

comfy_extras/nodes_hunyuan.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
from numpy import arccos
21
import nodes
32
import node_helpers
43
import torch
54
import re
65
import comfy.model_management
6+
import comfy.patcher_extension
77

88

99
class CLIPTextEncodeHunyuanDiT:
@@ -137,10 +137,6 @@ def INPUT_TYPES(s):
137137
CATEGORY = "sampling/custom_sampling/hunyuan"
138138

139139

140-
@classmethod
141-
def IS_CHANGED(cls, model):
142-
return True
143-
144140
def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_eta, general_norm_threshold, general_momentum, general_start_step,
145141
ocr_eta, ocr_norm_threshold, ocr_momentum, ocr_start_step):
146142

@@ -157,16 +153,22 @@ def apply_mix_mode_apg(self, model, has_quoted_text, guidance_scale, general_et
157153
adaptive_projected_guidance_momentum=ocr_momentum
158154
)
159155

160-
current_step = {"step": 0}
156+
m = model.clone()
157+
step_tracker = {"step": 0}
158+
159+
def hunyuan_apg_outer_sample_wrapper(executor, *args, **kwargs):
160+
step_tracker['step'] = 0
161+
return executor(*args, **kwargs)
162+
161163

162164
def cfg_function(args):
163165
sigma = args["sigma"].to(torch.float32)
164166
cond = args["cond"]
165167
uncond = args["uncond"]
166168
cond_scale = args["cond_scale"]
167169

168-
step = current_step["step"]
169-
current_step["step"] += 1
170+
step = step_tracker['step']
171+
step_tracker['step'] += 1
170172

171173
if not has_quoted_text:
172174
if step >= general_start_step:
@@ -187,8 +189,7 @@ def cfg_function(args):
187189

188190
return cond
189191

190-
191-
m = model.clone()
192+
m.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "hunyuan_apg", hunyuan_apg_outer_sample_wrapper)
192193
m.set_model_sampler_cfg_function(cfg_function, disable_cfg1_optimization=True)
193194
return (m,)
194195

@@ -239,6 +240,7 @@ def encode(self, clip, text):
239240

240241
return (c, has_quoted_text, text)
241242

243+
242244
class CLIPTextEncodeHunyuanImageRefiner:
243245
@classmethod
244246
def INPUT_TYPES(cls):

0 commit comments

Comments
 (0)