1717import torch
1818import torch .nn as nn
1919import torch .nn .functional as F
20- from ... loaders import PeftAdapterMixin
20+
2121from ...configuration_utils import ConfigMixin , register_to_config
2222from ...models .attention import FeedForward
2323from ...models .attention_processor import Attention
2424from ...models .modeling_utils import ModelMixin
2525from ...models .normalization import AdaLayerNormContinuous
2626from ...utils import logging
27- from ..cache_utils import CacheMixin
2827from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
2928from ..modeling_outputs import Transformer2DModelOutput
29+ from ...loaders import PeftAdapterMixin
30+ from ..cache_utils import CacheMixin
3031
3132
3233logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -123,10 +124,11 @@ def __call__(
123124 attn : Attention ,
124125 hidden_states : torch .Tensor ,
125126 encoder_hidden_states : torch .Tensor ,
126- attention_mask : Optional [torch .Tensor ] = None ,
127+ attention_mask : Optional [torch .LongTensor ] = None ,
127128 image_rotary_emb : Optional [torch .Tensor ] = None ,
128129 ) -> torch .Tensor :
129- text_seq_length = encoder_hidden_states .size (1 )
130+ batch_size , text_seq_length , embed_dim = encoder_hidden_states .shape
131+ batch_size , image_seq_length , embed_dim = hidden_states .shape
130132 hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
131133
132134 # 1. QKV projections
@@ -156,8 +158,18 @@ def __call__(
156158 )
157159
158160 # 4. Attention
161+ if attention_mask is not None :
162+ # construct attention_mask for concated sequence
163+ text_attention_mask = attention_mask .float ().to (query .device )
164+ attention_mask = torch .ones ((batch_size , text_seq_length + image_seq_length ), device = query .device )
165+ attention_mask [:, :text_seq_length ] = text_attention_mask
166+ attention_mask = attention_mask .unsqueeze (2 )
167+ attention_mask_matrix = attention_mask @ attention_mask .mT
168+ attention_mask_matrix = attention_mask_matrix == 1
169+ attention_mask_matrix = attention_mask_matrix .unsqueeze (1 )
170+
159171 hidden_states = F .scaled_dot_product_attention (
160- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
172+ query , key , value , attn_mask = attention_mask_matrix , dropout_p = 0.0 , is_causal = False
161173 )
162174 hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 )
163175 hidden_states = hidden_states .type_as (query )
@@ -203,6 +215,8 @@ def forward(
203215 encoder_hidden_states : torch .Tensor ,
204216 temb : Optional [torch .Tensor ] = None ,
205217 image_rotary_emb : Optional [torch .Tensor ] = None ,
218+ attention_mask : Optional [torch .Tensor ] = None ,
219+ ** kwargs ,
206220 ) -> torch .Tensor :
207221 # 1. Timestep conditioning
208222 (
@@ -223,6 +237,8 @@ def forward(
223237 hidden_states = norm_hidden_states ,
224238 encoder_hidden_states = norm_encoder_hidden_states ,
225239 image_rotary_emb = image_rotary_emb ,
240+ attention_mask = attention_mask ,
241+ ** kwargs ,
226242 )
227243 hidden_states = hidden_states + attn_hidden_states * gate_msa .unsqueeze (1 )
228244 encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa .unsqueeze (1 )
@@ -233,8 +249,8 @@ def forward(
233249 1 + c_scale_mlp .unsqueeze (1 )
234250 ) + c_shift_mlp .unsqueeze (1 )
235251
236- ff_output = self .ff (norm_hidden_states )
237- ff_output_context = self .ff (norm_encoder_hidden_states )
252+ ff_output = self .ff (norm_hidden_states , ** kwargs )
253+ ff_output_context = self .ff (norm_encoder_hidden_states , ** kwargs )
238254 hidden_states = hidden_states + ff_output * gate_mlp .unsqueeze (1 )
239255 encoder_hidden_states = encoder_hidden_states + ff_output_context * c_gate_mlp .unsqueeze (1 )
240256
@@ -381,6 +397,8 @@ def forward(
381397 target_size : torch .Tensor ,
382398 crop_coords : torch .Tensor ,
383399 return_dict : bool = True ,
400+ attention_mask : Optional [torch .Tensor ] = None ,
401+ ** kwargs ,
384402 ) -> Union [torch .Tensor , Transformer2DModelOutput ]:
385403 batch_size , num_channels , height , width = hidden_states .shape
386404
@@ -391,6 +409,7 @@ def forward(
391409 p = self .config .patch_size
392410 post_patch_height = height // p
393411 post_patch_width = width // p
412+
394413 hidden_states , encoder_hidden_states = self .patch_embed (hidden_states , encoder_hidden_states )
395414
396415 temb = self .time_condition_embed (timestep , original_size , target_size , crop_coords , hidden_states .dtype )
@@ -400,11 +419,11 @@ def forward(
400419 for block in self .transformer_blocks :
401420 if torch .is_grad_enabled () and self .gradient_checkpointing :
402421 hidden_states , encoder_hidden_states = self ._gradient_checkpointing_func (
403- block , hidden_states , encoder_hidden_states , temb , image_rotary_emb
422+ block , hidden_states , encoder_hidden_states , temb , image_rotary_emb , attention_mask , ** kwargs
404423 )
405424 else :
406425 hidden_states , encoder_hidden_states = block (
407- hidden_states , encoder_hidden_states , temb , image_rotary_emb
426+ hidden_states , encoder_hidden_states , temb , image_rotary_emb , attention_mask , ** kwargs
408427 )
409428
410429 # 4. Output norm & projection
0 commit comments