Skip to content

Commit 28eaab6

Browse files
Diffusion model part of Qwen Image Layered. (#11408)
Only thing missing after this is some nodes to make using it easier.
1 parent 6a2678a commit 28eaab6

File tree

2 files changed

+42
-24
lines changed

2 files changed

+42
-24
lines changed

comfy/ldm/qwen_image/model.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def apply_rotary_emb(x, freqs_cis):
6161

6262

6363
class QwenTimestepProjEmbeddings(nn.Module):
64-
def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None, operations=None):
64+
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
6565
super().__init__()
6666
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
6767
self.timestep_embedder = TimestepEmbedding(
@@ -72,9 +72,19 @@ def __init__(self, embedding_dim, pooled_projection_dim, dtype=None, device=None
7272
operations=operations
7373
)
7474

75-
def forward(self, timestep, hidden_states):
75+
self.use_additional_t_cond = use_additional_t_cond
76+
if self.use_additional_t_cond:
77+
self.addition_t_embedding = operations.Embedding(2, embedding_dim, device=device, dtype=dtype)
78+
79+
def forward(self, timestep, hidden_states, addition_t_cond=None):
7680
timesteps_proj = self.time_proj(timestep)
7781
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
82+
83+
if self.use_additional_t_cond:
84+
if addition_t_cond is None:
85+
addition_t_cond = torch.zeros((timesteps_emb.shape[0]), device=timesteps_emb.device, dtype=torch.long)
86+
timesteps_emb += self.addition_t_embedding(addition_t_cond, out_dtype=timesteps_emb.dtype)
87+
7888
return timesteps_emb
7989

8090

@@ -320,11 +330,11 @@ def __init__(
320330
num_attention_heads: int = 24,
321331
joint_attention_dim: int = 3584,
322332
pooled_projection_dim: int = 768,
323-
guidance_embeds: bool = False,
324333
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
325334
default_ref_method="index",
326335
image_model=None,
327336
final_layer=True,
337+
use_additional_t_cond=False,
328338
dtype=None,
329339
device=None,
330340
operations=None,
@@ -342,6 +352,7 @@ def __init__(
342352
self.time_text_embed = QwenTimestepProjEmbeddings(
343353
embedding_dim=self.inner_dim,
344354
pooled_projection_dim=pooled_projection_dim,
355+
use_additional_t_cond=use_additional_t_cond,
345356
dtype=dtype,
346357
device=device,
347358
operations=operations
@@ -375,36 +386,42 @@ def process_img(self, x, index=0, h_offset=0, w_offset=0):
375386
patch_size = self.patch_size
376387
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (1, self.patch_size, self.patch_size))
377388
orig_shape = hidden_states.shape
378-
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
379-
hidden_states = hidden_states.permute(0, 2, 4, 1, 3, 5)
380-
hidden_states = hidden_states.reshape(orig_shape[0], (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
389+
hidden_states = hidden_states.view(orig_shape[0], orig_shape[1], orig_shape[-3], orig_shape[-2] // 2, 2, orig_shape[-1] // 2, 2)
390+
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
391+
hidden_states = hidden_states.reshape(orig_shape[0], orig_shape[-3] * (orig_shape[-2] // 2) * (orig_shape[-1] // 2), orig_shape[1] * 4)
392+
t_len = t
381393
h_len = ((h + (patch_size // 2)) // patch_size)
382394
w_len = ((w + (patch_size // 2)) // patch_size)
383395

384396
h_offset = ((h_offset + (patch_size // 2)) // patch_size)
385397
w_offset = ((w_offset + (patch_size // 2)) // patch_size)
386398

387-
img_ids = torch.zeros((h_len, w_len, 3), device=x.device)
388-
img_ids[:, :, 0] = img_ids[:, :, 1] + index
389-
img_ids[:, :, 1] = img_ids[:, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1) - (h_len // 2)
390-
img_ids[:, :, 2] = img_ids[:, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0) - (w_len // 2)
391-
return hidden_states, repeat(img_ids, "h w c -> b (h w) c", b=bs), orig_shape
399+
img_ids = torch.zeros((t_len, h_len, w_len, 3), device=x.device)
400+
401+
if t_len > 1:
402+
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(0, t_len - 1, steps=t_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(1)
403+
else:
404+
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + index
405+
406+
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_offset, h_len - 1 + h_offset, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1).unsqueeze(0) - (h_len // 2)
407+
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_offset, w_len - 1 + w_offset, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0).unsqueeze(0) - (w_len // 2)
408+
return hidden_states, repeat(img_ids, "t h w c -> b (t h w) c", b=bs), orig_shape
392409

393-
def forward(self, x, timestep, context, attention_mask=None, guidance=None, ref_latents=None, transformer_options={}, **kwargs):
410+
def forward(self, x, timestep, context, attention_mask=None, ref_latents=None, additional_t_cond=None, transformer_options={}, **kwargs):
394411
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
395412
self._forward,
396413
self,
397414
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, transformer_options)
398-
).execute(x, timestep, context, attention_mask, guidance, ref_latents, transformer_options, **kwargs)
415+
).execute(x, timestep, context, attention_mask, ref_latents, additional_t_cond, transformer_options, **kwargs)
399416

400417
def _forward(
401418
self,
402419
x,
403420
timesteps,
404421
context,
405422
attention_mask=None,
406-
guidance: torch.Tensor = None,
407423
ref_latents=None,
424+
additional_t_cond=None,
408425
transformer_options={},
409426
control=None,
410427
**kwargs
@@ -423,12 +440,17 @@ def _forward(
423440
index = 0
424441
ref_method = kwargs.get("ref_latents_method", self.default_ref_method)
425442
index_ref_method = (ref_method == "index") or (ref_method == "index_timestep_zero")
443+
negative_ref_method = ref_method == "negative_index"
426444
timestep_zero = ref_method == "index_timestep_zero"
427445
for ref in ref_latents:
428446
if index_ref_method:
429447
index += 1
430448
h_offset = 0
431449
w_offset = 0
450+
elif negative_ref_method:
451+
index -= 1
452+
h_offset = 0
453+
w_offset = 0
432454
else:
433455
index = 1
434456
h_offset = 0
@@ -458,14 +480,7 @@ def _forward(
458480
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
459481
encoder_hidden_states = self.txt_in(encoder_hidden_states)
460482

461-
if guidance is not None:
462-
guidance = guidance * 1000
463-
464-
temb = (
465-
self.time_text_embed(timestep, hidden_states)
466-
if guidance is None
467-
else self.time_text_embed(timestep, guidance, hidden_states)
468-
)
483+
temb = self.time_text_embed(timestep, hidden_states, additional_t_cond)
469484

470485
patches_replace = transformer_options.get("patches_replace", {})
471486
patches = transformer_options.get("patches", {})
@@ -513,6 +528,6 @@ def block_wrap(args):
513528
hidden_states = self.norm_out(hidden_states, temb)
514529
hidden_states = self.proj_out(hidden_states)
515530

516-
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
517-
hidden_states = hidden_states.permute(0, 3, 1, 4, 2, 5)
531+
hidden_states = hidden_states[:, :num_embeds].view(orig_shape[0], orig_shape[-3], orig_shape[-2] // 2, orig_shape[-1] // 2, orig_shape[1], 2, 2)
532+
hidden_states = hidden_states.permute(0, 4, 1, 2, 5, 3, 6)
518533
return hidden_states.reshape(orig_shape)[:, :, :, :x.shape[-2], :x.shape[-1]]

comfy/model_detection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
620620
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}transformer_blocks.'.format(key_prefix) + '{}.')
621621
if "{}__index_timestep_zero__".format(key_prefix) in state_dict_keys: # 2511
622622
dit_config["default_ref_method"] = "index_timestep_zero"
623+
if "{}time_text_embed.addition_t_embedding.weight".format(key_prefix) in state_dict_keys: # Layered
624+
dit_config["use_additional_t_cond"] = True
625+
dit_config["default_ref_method"] = "negative_index"
623626
return dit_config
624627

625628
if '{}visual_transformer_blocks.0.cross_attention.key_norm.weight'.format(key_prefix) in state_dict_keys: # Kandinsky 5

0 commit comments

Comments
 (0)