Skip to content

Commit f3c27d6

Browse files
committed
Merge branch 'master' into v3-improvements
2 parents dd7c045 + 683569d commit f3c27d6

31 files changed

+989
-753
lines changed

comfy/cli_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
9797
Latent2RGB = "latent2rgb"
9898
TAESD = "taesd"
9999

100+
@classmethod
101+
def from_string(cls, value: str):
102+
for member in cls:
103+
if member.value == value:
104+
return member
105+
return None
106+
100107
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
101108

102109
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")

comfy/context_windows.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class IndexListCallbacks:
8787
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
8888
EXECUTE_START = "execute_start"
8989
EXECUTE_CLEANUP = "execute_cleanup"
90+
RESIZE_COND_ITEM = "resize_cond_item"
9091

9192
def init_callbacks(self):
9293
return {}
@@ -166,6 +167,18 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
166167
new_cond_item = cond_item.copy()
167168
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
168169
for cond_key, cond_value in new_cond_item.items():
170+
# Allow callbacks to handle custom conditioning items
171+
handled = False
172+
for callback in comfy.patcher_extension.get_all_callbacks(
173+
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
174+
):
175+
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
176+
if result is not None:
177+
new_cond_item[cond_key] = result
178+
handled = True
179+
break
180+
if handled:
181+
continue
169182
if isinstance(cond_value, torch.Tensor):
170183
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
171184
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):

comfy/ldm/lumina/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,11 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, trans
634634
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
635635
freqs_cis = freqs_cis.to(img.device)
636636

637+
transformer_options["total_blocks"] = len(self.layers)
638+
transformer_options["block_type"] = "double"
637639
img_input = img
638640
for i, layer in enumerate(self.layers):
641+
transformer_options["block_index"] = i
639642
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
640643
if "double_block" in patches:
641644
for p in patches["double_block"]:

comfy/ldm/qwen_image/model.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,24 @@ def __init__(
218218
operations=operations,
219219
)
220220

221-
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
221+
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
222+
if timestep_zero_index is not None:
223+
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
224+
else:
225+
return torch.addcmul(y, gate, x)
226+
227+
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
222228
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
223-
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
229+
if timestep_zero_index is not None:
230+
actual_batch = shift.size(0) // 2
231+
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
232+
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
233+
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
234+
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
235+
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
236+
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
237+
else:
238+
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
224239

225240
def forward(
226241
self,
@@ -229,14 +244,19 @@ def forward(
229244
encoder_hidden_states_mask: torch.Tensor,
230245
temb: torch.Tensor,
231246
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
247+
timestep_zero_index=None,
232248
transformer_options={},
233249
) -> Tuple[torch.Tensor, torch.Tensor]:
234250
img_mod_params = self.img_mod(temb)
251+
252+
if timestep_zero_index is not None:
253+
temb = temb.chunk(2, dim=0)[0]
254+
235255
txt_mod_params = self.txt_mod(temb)
236256
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
237257
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
238258

239-
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
259+
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
240260
del img_mod1
241261
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
242262
del txt_mod1
@@ -251,15 +271,15 @@ def forward(
251271
del img_modulated
252272
del txt_modulated
253273

254-
hidden_states = hidden_states + img_gate1 * img_attn_output
274+
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
255275
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
256276
del img_attn_output
257277
del txt_attn_output
258278
del img_gate1
259279
del txt_gate1
260280

261-
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
262-
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
281+
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
282+
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
263283

264284
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
265285
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@@ -391,11 +411,14 @@ def _forward(
391411
hidden_states, img_ids, orig_shape = self.process_img(x)
392412
num_embeds = hidden_states.shape[1]
393413

414+
timestep_zero_index = None
394415
if ref_latents is not None:
395416
h = 0
396417
w = 0
397418
index = 0
398-
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
419+
ref_method = kwargs.get("ref_latents_method", "index")
420+
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
421+
timestep_zero = ref_method == "index_timestep_zero"
399422
for ref in ref_latents:
400423
if index_ref_method:
401424
index += 1
@@ -415,6 +438,10 @@ def _forward(
415438
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
416439
hidden_states = torch.cat([hidden_states, kontext], dim=1)
417440
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
441+
if timestep_zero:
442+
if index > 0:
443+
timestep = torch.cat([timestep, timestep * 0], dim=0)
444+
timestep_zero_index = num_embeds
418445

419446
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
420447
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@@ -446,7 +473,7 @@ def _forward(
446473
if ("double_block", i) in blocks_replace:
447474
def block_wrap(args):
448475
out = {}
449-
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
476+
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
450477
return out
451478
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
452479
hidden_states = out["img"]
@@ -458,6 +485,7 @@ def block_wrap(args):
458485
encoder_hidden_states_mask=encoder_hidden_states_mask,
459486
temb=temb,
460487
image_rotary_emb=image_rotary_emb,
488+
timestep_zero_index=timestep_zero_index,
461489
transformer_options=transformer_options,
462490
)
463491

@@ -474,6 +502,9 @@ def block_wrap(args):
474502
if add is not None:
475503
hidden_states[:, :add.shape[1]] += add
476504

505+
if timestep_zero_index is not None:
506+
temb = temb.chunk(2, dim=0)[0]
507+
477508
hidden_states = self.norm_out(hidden_states, temb)
478509
hidden_states = self.proj_out(hidden_states)
479510

comfy/ldm/wan/model.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,10 @@ def forward_orig(
568568

569569
patches_replace = transformer_options.get("patches_replace", {})
570570
blocks_replace = patches_replace.get("dit", {})
571+
transformer_options["total_blocks"] = len(self.blocks)
572+
transformer_options["block_type"] = "double"
571573
for i, block in enumerate(self.blocks):
574+
transformer_options["block_index"] = i
572575
if ("double_block", i) in blocks_replace:
573576
def block_wrap(args):
574577
out = {}
@@ -763,7 +766,10 @@ def forward_orig(
763766

764767
patches_replace = transformer_options.get("patches_replace", {})
765768
blocks_replace = patches_replace.get("dit", {})
769+
transformer_options["total_blocks"] = len(self.blocks)
770+
transformer_options["block_type"] = "double"
766771
for i, block in enumerate(self.blocks):
772+
transformer_options["block_index"] = i
767773
if ("double_block", i) in blocks_replace:
768774
def block_wrap(args):
769775
out = {}
@@ -862,7 +868,10 @@ def forward_orig(
862868

863869
patches_replace = transformer_options.get("patches_replace", {})
864870
blocks_replace = patches_replace.get("dit", {})
871+
transformer_options["total_blocks"] = len(self.blocks)
872+
transformer_options["block_type"] = "double"
865873
for i, block in enumerate(self.blocks):
874+
transformer_options["block_index"] = i
866875
if ("double_block", i) in blocks_replace:
867876
def block_wrap(args):
868877
out = {}
@@ -1326,16 +1335,19 @@ def forward_orig(
13261335

13271336
patches_replace = transformer_options.get("patches_replace", {})
13281337
blocks_replace = patches_replace.get("dit", {})
1338+
transformer_options["total_blocks"] = len(self.blocks)
1339+
transformer_options["block_type"] = "double"
13291340
for i, block in enumerate(self.blocks):
1341+
transformer_options["block_index"] = i
13301342
if ("double_block", i) in blocks_replace:
13311343
def block_wrap(args):
13321344
out = {}
1333-
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"])
1345+
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], transformer_options=args["transformer_options"])
13341346
return out
1335-
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs}, {"original_block": block_wrap})
1347+
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
13361348
x = out["img"]
13371349
else:
1338-
x = block(x, e=e0, freqs=freqs, context=context)
1350+
x = block(x, e=e0, freqs=freqs, context=context, transformer_options=transformer_options)
13391351
if audio_emb is not None:
13401352
x = self.audio_injector(x, i, audio_emb, audio_emb_global, seq_len)
13411353
# head
@@ -1574,7 +1586,10 @@ def forward_orig(
15741586

15751587
patches_replace = transformer_options.get("patches_replace", {})
15761588
blocks_replace = patches_replace.get("dit", {})
1589+
transformer_options["total_blocks"] = len(self.blocks)
1590+
transformer_options["block_type"] = "double"
15771591
for i, block in enumerate(self.blocks):
1592+
transformer_options["block_index"] = i
15781593
if ("double_block", i) in blocks_replace:
15791594
def block_wrap(args):
15801595
out = {}

comfy/ldm/wan/model_animate.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -523,7 +523,10 @@ def forward_orig(
523523

524524
patches_replace = transformer_options.get("patches_replace", {})
525525
blocks_replace = patches_replace.get("dit", {})
526+
transformer_options["total_blocks"] = len(self.blocks)
527+
transformer_options["block_type"] = "double"
526528
for i, block in enumerate(self.blocks):
529+
transformer_options["block_index"] = i
527530
if ("double_block", i) in blocks_replace:
528531
def block_wrap(args):
529532
out = {}

comfy/supported_models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from . import latent_formats
2929

3030
from . import diffusers_convert
31+
import comfy.model_management
3132

3233
class SD15(supported_models_base.BASE):
3334
unet_config = {
@@ -1028,7 +1029,13 @@ class ZImage(Lumina2):
10281029

10291030
memory_usage_factor = 2.0
10301031

1031-
supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]
1032+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
1033+
1034+
def __init__(self, unet_config):
1035+
super().__init__(unet_config)
1036+
if comfy.model_management.extended_fp16_support():
1037+
self.supported_inference_dtypes = self.supported_inference_dtypes.copy()
1038+
self.supported_inference_dtypes.insert(1, torch.float16)
10321039

10331040
def clip_target(self, state_dict={}):
10341041
pref = self.text_encoder_key_prefix[0]

comfy/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def scalar(*args, **kwargs):
5353
ALWAYS_SAFE_LOAD = True
5454
logging.info("Checkpoint files will always be loaded safely.")
5555
else:
56-
logging.info("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended.")
56+
logging.warning("Warning, you are using an old pytorch version and some ckpt/pt files might be loaded unsafely. Upgrading to 2.4 or above is recommended as older versions of pytorch are no longer supported.")
5757

5858
def load_torch_file(ckpt, safe_load=False, device=None, return_metadata=False):
5959
if device is None:

comfy_api/feature_flags.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,20 @@
55
allowing graceful protocol evolution while maintaining backward compatibility.
66
"""
77

8-
from typing import Any, Dict
8+
from typing import Any
99

1010
from comfy.cli_args import args
1111

1212
# Default server capabilities
13-
SERVER_FEATURE_FLAGS: Dict[str, Any] = {
13+
SERVER_FEATURE_FLAGS: dict[str, Any] = {
1414
"supports_preview_metadata": True,
1515
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
1616
"extension": {"manager": {"supports_v4": True}},
1717
}
1818

1919

2020
def get_connection_feature(
21-
sockets_metadata: Dict[str, Dict[str, Any]],
21+
sockets_metadata: dict[str, dict[str, Any]],
2222
sid: str,
2323
feature_name: str,
2424
default: Any = False
@@ -42,7 +42,7 @@ def get_connection_feature(
4242

4343

4444
def supports_feature(
45-
sockets_metadata: Dict[str, Dict[str, Any]],
45+
sockets_metadata: dict[str, dict[str, Any]],
4646
sid: str,
4747
feature_name: str
4848
) -> bool:
@@ -60,7 +60,7 @@ def supports_feature(
6060
return get_connection_feature(sockets_metadata, sid, feature_name, False) is True
6161

6262

63-
def get_server_features() -> Dict[str, Any]:
63+
def get_server_features() -> dict[str, Any]:
6464
"""
6565
Get the server's feature flags.
6666

comfy_api/internal/api_registry.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Type, List, NamedTuple
1+
from typing import NamedTuple
22
from comfy_api.internal.singleton import ProxiedSingleton
33
from packaging import version as packaging_version
44

@@ -10,7 +10,7 @@ def __init__(self):
1010

1111
class ComfyAPIWithVersion(NamedTuple):
1212
version: str
13-
api_class: Type[ComfyAPIBase]
13+
api_class: type[ComfyAPIBase]
1414

1515

1616
def parse_version(version_str: str) -> packaging_version.Version:
@@ -23,16 +23,16 @@ def parse_version(version_str: str) -> packaging_version.Version:
2323
return packaging_version.parse(version_str)
2424

2525

26-
registered_versions: List[ComfyAPIWithVersion] = []
26+
registered_versions: list[ComfyAPIWithVersion] = []
2727

2828

29-
def register_versions(versions: List[ComfyAPIWithVersion]):
29+
def register_versions(versions: list[ComfyAPIWithVersion]):
3030
versions.sort(key=lambda x: parse_version(x.version))
3131
global registered_versions
3232
registered_versions = versions
3333

3434

35-
def get_all_versions() -> List[ComfyAPIWithVersion]:
35+
def get_all_versions() -> list[ComfyAPIWithVersion]:
3636
"""
3737
Returns a list of all registered ComfyAPI versions.
3838
"""

0 commit comments

Comments
 (0)