1717import numpy as np
1818import torch
1919import torch .nn as nn
20- import torch .nn .functional as F
2120
2221from ...configuration_utils import ConfigMixin , register_to_config
2322from ...loaders import FromOriginalModelMixin
2423from ...utils import is_torchvision_available
2524from ..attention import FeedForward
25+ from ..attention_dispatch import dispatch_attention_fn
2626from ..attention_processor import Attention
2727from ..embeddings import Timesteps
2828from ..modeling_outputs import Transformer2DModelOutput
@@ -152,10 +152,10 @@ def forward(
152152
153153class CosmosAttnProcessor2_0 :
154154 def __init__ (self ):
155- if not hasattr (F , "scaled_dot_product_attention" ):
155+ if not hasattr (torch . nn . functional , "scaled_dot_product_attention" ):
156156 raise ImportError ("CosmosAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0." )
157157
158- def compute_attn (
158+ def __call__ (
159159 self ,
160160 attn : Attention ,
161161 hidden_states : torch .Tensor ,
@@ -199,70 +199,26 @@ def compute_attn(
199199 value = value .repeat_interleave (query_idx // value_idx , dim = 3 )
200200
201201 # 5. Attention
202- hidden_states = F .scaled_dot_product_attention (
203- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
204- )
205- hidden_states = hidden_states .transpose (1 , 2 ).flatten (2 , 3 ).type_as (query )
206- return hidden_states
207-
208- def __call__ (
209- self ,
210- attn : Attention ,
211- hidden_states : torch .Tensor ,
212- encoder_hidden_states : Optional [torch .Tensor ] = None ,
213- attention_mask : Optional [torch .Tensor ] = None ,
214- image_rotary_emb : Optional [torch .Tensor ] = None ,
215- ) -> torch .Tensor :
216- hidden_states = self .compute_attn (
217- attn = attn ,
218- hidden_states = hidden_states ,
219- encoder_hidden_states = encoder_hidden_states ,
220- attention_mask = attention_mask ,
221- image_rotary_emb = image_rotary_emb ,
202+ hidden_states = dispatch_attention_fn (
203+ query .transpose (1 , 2 ),
204+ key .transpose (1 , 2 ),
205+ value .transpose (1 , 2 ),
206+ attn_mask = attention_mask ,
207+ dropout_p = 0.0 ,
208+ is_causal = False ,
222209 )
210+ hidden_states = hidden_states .flatten (2 , 3 ).type_as (query )
223211 hidden_states = attn .to_out [0 ](hidden_states )
224212 hidden_states = attn .to_out [1 ](hidden_states )
225213
226214 return hidden_states
227215
228216
229- class CosmosAttnProcessor2_5 ( CosmosAttnProcessor2_0 ) :
217+ class CosmosAttnProcessor2_5 :
230218 def __init__ (self ):
231219 if not hasattr (torch .nn .functional , "scaled_dot_product_attention" ):
232220 raise ImportError ("CosmosAttnProcessor2_5 requires PyTorch 2.0. Please upgrade PyTorch to 2.0 or newer." )
233221
234- def compute_attn_i2v (
235- self ,
236- attn : Attention ,
237- hidden_states : torch .Tensor ,
238- img_context = None ,
239- attention_mask = None ,
240- ):
241- q_img = attn .q_img (hidden_states )
242- k_img = attn .k_img (img_context )
243- v_img = attn .v_img (img_context )
244-
245- batch_size = hidden_states .shape [0 ]
246-
247- dim_head = attn .out_dim // attn .heads
248- q_img = q_img .view (batch_size , - 1 , attn .heads , dim_head ).transpose (1 , 2 )
249- k_img = k_img .view (batch_size , - 1 , attn .heads , dim_head ).transpose (1 , 2 )
250- v_img = v_img .view (batch_size , - 1 , attn .heads , dim_head ).transpose (1 , 2 )
251-
252- q_img = attn .q_img_norm (q_img )
253- k_img = attn .k_img_norm (k_img )
254-
255- q_img_idx = q_img .size (3 )
256- k_img_idx = k_img .size (3 )
257- v_img_idx = v_img .size (3 )
258- k_img = k_img .repeat_interleave (q_img_idx // k_img_idx , dim = 3 )
259- v_img = v_img .repeat_interleave (q_img_idx // v_img_idx , dim = 3 )
260- img_out = torch .nn .functional .scaled_dot_product_attention (
261- q_img , k_img , v_img , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
262- )
263- img_out = img_out .transpose (1 , 2 ).flatten (2 , 3 ).type_as (q_img )
264- return img_out
265-
266222 def __call__ (
267223 self ,
268224 attn : Attention ,
@@ -277,21 +233,77 @@ def __call__(
277233 text_context , img_context = encoder_hidden_states if encoder_hidden_states else (None , None )
278234 text_mask , img_mask = attention_mask if attention_mask else (None , None )
279235
280- attn_out = self .compute_attn (
281- attn = attn ,
282- hidden_states = hidden_states ,
283- encoder_hidden_states = text_context ,
284- attention_mask = text_mask ,
285- image_rotary_emb = image_rotary_emb ,
236+ if text_context is None :
237+ text_context = hidden_states
238+
239+ query = attn .to_q (hidden_states )
240+ key = attn .to_k (text_context )
241+ value = attn .to_v (text_context )
242+
243+ query = query .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
244+ key = key .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
245+ value = value .unflatten (2 , (attn .heads , - 1 )).transpose (1 , 2 )
246+
247+ query = attn .norm_q (query )
248+ key = attn .norm_k (key )
249+
250+ if image_rotary_emb is not None :
251+ from ..embeddings import apply_rotary_emb
252+
253+ query = apply_rotary_emb (query , image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 )
254+ key = apply_rotary_emb (key , image_rotary_emb , use_real = True , use_real_unbind_dim = - 2 )
255+
256+ if torch .onnx .is_in_onnx_export ():
257+ query_idx = torch .tensor (query .size (3 ), device = query .device )
258+ key_idx = torch .tensor (key .size (3 ), device = key .device )
259+ value_idx = torch .tensor (value .size (3 ), device = value .device )
260+ else :
261+ query_idx = query .size (3 )
262+ key_idx = key .size (3 )
263+ value_idx = value .size (3 )
264+ key = key .repeat_interleave (query_idx // key_idx , dim = 3 )
265+ value = value .repeat_interleave (query_idx // value_idx , dim = 3 )
266+
267+ attn_out = dispatch_attention_fn (
268+ query .transpose (1 , 2 ),
269+ key .transpose (1 , 2 ),
270+ value .transpose (1 , 2 ),
271+ attn_mask = text_mask ,
272+ dropout_p = 0.0 ,
273+ is_causal = False ,
286274 )
275+ attn_out = attn_out .flatten (2 , 3 ).type_as (query )
287276
288277 if img_context is not None :
289- img_out = self .compute_attn_i2v (
290- attn = attn ,
291- hidden_states = hidden_states ,
292- img_context = img_context ,
293- attention_mask = img_mask ,
278+ q_img = attn .q_img (hidden_states )
279+ k_img = attn .k_img (img_context )
280+ v_img = attn .v_img (img_context )
281+
282+ batch_size = hidden_states .shape [0 ]
283+ dim_head = attn .out_dim // attn .heads
284+
285+ q_img = q_img .view (batch_size , - 1 , attn .heads , dim_head ).transpose (1 , 2 )
286+ k_img = k_img .view (batch_size , - 1 , attn .heads , dim_head ).transpose (1 , 2 )
287+ v_img = v_img .view (batch_size , - 1 , attn .heads , dim_head ).transpose (1 , 2 )
288+
289+ q_img = attn .q_img_norm (q_img )
290+ k_img = attn .k_img_norm (k_img )
291+
292+ q_img_idx = q_img .size (3 )
293+ k_img_idx = k_img .size (3 )
294+ v_img_idx = v_img .size (3 )
295+ k_img = k_img .repeat_interleave (q_img_idx // k_img_idx , dim = 3 )
296+ v_img = v_img .repeat_interleave (q_img_idx // v_img_idx , dim = 3 )
297+
298+ img_out = dispatch_attention_fn (
299+ q_img .transpose (1 , 2 ),
300+ k_img .transpose (1 , 2 ),
301+ v_img .transpose (1 , 2 ),
302+ attn_mask = img_mask ,
303+ dropout_p = 0.0 ,
304+ is_causal = False ,
294305 )
306+ img_out = img_out .flatten (2 , 3 ).type_as (q_img )
295307 hidden_states = attn_out + img_out
296308 else :
297309 hidden_states = attn_out
0 commit comments