Skip to content

Commit d3fd217

Browse files
committed
Merge branch 'master' into assets-redo
2 parents 07e85ce + 6a2678a commit d3fd217

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+3089
-1241
lines changed

.github/workflows/test-ci.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ on:
55
push:
66
branches:
77
- master
8+
- release/**
89
paths-ignore:
910
- 'app/**'
1011
- 'input/**'

.github/workflows/test-execution.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ name: Execution Tests
22

33
on:
44
push:
5-
branches: [ main, master ]
5+
branches: [ main, master, release/** ]
66
pull_request:
7-
branches: [ main, master ]
7+
branches: [ main, master, release/** ]
88

99
jobs:
1010
test:

.github/workflows/test-launch.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ name: Test server launches without errors
22

33
on:
44
push:
5-
branches: [ main, master ]
5+
branches: [ main, master, release/** ]
66
pull_request:
7-
branches: [ main, master ]
7+
branches: [ main, master, release/** ]
88

99
jobs:
1010
test:

.github/workflows/test-unit.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ name: Unit Tests
22

33
on:
44
push:
5-
branches: [ main, master ]
5+
branches: [ main, master, release/** ]
66
pull_request:
7-
branches: [ main, master ]
7+
branches: [ main, master, release/** ]
88

99
jobs:
1010
test:

.github/workflows/update-version.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ on:
66
- "pyproject.toml"
77
branches:
88
- master
9+
- release/**
910

1011
jobs:
1112
update-version:

comfy/cli_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,13 @@ class LatentPreviewMethod(enum.Enum):
9797
Latent2RGB = "latent2rgb"
9898
TAESD = "taesd"
9999

100+
@classmethod
101+
def from_string(cls, value: str):
102+
for member in cls:
103+
if member.value == value:
104+
return member
105+
return None
106+
100107
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.NoPreviews, help="Default preview method for sampler nodes.", action=EnumAction)
101108

102109
parser.add_argument("--preview-size", type=int, default=512, help="Sets the maximum preview size for sampler nodes.")

comfy/context_windows.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ class IndexListCallbacks:
8787
COMBINE_CONTEXT_WINDOW_RESULTS = "combine_context_window_results"
8888
EXECUTE_START = "execute_start"
8989
EXECUTE_CLEANUP = "execute_cleanup"
90+
RESIZE_COND_ITEM = "resize_cond_item"
9091

9192
def init_callbacks(self):
9293
return {}
@@ -166,6 +167,18 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
166167
new_cond_item = cond_item.copy()
167168
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
168169
for cond_key, cond_value in new_cond_item.items():
170+
# Allow callbacks to handle custom conditioning items
171+
handled = False
172+
for callback in comfy.patcher_extension.get_all_callbacks(
173+
IndexListCallbacks.RESIZE_COND_ITEM, self.callbacks
174+
):
175+
result = callback(cond_key, cond_value, window, x_in, device, new_cond_item)
176+
if result is not None:
177+
new_cond_item[cond_key] = result
178+
handled = True
179+
break
180+
if handled:
181+
continue
169182
if isinstance(cond_value, torch.Tensor):
170183
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
171184
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):

comfy/k_diffusion/sampling.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,17 @@ def sample_seeds_2(model, x, sigmas, extra_args=None, callback=None, disable=Non
16181618
x = x + sde_noise * sigmas[i + 1] * s_noise
16191619
return x
16201620

1621+
@torch.no_grad()
1622+
def sample_exp_heun_2_x0(model, x, sigmas, extra_args=None, callback=None, disable=None, solver_type="phi_2"):
1623+
"""Deterministic exponential Heun second order method in data prediction (x0) and logSNR time."""
1624+
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=0.0, s_noise=0.0, noise_sampler=None, r=1.0, solver_type=solver_type)
1625+
1626+
1627+
@torch.no_grad()
1628+
def sample_exp_heun_2_x0_sde(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, solver_type="phi_2"):
1629+
"""Stochastic exponential Heun second order method in data prediction (x0) and logSNR time."""
1630+
return sample_seeds_2(model, x, sigmas, extra_args=extra_args, callback=callback, disable=disable, eta=eta, s_noise=s_noise, noise_sampler=noise_sampler, r=1.0, solver_type=solver_type)
1631+
16211632

16221633
@torch.no_grad()
16231634
def sample_seeds_3(model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1., s_noise=1., noise_sampler=None, r_1=1./3, r_2=2./3):
@@ -1765,7 +1776,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
17651776
# Predictor
17661777
if sigmas[i + 1] == 0:
17671778
# Denoising step
1768-
x = denoised
1779+
x_pred = denoised
17691780
else:
17701781
tau_t = tau_func(sigmas[i + 1])
17711782
curr_lambdas = lambdas[i - predictor_order_used + 1:i + 1]
@@ -1786,7 +1797,7 @@ def sample_sa_solver(model, x, sigmas, extra_args=None, callback=None, disable=F
17861797
if tau_t > 0 and s_noise > 0:
17871798
noise = noise_sampler(sigmas[i], sigmas[i + 1]) * sigmas[i + 1] * (-2 * tau_t ** 2 * h).expm1().neg().sqrt() * s_noise
17881799
x_pred = x_pred + noise
1789-
return x
1800+
return x_pred
17901801

17911802

17921803
@torch.no_grad()

comfy/ldm/lumina/model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,8 +634,11 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, trans
634634
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
635635
freqs_cis = freqs_cis.to(img.device)
636636

637+
transformer_options["total_blocks"] = len(self.layers)
638+
transformer_options["block_type"] = "double"
637639
img_input = img
638640
for i, layer in enumerate(self.layers):
641+
transformer_options["block_index"] = i
639642
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
640643
if "double_block" in patches:
641644
for p in patches["double_block"]:

comfy/ldm/qwen_image/model.py

Lines changed: 44 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,24 @@ def __init__(
218218
operations=operations,
219219
)
220220

221-
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
221+
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
222+
if timestep_zero_index is not None:
223+
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
224+
else:
225+
return torch.addcmul(y, gate, x)
226+
227+
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
222228
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
223-
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
229+
if timestep_zero_index is not None:
230+
actual_batch = shift.size(0) // 2
231+
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
232+
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
233+
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
234+
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
235+
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
236+
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
237+
else:
238+
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
224239

225240
def forward(
226241
self,
@@ -229,14 +244,19 @@ def forward(
229244
encoder_hidden_states_mask: torch.Tensor,
230245
temb: torch.Tensor,
231246
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
247+
timestep_zero_index=None,
232248
transformer_options={},
233249
) -> Tuple[torch.Tensor, torch.Tensor]:
234250
img_mod_params = self.img_mod(temb)
251+
252+
if timestep_zero_index is not None:
253+
temb = temb.chunk(2, dim=0)[0]
254+
235255
txt_mod_params = self.txt_mod(temb)
236256
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
237257
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
238258

239-
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
259+
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
240260
del img_mod1
241261
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
242262
del txt_mod1
@@ -251,15 +271,15 @@ def forward(
251271
del img_modulated
252272
del txt_modulated
253273

254-
hidden_states = hidden_states + img_gate1 * img_attn_output
274+
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
255275
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
256276
del img_attn_output
257277
del txt_attn_output
258278
del img_gate1
259279
del txt_gate1
260280

261-
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
262-
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
281+
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
282+
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
263283

264284
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
265285
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@@ -302,6 +322,7 @@ def __init__(
302322
pooled_projection_dim: int = 768,
303323
guidance_embeds: bool = False,
304324
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
325+
default_ref_method="index",
305326
image_model=None,
306327
final_layer=True,
307328
dtype=None,
@@ -314,6 +335,7 @@ def __init__(
314335
self.in_channels = in_channels
315336
self.out_channels = out_channels or in_channels
316337
self.inner_dim = num_attention_heads * attention_head_dim
338+
self.default_ref_method = default_ref_method
317339

318340
self.pe_embedder = EmbedND(dim=attention_head_dim, theta=10000, axes_dim=list(axes_dims_rope))
319341

@@ -341,6 +363,9 @@ def __init__(
341363
for _ in range(num_layers)
342364
])
343365

366+
if self.default_ref_method == "index_timestep_zero":
367+
self.register_buffer("__index_timestep_zero__", torch.tensor([]))
368+
344369
if final_layer:
345370
self.norm_out = LastLayer(self.inner_dim, self.inner_dim, dtype=dtype, device=device, operations=operations)
346371
self.proj_out = operations.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True, dtype=dtype, device=device)
@@ -391,11 +416,14 @@ def _forward(
391416
hidden_states, img_ids, orig_shape = self.process_img(x)
392417
num_embeds = hidden_states.shape[1]
393418

419+
timestep_zero_index = None
394420
if ref_latents is not None:
395421
h = 0
396422
w = 0
397423
index = 0
398-
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
424+
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
425+
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
426+
timestep_zero = ref_method == "index_timestep_zero"
399427
for ref in ref_latents:
400428
if index_ref_method:
401429
index += 1
@@ -415,6 +443,10 @@ def _forward(
415443
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
416444
hidden_states = torch.cat([hidden_states, kontext], dim=1)
417445
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
446+
if timestep_zero:
447+
if index > 0:
448+
timestep = torch.cat([timestep, timestep * 0], dim=0)
449+
timestep_zero_index = num_embeds
418450

419451
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
420452
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@@ -446,7 +478,7 @@ def _forward(
446478
if ("double_block", i) in blocks_replace:
447479
def block_wrap(args):
448480
out = {}
449-
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
481+
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
450482
return out
451483
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
452484
hidden_states = out["img"]
@@ -458,6 +490,7 @@ def block_wrap(args):
458490
encoder_hidden_states_mask=encoder_hidden_states_mask,
459491
temb=temb,
460492
image_rotary_emb=image_rotary_emb,
493+
timestep_zero_index=timestep_zero_index,
461494
transformer_options=transformer_options,
462495
)
463496

@@ -474,6 +507,9 @@ def block_wrap(args):
474507
if add is not None:
475508
hidden_states[:, :add.shape[1]] += add
476509

510+
if timestep_zero_index is not None:
511+
temb = temb.chunk(2, dim=0)[0]
512+
477513
hidden_states = self.norm_out(hidden_states, temb)
478514
hidden_states = self.proj_out(hidden_states)
479515

0 commit comments

Comments
 (0)