@@ -250,15 +250,21 @@ def forward(
250250        hidden_states : torch .Tensor ,
251251        temb : torch .Tensor ,
252252        image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
253+         attention_mask : Optional [torch .Tensor ] =  None ,
253254        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
254255    ) ->  torch .Tensor :
255256        residual  =  hidden_states 
256257        norm_hidden_states , gate  =  self .norm (hidden_states , emb = temb )
257258        mlp_hidden_states  =  self .act_mlp (self .proj_mlp (norm_hidden_states ))
258259        joint_attention_kwargs  =  joint_attention_kwargs  or  {}
260+ 
261+         if  attention_mask  is  not None :
262+             attention_mask  =  attention_mask [:, None , None , :] *  attention_mask [:, None , :, None ]
263+ 
259264        attn_output  =  self .attn (
260265            hidden_states = norm_hidden_states ,
261266            image_rotary_emb = image_rotary_emb ,
267+             attention_mask = attention_mask ,
262268            ** joint_attention_kwargs ,
263269        )
264270
@@ -312,6 +318,7 @@ def forward(
312318        encoder_hidden_states : torch .Tensor ,
313319        temb : torch .Tensor ,
314320        image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] =  None ,
321+         attention_mask : Optional [torch .Tensor ] =  None ,
315322        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
316323    ) ->  Tuple [torch .Tensor , torch .Tensor ]:
317324        temb_img , temb_txt  =  temb [:, :6 ], temb [:, 6 :]
@@ -321,11 +328,15 @@ def forward(
321328            encoder_hidden_states , emb = temb_txt 
322329        )
323330        joint_attention_kwargs  =  joint_attention_kwargs  or  {}
331+         if  attention_mask  is  not None :
332+             attention_mask  =  attention_mask [:, None , None , :] *  attention_mask [:, None , :, None ]
333+ 
324334        # Attention. 
325335        attention_outputs  =  self .attn (
326336            hidden_states = norm_hidden_states ,
327337            encoder_hidden_states = norm_encoder_hidden_states ,
328338            image_rotary_emb = image_rotary_emb ,
339+             attention_mask = attention_mask ,
329340            ** joint_attention_kwargs ,
330341        )
331342
@@ -570,6 +581,7 @@ def forward(
570581        timestep : torch .LongTensor  =  None ,
571582        img_ids : torch .Tensor  =  None ,
572583        txt_ids : torch .Tensor  =  None ,
584+         attention_mask : torch .Tensor  =  None ,
573585        joint_attention_kwargs : Optional [Dict [str , Any ]] =  None ,
574586        controlnet_block_samples = None ,
575587        controlnet_single_block_samples = None ,
@@ -659,11 +671,7 @@ def forward(
659671            )
660672            if  torch .is_grad_enabled () and  self .gradient_checkpointing :
661673                encoder_hidden_states , hidden_states  =  self ._gradient_checkpointing_func (
662-                     block ,
663-                     hidden_states ,
664-                     encoder_hidden_states ,
665-                     temb ,
666-                     image_rotary_emb ,
674+                     block , hidden_states , encoder_hidden_states , temb , image_rotary_emb , attention_mask 
667675                )
668676
669677            else :
@@ -672,6 +680,7 @@ def forward(
672680                    encoder_hidden_states = encoder_hidden_states ,
673681                    temb = temb ,
674682                    image_rotary_emb = image_rotary_emb ,
683+                     attention_mask = attention_mask ,
675684                    joint_attention_kwargs = joint_attention_kwargs ,
676685                )
677686
@@ -704,6 +713,7 @@ def forward(
704713                    hidden_states = hidden_states ,
705714                    temb = temb ,
706715                    image_rotary_emb = image_rotary_emb ,
716+                     attention_mask = attention_mask ,
707717                    joint_attention_kwargs = joint_attention_kwargs ,
708718                )
709719
0 commit comments