Skip to content

Commit 8ccc0c9

Browse files
Make omni stuff work on regular z image for easier testing. (#11985)
1 parent 4edb87a commit 8ccc0c9

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

comfy/ldm/lumina/model.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -657,7 +657,7 @@ def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False,
657657
device = x.device
658658
embeds, freqs_cis, cap_feats_len = self.embed_cap(cap_feats, offset=offset, bsz=bsz, device=device, dtype=x.dtype)
659659

660-
if not omni:
660+
if (not omni) or self.siglip_embedder is None:
661661
cap_feats_len = embeds[0].shape[1] + offset
662662
embeds += (None,)
663663
freqs_cis += (None,)
@@ -675,8 +675,9 @@ def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False,
675675
siglip_feats, pad_extra = pad_zimage(siglip_feats, self.siglip_pad_token, self.pad_tokens_multiple) # TODO: double check
676676
siglip_pos_ids = torch.nn.functional.pad(siglip_pos_ids, (0, 0, 0, pad_extra))
677677
else:
678-
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
679-
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
678+
if self.siglip_pad_token is not None:
679+
siglip_feats = self.siglip_pad_token.to(device=device, dtype=x.dtype, copy=True).unsqueeze(0).repeat(bsz, self.pad_tokens_multiple, 1)
680+
siglip_pos_ids = torch.zeros((bsz, siglip_feats.shape[1], 3), dtype=torch.float32, device=device)
680681

681682
if siglip_feats is None:
682683
embeds += (None,)
@@ -724,8 +725,9 @@ def patchify_and_embed(
724725

725726
out = self.embed_all(ref, ref_con, sig_feat, offset=start_t, omni=omni, transformer_options=transformer_options)
726727
for i, e in enumerate(out[0]):
727-
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
728-
freqs_cis[i].append(out[1][i])
728+
if e is not None:
729+
embeds[i].append(comfy.utils.repeat_to_batch_size(e, bsz))
730+
freqs_cis[i].append(out[1][i])
729731
start_t = out[2]
730732
leftover_cap = ref_contexts[len(ref_latents):]
731733

@@ -759,7 +761,7 @@ def patchify_and_embed(
759761
feats = (cap_feats,)
760762
fc = (cap_freqs_cis,)
761763

762-
if omni:
764+
if omni and len(embeds[1]) > 0:
763765
siglip_mask = None
764766
siglip_feats_combined = torch.cat(embeds[1], dim=1)
765767
siglip_feats_freqs_cis = torch.cat(freqs_cis[1], dim=1)

0 commit comments

Comments
 (0)