Skip to content

Commit 70541d4

Browse files
Support the new qwen edit 2511 reference method. (#11340)
index_timestep_zero can be selected in the FluxKontextMultiReferenceLatentMethod now with the display name set to the more generic "Edit Model Reference Method" node.
1 parent 77b2f7c commit 70541d4

File tree

2 files changed

+41
-9
lines changed

2 files changed

+41
-9
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 39 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,24 @@ def __init__(
218218
operations=operations,
219219
)
220220

221-
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
221+
def _apply_gate(self, x, y, gate, timestep_zero_index=None):
222+
if timestep_zero_index is not None:
223+
return y + torch.cat((x[:, :timestep_zero_index] * gate[0], x[:, timestep_zero_index:] * gate[1]), dim=1)
224+
else:
225+
return torch.addcmul(y, gate, x)
226+
227+
def _modulate(self, x: torch.Tensor, mod_params: torch.Tensor, timestep_zero_index=None) -> Tuple[torch.Tensor, torch.Tensor]:
222228
shift, scale, gate = torch.chunk(mod_params, 3, dim=-1)
223-
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
229+
if timestep_zero_index is not None:
230+
actual_batch = shift.size(0) // 2
231+
shift, shift_0 = shift[:actual_batch], shift[actual_batch:]
232+
scale, scale_0 = scale[:actual_batch], scale[actual_batch:]
233+
gate, gate_0 = gate[:actual_batch], gate[actual_batch:]
234+
reg = torch.addcmul(shift.unsqueeze(1), x[:, :timestep_zero_index], 1 + scale.unsqueeze(1))
235+
zero = torch.addcmul(shift_0.unsqueeze(1), x[:, timestep_zero_index:], 1 + scale_0.unsqueeze(1))
236+
return torch.cat((reg, zero), dim=1), (gate.unsqueeze(1), gate_0.unsqueeze(1))
237+
else:
238+
return torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1)), gate.unsqueeze(1)
224239

225240
def forward(
226241
self,
@@ -229,14 +244,19 @@ def forward(
229244
encoder_hidden_states_mask: torch.Tensor,
230245
temb: torch.Tensor,
231246
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
247+
timestep_zero_index=None,
232248
transformer_options={},
233249
) -> Tuple[torch.Tensor, torch.Tensor]:
234250
img_mod_params = self.img_mod(temb)
251+
252+
if timestep_zero_index is not None:
253+
temb = temb.chunk(2, dim=0)[0]
254+
235255
txt_mod_params = self.txt_mod(temb)
236256
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1)
237257
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1)
238258

239-
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1)
259+
img_modulated, img_gate1 = self._modulate(self.img_norm1(hidden_states), img_mod1, timestep_zero_index)
240260
del img_mod1
241261
txt_modulated, txt_gate1 = self._modulate(self.txt_norm1(encoder_hidden_states), txt_mod1)
242262
del txt_mod1
@@ -251,15 +271,15 @@ def forward(
251271
del img_modulated
252272
del txt_modulated
253273

254-
hidden_states = hidden_states + img_gate1 * img_attn_output
274+
hidden_states = self._apply_gate(img_attn_output, hidden_states, img_gate1, timestep_zero_index)
255275
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
256276
del img_attn_output
257277
del txt_attn_output
258278
del img_gate1
259279
del txt_gate1
260280

261-
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2)
262-
hidden_states = torch.addcmul(hidden_states, img_gate2, self.img_mlp(img_modulated2))
281+
img_modulated2, img_gate2 = self._modulate(self.img_norm2(hidden_states), img_mod2, timestep_zero_index)
282+
hidden_states = self._apply_gate(self.img_mlp(img_modulated2), hidden_states, img_gate2, timestep_zero_index)
263283

264284
txt_modulated2, txt_gate2 = self._modulate(self.txt_norm2(encoder_hidden_states), txt_mod2)
265285
encoder_hidden_states = torch.addcmul(encoder_hidden_states, txt_gate2, self.txt_mlp(txt_modulated2))
@@ -391,11 +411,14 @@ def _forward(
391411
hidden_states, img_ids, orig_shape = self.process_img(x)
392412
num_embeds = hidden_states.shape[1]
393413

414+
timestep_zero_index = None
394415
if ref_latents is not None:
395416
h = 0
396417
w = 0
397418
index = 0
398-
index_ref_method = kwargs.get("ref_latents_method", "index") == "index"
419+
ref_method = kwargs.get("ref_latents_method", "index")
420+
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
421+
timestep_zero = ref_method == "index_timestep_zero"
399422
for ref in ref_latents:
400423
if index_ref_method:
401424
index += 1
@@ -415,6 +438,10 @@ def _forward(
415438
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
416439
hidden_states = torch.cat([hidden_states, kontext], dim=1)
417440
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
441+
if timestep_zero:
442+
if index > 0:
443+
timestep = torch.cat([timestep, timestep * 0], dim=0)
444+
timestep_zero_index = num_embeds
418445

419446
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
420447
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
@@ -446,7 +473,7 @@ def _forward(
446473
if ("double_block", i) in blocks_replace:
447474
def block_wrap(args):
448475
out = {}
449-
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], transformer_options=args["transformer_options"])
476+
out["txt"], out["img"] = block(hidden_states=args["img"], encoder_hidden_states=args["txt"], encoder_hidden_states_mask=encoder_hidden_states_mask, temb=args["vec"], image_rotary_emb=args["pe"], timestep_zero_index=timestep_zero_index, transformer_options=args["transformer_options"])
450477
return out
451478
out = blocks_replace[("double_block", i)]({"img": hidden_states, "txt": encoder_hidden_states, "vec": temb, "pe": image_rotary_emb, "transformer_options": transformer_options}, {"original_block": block_wrap})
452479
hidden_states = out["img"]
@@ -458,6 +485,7 @@ def block_wrap(args):
458485
encoder_hidden_states_mask=encoder_hidden_states_mask,
459486
temb=temb,
460487
image_rotary_emb=image_rotary_emb,
488+
timestep_zero_index=timestep_zero_index,
461489
transformer_options=transformer_options,
462490
)
463491

@@ -474,6 +502,9 @@ def block_wrap(args):
474502
if add is not None:
475503
hidden_states[:, :add.shape[1]] += add
476504

505+
if timestep_zero_index is not None:
506+
temb = temb.chunk(2, dim=0)[0]
507+
477508
hidden_states = self.norm_out(hidden_states, temb)
478509
hidden_states = self.proj_out(hidden_states)
479510

comfy_extras/nodes_flux.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,12 +154,13 @@ class FluxKontextMultiReferenceLatentMethod(io.ComfyNode):
154154
def define_schema(cls):
155155
return io.Schema(
156156
node_id="FluxKontextMultiReferenceLatentMethod",
157+
display_name="Edit Model Reference Method",
157158
category="advanced/conditioning/flux",
158159
inputs=[
159160
io.Conditioning.Input("conditioning"),
160161
io.Combo.Input(
161162
"reference_latents_method",
162-
options=["offset", "index", "uxo/uno"],
163+
options=["offset", "index", "uxo/uno", "index_timestep_zero"],
163164
),
164165
],
165166
outputs=[

0 commit comments

Comments
 (0)