Skip to content

Commit a6fed84

Browse files
committed
Merge branch 'master' into assets-redo
2 parents 604b00c + 6592bff commit a6fed84

File tree

23 files changed

+963
-99
lines changed

23 files changed

+963
-99
lines changed

.ci/update_windows/update.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,16 @@ def pull(repo, remote_name='origin', branch='master'):
5353
repo.stash(ident)
5454
except KeyError:
5555
print("nothing to stash") # noqa: T201
56+
except:
57+
print("Could not stash, cleaning index and trying again.") # noqa: T201
58+
repo.state_cleanup()
59+
repo.index.read_tree(repo.head.peel().tree)
60+
repo.index.write()
61+
try:
62+
repo.stash(ident)
63+
except KeyError:
64+
print("nothing to stash.") # noqa: T201
65+
5666
backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S'))
5767
print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201
5868
try:

api_server/routes/internal/internal_routes.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,13 @@ async def get_files(request: web.Request) -> web.Response:
5858
return web.json_response({"error": "Invalid directory type"}, status=400)
5959

6060
directory = get_directory_by_type(directory_type)
61+
62+
def is_visible_file(entry: os.DirEntry) -> bool:
63+
"""Filter out hidden files (e.g., .DS_Store on macOS)."""
64+
return entry.is_file() and not entry.name.startswith('.')
65+
6166
sorted_files = sorted(
62-
(entry for entry in os.scandir(directory) if entry.is_file()),
67+
(entry for entry in os.scandir(directory) if is_visible_file(entry)),
6368
key=lambda entry: -entry.stat().st_mtime
6469
)
6570
return web.json_response([entry.name for entry in sorted_files], status=200)

comfy/k_diffusion/sampling.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1557,10 +1557,13 @@ def default_er_sde_noise_scaler(x):
15571557

15581558

15591559
@torch.no_grad()
1560-
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5):
1560+
def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r=0.5, solver_type="phi_1"):
15611561
"""SEEDS-2 - Stochastic Explicit Exponential Derivative-free Solvers (VP Data Prediction) stage 2.
15621562
arXiv: https://arxiv.org/abs/2305.14267 (NeurIPS 2023)
15631563
"""
1564+
if solver_type not in {"phi_1", "phi_2"}:
1565+
raise ValueError("solver_type must be 'phi_1' or 'phi_2'")
1566+
15641567
extra_args = {} if extra_args is None else extra_args
15651568
seed = extra_args.get("seed", None)
15661569
noise_sampler = default_noise_sampler(x, seed=seed) if noise_sampler is None else noise_sampler
@@ -1600,8 +1603,14 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
16001603
denoised_2 = model(x_2, sigma_s_1 * s_in, **extra_args)
16011604

16021605
# Step 2
1603-
denoised_d = torch.lerp(denoised, denoised_2, fac)
1604-
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
1606+
if solver_type == "phi_1":
1607+
denoised_d = torch.lerp(denoised, denoised_2, fac)
1608+
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * ei_h_phi_1(-h_eta) * denoised_d
1609+
elif solver_type == "phi_2":
1610+
b2 = ei_h_phi_2(-h_eta) / r
1611+
b1 = ei_h_phi_1(-h_eta) - b2
1612+
x = sigmas[i + 1] / sigmas[i] * (-h * eta).exp() * x - alpha_t * (b1 * denoised + b2 * denoised_2)
1613+
16051614
if inject_noise:
16061615
segment_factor = (r - 1) * h * eta
16071616
sde_noise = sde_noise * segment_factor.exp()

comfy/ldm/hunyuan_video/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class HunyuanVideoParams:
4343
meanflow: bool
4444
use_cond_type_embedding: bool
4545
vision_in_dim: int
46+
meanflow_sum: bool
4647

4748

4849
class SelfAttentionRef(nn.Module):
@@ -317,7 +318,7 @@ def forward_orig(
317318
timesteps_r = transformer_options['sample_sigmas'][w[0] + 1]
318319
timesteps_r = timesteps_r.unsqueeze(0).to(device=timesteps.device, dtype=timesteps.dtype)
319320
vec_r = self.time_r_in(timestep_embedding(timesteps_r, 256, time_factor=1000.0).to(img.dtype))
320-
vec = (vec + vec_r) / 2
321+
vec = (vec + vec_r) if self.params.meanflow_sum else (vec + vec_r) / 2
321322

322323
if ref_latent is not None:
323324
ref_latent_ids = self.img_ids(ref_latent)

comfy/ldm/lumina/controlnet.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def __init__(
4141
ffn_dim_multiplier: float = (8.0 / 3.0),
4242
norm_eps: float = 1e-5,
4343
qk_norm: bool = True,
44+
n_control_layers=6,
45+
control_in_dim=16,
46+
additional_in_dim=0,
47+
broken=False,
48+
refiner_control=False,
4449
dtype=None,
4550
device=None,
4651
operations=None,
@@ -49,10 +54,11 @@ def __init__(
4954
super().__init__()
5055
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
5156

52-
self.additional_in_dim = 0
53-
self.control_in_dim = 16
57+
self.broken = broken
58+
self.additional_in_dim = additional_in_dim
59+
self.control_in_dim = control_in_dim
5460
n_refiner_layers = 2
55-
self.n_control_layers = 6
61+
self.n_control_layers = n_control_layers
5662
self.control_layers = nn.ModuleList(
5763
[
5864
ZImageControlTransformerBlock(
@@ -74,28 +80,49 @@ def __init__(
7480
all_x_embedder = {}
7581
patch_size = 2
7682
f_patch_size = 1
77-
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
83+
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
7884
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
7985

86+
self.refiner_control = refiner_control
87+
8088
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
81-
self.control_noise_refiner = nn.ModuleList(
82-
[
83-
JointTransformerBlock(
84-
layer_id,
85-
dim,
86-
n_heads,
87-
n_kv_heads,
88-
multiple_of,
89-
ffn_dim_multiplier,
90-
norm_eps,
91-
qk_norm,
92-
modulation=True,
93-
z_image_modulation=True,
94-
operation_settings=operation_settings,
95-
)
96-
for layer_id in range(n_refiner_layers)
97-
]
98-
)
89+
if self.refiner_control:
90+
self.control_noise_refiner = nn.ModuleList(
91+
[
92+
ZImageControlTransformerBlock(
93+
layer_id,
94+
dim,
95+
n_heads,
96+
n_kv_heads,
97+
multiple_of,
98+
ffn_dim_multiplier,
99+
norm_eps,
100+
qk_norm,
101+
block_id=layer_id,
102+
operation_settings=operation_settings,
103+
)
104+
for layer_id in range(n_refiner_layers)
105+
]
106+
)
107+
else:
108+
self.control_noise_refiner = nn.ModuleList(
109+
[
110+
JointTransformerBlock(
111+
layer_id,
112+
dim,
113+
n_heads,
114+
n_kv_heads,
115+
multiple_of,
116+
ffn_dim_multiplier,
117+
norm_eps,
118+
qk_norm,
119+
modulation=True,
120+
z_image_modulation=True,
121+
operation_settings=operation_settings,
122+
)
123+
for layer_id in range(n_refiner_layers)
124+
]
125+
)
99126

100127
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
101128
patch_size = 2
@@ -105,9 +132,29 @@ def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
105132
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
106133

107134
x_attn_mask = None
108-
for layer in self.control_noise_refiner:
109-
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
135+
if not self.refiner_control:
136+
for layer in self.control_noise_refiner:
137+
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
138+
110139
return control_context
111140

141+
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
142+
if self.refiner_control:
143+
if self.broken:
144+
if layer_id == 0:
145+
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
146+
if layer_id > 0:
147+
out = None
148+
for i in range(1, len(self.control_layers)):
149+
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
150+
if out is None:
151+
out = o
152+
153+
return (out, control_context)
154+
else:
155+
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
156+
else:
157+
return (None, control_context)
158+
112159
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
113160
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

comfy/ldm/lumina/model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ def patchify_and_embed(
536536
bsz = len(x)
537537
pH = pW = self.patch_size
538538
device = x[0].device
539+
orig_x = x
539540

540541
if self.pad_tokens_multiple is not None:
541542
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
@@ -572,13 +573,21 @@ def patchify_and_embed(
572573

573574
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
574575

576+
patches = transformer_options.get("patches", {})
577+
575578
# refine context
576579
for layer in self.context_refiner:
577580
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
578581

579582
padded_img_mask = None
580-
for layer in self.noise_refiner:
583+
x_input = x
584+
for i, layer in enumerate(self.noise_refiner):
581585
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
586+
if "noise_refiner" in patches:
587+
for p in patches["noise_refiner"]:
588+
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
589+
if "img" in out:
590+
x = out["img"]
582591

583592
padded_full_embed = torch.cat((cap_feats, x), dim=1)
584593
mask = None
@@ -622,14 +631,15 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, trans
622631

623632
patches = transformer_options.get("patches", {})
624633
x_is_tensor = isinstance(x, torch.Tensor)
625-
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
634+
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
626635
freqs_cis = freqs_cis.to(img.device)
627636

637+
img_input = img
628638
for i, layer in enumerate(self.layers):
629639
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
630640
if "double_block" in patches:
631641
for p in patches["double_block"]:
632-
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
642+
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
633643
if "img" in out:
634644
img[:, cap_size[0]:] = out["img"]
635645
if "txt" in out:

comfy/model_detection.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
180180
dit_config["use_cond_type_embedding"] = False
181181
if '{}vision_in.proj.0.weight'.format(key_prefix) in state_dict_keys:
182182
dit_config["vision_in_dim"] = state_dict['{}vision_in.proj.0.weight'.format(key_prefix)].shape[0]
183+
dit_config["meanflow_sum"] = True
183184
else:
184185
dit_config["vision_in_dim"] = None
186+
dit_config["meanflow_sum"] = False
185187
return dit_config
186188

187189
if '{}double_blocks.0.img_attn.norm.key_norm.scale'.format(key_prefix) in state_dict_keys and ('{}img_in.weight'.format(key_prefix) in state_dict_keys or f"{key_prefix}distilled_guidance_layer.norms.0.scale" in state_dict_keys): #Flux, Chroma or Chroma Radiance (has no img_in.weight)
@@ -257,8 +259,10 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
257259
dit_config["nerf_tile_size"] = 512
258260
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
259261
dit_config["nerf_embedder_dtype"] = torch.float32
260-
if "__x0__" in state_dict_keys: # x0 pred
261-
dit_config["use_x0"] = True
262+
if "__x0__" in state_dict_keys: # x0 pred
263+
dit_config["use_x0"] = True
264+
else:
265+
dit_config["use_x0"] = False
262266
else:
263267
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
264268
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys

comfy/model_patcher.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,9 @@ def set_model_double_block_patch(self, patch):
454454
def set_model_post_input_patch(self, patch):
455455
self.set_model_patch(patch, "post_input")
456456

457+
def set_model_noise_refiner_patch(self, patch):
458+
self.set_model_patch(patch, "noise_refiner")
459+
457460
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
458461
rope_options = self.model_options["transformer_options"].get("rope_options", {})
459462
rope_options["scale_x"] = scale_x

comfy/ops.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -497,15 +497,14 @@ def __init__(
497497
) -> None:
498498
super().__init__()
499499

500-
self.factory_kwargs = {"device": device, "dtype": MixedPrecisionOps._compute_dtype}
501-
# self.factory_kwargs = {"device": device, "dtype": dtype}
500+
if dtype is None:
501+
dtype = MixedPrecisionOps._compute_dtype
502+
503+
self.factory_kwargs = {"device": device, "dtype": dtype}
502504

503505
self.in_features = in_features
504506
self.out_features = out_features
505-
if bias:
506-
self.bias = torch.nn.Parameter(torch.empty(out_features, **self.factory_kwargs))
507-
else:
508-
self.register_parameter("bias", None)
507+
self._has_bias = bias
509508

510509
self.tensor_class = None
511510
self._full_precision_mm = MixedPrecisionOps._full_precision_mm
@@ -530,7 +529,14 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
530529
layer_conf = json.loads(layer_conf.numpy().tobytes())
531530

532531
if layer_conf is None:
533-
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False)
532+
dtype = self.factory_kwargs["dtype"]
533+
self.weight = torch.nn.Parameter(weight.to(device=device, dtype=dtype), requires_grad=False)
534+
if dtype != MixedPrecisionOps._compute_dtype:
535+
self.comfy_cast_weights = True
536+
if self._has_bias:
537+
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=dtype))
538+
else:
539+
self.register_parameter("bias", None)
534540
else:
535541
self.quant_format = layer_conf.get("format", None)
536542
if not self._full_precision_mm:
@@ -560,6 +566,11 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata,
560566
requires_grad=False
561567
)
562568

569+
if self._has_bias:
570+
self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device, dtype=MixedPrecisionOps._compute_dtype))
571+
else:
572+
self.register_parameter("bias", None)
573+
563574
for param_name in qconfig["parameters"]:
564575
param_key = f"{prefix}{param_name}"
565576
_v = state_dict.pop(param_key, None)
@@ -581,7 +592,7 @@ def state_dict(self, *args, destination=None, prefix="", **kwargs):
581592
quant_conf = {"format": self.quant_format}
582593
if self._full_precision_mm:
583594
quant_conf["full_precision_matrix_mult"] = True
584-
sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8)
595+
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
585596
return sd
586597

587598
def _forward(self, input, weight, bias):

comfy/quant_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,10 @@ def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_roun
399399
orig_dtype = tensor.dtype
400400

401401
if isinstance(scale, str) and scale == "recalculate":
402-
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
402+
scale = torch.amax(tensor.abs()).to(dtype=torch.float32) / torch.finfo(dtype).max
403+
if tensor.dtype not in [torch.float32, torch.bfloat16]: # Prevent scale from being too small
404+
tensor_info = torch.finfo(tensor.dtype)
405+
scale = (1.0 / torch.clamp((1.0 / scale), min=tensor_info.min, max=tensor_info.max))
403406

404407
if scale is not None:
405408
if not isinstance(scale, torch.Tensor):

0 commit comments

Comments
 (0)