11import torch
22import torch .nn as nn
3- from typing import Any , Dict , Tuple , Union , Optional
3+ from typing import Any , Dict , List , Tuple , Union , Optional
44from einops import rearrange
55
66from diffsynth_engine .models .base import StateDictConverter , PreTrainedModel
@@ -190,7 +190,8 @@ def forward(
190190 self ,
191191 image : torch .FloatTensor ,
192192 text : torch .FloatTensor ,
193- image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
193+ rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
194+ attn_mask : Optional [torch .Tensor ] = None ,
194195 ) -> Tuple [torch .FloatTensor , torch .FloatTensor ]:
195196 img_q , img_k , img_v = self .to_q (image ), self .to_k (image ), self .to_v (image )
196197 txt_q , txt_k , txt_v = self .add_q_proj (text ), self .add_k_proj (text ), self .add_v_proj (text )
@@ -206,8 +207,8 @@ def forward(
206207 img_q , img_k = self .norm_q (img_q ), self .norm_k (img_k )
207208 txt_q , txt_k = self .norm_added_q (txt_q ), self .norm_added_k (txt_k )
208209
209- if image_rotary_emb is not None :
210- img_freqs , txt_freqs = image_rotary_emb
210+ if rotary_emb is not None :
211+ img_freqs , txt_freqs = rotary_emb
211212 img_q = apply_rotary_emb_qwen (img_q , img_freqs )
212213 img_k = apply_rotary_emb_qwen (img_k , img_freqs )
213214 txt_q = apply_rotary_emb_qwen (txt_q , txt_freqs )
@@ -221,7 +222,7 @@ def forward(
221222 joint_k = joint_k .transpose (1 , 2 )
222223 joint_v = joint_v .transpose (1 , 2 )
223224
224- joint_attn_out = attention_ops .attention (joint_q , joint_k , joint_v , ** self .attn_kwargs )
225+ joint_attn_out = attention_ops .attention (joint_q , joint_k , joint_v , attn_mask = attn_mask , ** self .attn_kwargs )
225226
226227 joint_attn_out = rearrange (joint_attn_out , "b s h d -> b s (h d)" ).to (joint_q .dtype )
227228
@@ -285,7 +286,8 @@ def forward(
285286 image : torch .Tensor ,
286287 text : torch .Tensor ,
287288 temb : torch .Tensor ,
288- image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
289+ rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
290+ attn_mask : Optional [torch .Tensor ] = None ,
289291 ) -> Tuple [torch .Tensor , torch .Tensor ]:
290292 img_mod_attn , img_mod_mlp = self .img_mod (temb ).chunk (2 , dim = - 1 ) # [B, 3*dim] each
291293 txt_mod_attn , txt_mod_mlp = self .txt_mod (temb ).chunk (2 , dim = - 1 ) # [B, 3*dim] each
@@ -299,7 +301,8 @@ def forward(
299301 img_attn_out , txt_attn_out = self .attn (
300302 image = img_modulated ,
301303 text = txt_modulated ,
302- image_rotary_emb = image_rotary_emb ,
304+ rotary_emb = rotary_emb ,
305+ attn_mask = attn_mask ,
303306 )
304307
305308 image = image + img_gate * img_attn_out
@@ -368,13 +371,74 @@ def unpatchify(self, hidden_states, height, width):
368371 )
369372 return hidden_states
370373
374+ def process_entity_masks (
375+ self ,
376+ text : torch .Tensor ,
377+ text_seq_lens : torch .LongTensor ,
378+ rotary_emb : Tuple [torch .Tensor , torch .Tensor ],
379+ video_fhw : List [Tuple [int , int , int ]],
380+ entity_text : List [torch .Tensor ],
381+ entity_seq_lens : List [torch .LongTensor ],
382+ entity_masks : List [torch .Tensor ],
383+ device : str ,
384+ dtype : torch .dtype ,
385+ ):
386+ entity_seq_lens = [seq_lens .max ().item () for seq_lens in entity_seq_lens ]
387+ text_seq_lens = entity_seq_lens + [text_seq_lens .max ().item ()]
388+ entity_text = [
389+ self .txt_in (self .txt_norm (text [:, :seq_len ])) for text , seq_len in zip (entity_text , entity_seq_lens )
390+ ]
391+ text = torch .cat (entity_text + [text ], dim = 1 )
392+
393+ entity_txt_freqs = [self .pos_embed (video_fhw , seq_len , device )[1 ] for seq_len in entity_seq_lens ]
394+ img_freqs , txt_freqs = rotary_emb
395+ txt_freqs = torch .cat (entity_txt_freqs + [txt_freqs ], dim = 0 )
396+ rotary_emb = (img_freqs , txt_freqs )
397+
398+ global_mask = torch .ones_like (entity_masks [0 ], device = device , dtype = dtype )
399+ patched_masks = [self .patchify (mask ) for mask in entity_masks + [global_mask ]]
400+ batch_size , image_seq_len = patched_masks [0 ].shape [:2 ]
401+ total_seq_len = sum (text_seq_lens ) + image_seq_len
402+ attention_mask = torch .ones ((batch_size , total_seq_len , total_seq_len ), device = device , dtype = torch .bool )
403+
404+ # text-image attention mask
405+ img_start , img_end = sum (text_seq_lens ), total_seq_len
406+ cumsum = [0 ]
407+ for seq_len in text_seq_lens :
408+ cumsum .append (cumsum [- 1 ] + seq_len )
409+ for i , patched_mask in enumerate (patched_masks ):
410+ txt_start , txt_end = cumsum [i ], cumsum [i + 1 ]
411+ mask = torch .sum (patched_mask , dim = - 1 ) > 0
412+ mask = mask .unsqueeze (1 ).repeat (1 , text_seq_lens [i ], 1 )
413+ # text-to-image attention
414+ attention_mask [:, txt_start :txt_end , img_start :img_end ] = mask
415+ # image-to-text attention
416+ attention_mask [:, img_start :img_end , txt_start :txt_end ] = mask .transpose (1 , 2 )
417+ # entity text tokens should not attend to each other
418+ for i in range (len (text_seq_lens )):
419+ for j in range (len (text_seq_lens )):
420+ if i == j :
421+ continue
422+ i_start , i_end = cumsum [i ], cumsum [i + 1 ]
423+ j_start , j_end = cumsum [j ], cumsum [j + 1 ]
424+ attention_mask [:, i_start :i_end , j_start :j_end ] = False
425+
426+ attn_mask = torch .zeros_like (attention_mask , device = device , dtype = dtype )
427+ attn_mask [~ attention_mask ] = - torch .inf
428+ attn_mask = attn_mask .unsqueeze (1 )
429+ return text , rotary_emb , attn_mask
430+
371431 def forward (
372432 self ,
373433 image : torch .Tensor ,
374434 edit : torch .Tensor = None ,
375- text : torch .Tensor = None ,
376435 timestep : torch .LongTensor = None ,
377- txt_seq_lens : torch .LongTensor = None ,
436+ text : torch .Tensor = None ,
437+ text_seq_lens : torch .LongTensor = None ,
438+ context_latents : Optional [torch .Tensor ] = None ,
439+ entity_text : Optional [List [torch .Tensor ]] = None ,
440+ entity_seq_lens : Optional [List [torch .LongTensor ]] = None ,
441+ entity_masks : Optional [List [torch .Tensor ]] = None ,
378442 ):
379443 h , w = image .shape [- 2 :]
380444 fp8_linear_enabled = getattr (self , "fp8_linear_enabled" , False )
@@ -386,35 +450,59 @@ def forward(
386450 (
387451 image ,
388452 edit ,
389- text ,
390453 timestep ,
391- txt_seq_lens ,
454+ text ,
455+ text_seq_lens ,
456+ * (entity_text if entity_text is not None else ()),
457+ * (entity_seq_lens if entity_seq_lens is not None else ()),
458+ * (entity_masks if entity_masks is not None else ()),
459+ context_latents ,
392460 ),
393461 use_cfg = use_cfg ,
394462 ),
395463 ):
396464 conditioning = self .time_text_embed (timestep , image .dtype )
397465 video_fhw = [(1 , h // 2 , w // 2 )] # frame, height, width
398- max_length = txt_seq_lens .max ().item ()
466+ text_seq_len = text_seq_lens .max ().item ()
399467 image = self .patchify (image )
400468 image_seq_len = image .shape [1 ]
469+ if context_latents is not None :
470+ context_latents = context_latents .to (dtype = image .dtype )
471+ context_latents = self .patchify (context_latents )
472+ image = torch .cat ([image , context_latents ], dim = 1 )
473+ video_fhw += [(1 , h // 2 , w // 2 )]
401474 if edit is not None :
402475 edit = edit .to (dtype = image .dtype )
403476 edit = self .patchify (edit )
404477 image = torch .cat ([image , edit ], dim = 1 )
405- video_fhw += video_fhw
478+ video_fhw += [( 1 , h // 2 , w // 2 )]
406479
407- image_rotary_emb = self .pos_embed (video_fhw , max_length , image .device )
480+ rotary_emb = self .pos_embed (video_fhw , text_seq_len , image .device )
408481
409482 image = self .img_in (image )
410- text = self .txt_in (self .txt_norm (text [:, :max_length ]))
483+ text = self .txt_in (self .txt_norm (text [:, :text_seq_len ]))
484+
485+ attn_mask = None
486+ if entity_text is not None :
487+ text , rotary_emb , attn_mask = self .process_entity_masks (
488+ text ,
489+ text_seq_lens ,
490+ rotary_emb ,
491+ video_fhw ,
492+ entity_text ,
493+ entity_seq_lens ,
494+ entity_masks ,
495+ image .device ,
496+ image .dtype ,
497+ )
411498
412499 for block in self .transformer_blocks :
413- text , image = block (image = image , text = text , temb = conditioning , image_rotary_emb = image_rotary_emb )
500+ text , image = block (
501+ image = image , text = text , temb = conditioning , rotary_emb = rotary_emb , attn_mask = attn_mask
502+ )
414503 image = self .norm_out (image , conditioning )
415504 image = self .proj_out (image )
416- if edit is not None :
417- image = image [:, :image_seq_len ]
505+ image = image [:, :image_seq_len ]
418506
419507 image = self .unpatchify (image , h , w )
420508
0 commit comments