@@ -657,7 +657,7 @@ def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False,
657657 device = x .device
658658 embeds , freqs_cis , cap_feats_len = self .embed_cap (cap_feats , offset = offset , bsz = bsz , device = device , dtype = x .dtype )
659659
660- if not omni :
660+ if ( not omni ) or self . siglip_embedder is None :
661661 cap_feats_len = embeds [0 ].shape [1 ] + offset
662662 embeds += (None ,)
663663 freqs_cis += (None ,)
@@ -675,8 +675,9 @@ def embed_all(self, x, cap_feats=None, siglip_feats=None, offset=0, omni=False,
675675 siglip_feats , pad_extra = pad_zimage (siglip_feats , self .siglip_pad_token , self .pad_tokens_multiple ) # TODO: double check
676676 siglip_pos_ids = torch .nn .functional .pad (siglip_pos_ids , (0 , 0 , 0 , pad_extra ))
677677 else :
678- siglip_feats = self .siglip_pad_token .to (device = device , dtype = x .dtype , copy = True ).unsqueeze (0 ).repeat (bsz , self .pad_tokens_multiple , 1 )
679- siglip_pos_ids = torch .zeros ((bsz , siglip_feats .shape [1 ], 3 ), dtype = torch .float32 , device = device )
678+ if self .siglip_pad_token is not None :
679+ siglip_feats = self .siglip_pad_token .to (device = device , dtype = x .dtype , copy = True ).unsqueeze (0 ).repeat (bsz , self .pad_tokens_multiple , 1 )
680+ siglip_pos_ids = torch .zeros ((bsz , siglip_feats .shape [1 ], 3 ), dtype = torch .float32 , device = device )
680681
681682 if siglip_feats is None :
682683 embeds += (None ,)
@@ -724,8 +725,9 @@ def patchify_and_embed(
724725
725726 out = self .embed_all (ref , ref_con , sig_feat , offset = start_t , omni = omni , transformer_options = transformer_options )
726727 for i , e in enumerate (out [0 ]):
727- embeds [i ].append (comfy .utils .repeat_to_batch_size (e , bsz ))
728- freqs_cis [i ].append (out [1 ][i ])
728+ if e is not None :
729+ embeds [i ].append (comfy .utils .repeat_to_batch_size (e , bsz ))
730+ freqs_cis [i ].append (out [1 ][i ])
729731 start_t = out [2 ]
730732 leftover_cap = ref_contexts [len (ref_latents ):]
731733
@@ -759,7 +761,7 @@ def patchify_and_embed(
759761 feats = (cap_feats ,)
760762 fc = (cap_freqs_cis ,)
761763
762- if omni :
764+ if omni and len ( embeds [ 1 ]) > 0 :
763765 siglip_mask = None
764766 siglip_feats_combined = torch .cat (embeds [1 ], dim = 1 )
765767 siglip_feats_freqs_cis = torch .cat (freqs_cis [1 ], dim = 1 )
0 commit comments