2222from diffusers .utils .torch_utils import maybe_allow_in_graph
2323from torch import nn
2424
25- from videosys .core .comm import all_to_all_comm , gather_sequence , get_pad , set_pad , split_sequence
25+ from videosys .core .comm import all_to_all_with_pad , gather_sequence , get_pad , set_pad , split_sequence
2626from videosys .core .pab_mgr import enable_pab , if_broadcast_spatial
2727from videosys .core .parallel_mgr import ParallelManager
2828from videosys .models .modules .embeddings import apply_rotary_emb
@@ -42,6 +42,32 @@ def __init__(self):
4242 if not hasattr (F , "scaled_dot_product_attention" ):
4343 raise ImportError ("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
4444
45+ def _remove_extra_encoder (self , hidden_states , text_seq_length , attn ):
46+ # current layout is [text, 1/n seq, text, 1/n seq, ...]
47+ # we want to remove the all the text info [text, seq]
48+ sp_size = attn .parallel_manager .sp_size
49+ split_seq = hidden_states .split (hidden_states .size (2 ) // sp_size , dim = 2 )
50+ encoder_hidden_states = split_seq [0 ][:, :, :text_seq_length ]
51+ new_seq = [encoder_hidden_states ]
52+ for i in range (sp_size ):
53+ new_seq .append (split_seq [i ][:, :, text_seq_length :])
54+ hidden_states = torch .cat (new_seq , dim = 2 )
55+ return hidden_states
56+
57+ def _add_extra_encoder (self , hidden_states , text_seq_length , attn ):
58+ # current layout is [text, seq]
59+ # we want to add the extra encoder info [text, 1/n seq, text, 1/n seq, ...]
60+ sp_size = attn .parallel_manager .sp_size
61+ encoder = hidden_states [:, :text_seq_length ]
62+ seq = hidden_states [:, text_seq_length :]
63+ seq = seq .split (seq .size (1 ) // sp_size , dim = 1 )
64+ new_seq = []
65+ for i in range (sp_size ):
66+ new_seq .append (encoder )
67+ new_seq .append (seq [i ])
68+ hidden_states = torch .cat (new_seq , dim = 1 )
69+ return hidden_states
70+
4571 def __call__ (
4672 self ,
4773 attn : Attention ,
@@ -72,7 +98,9 @@ def __call__(
7298 ), f"Number of heads { attn .heads } must be divisible by sequence parallel size { attn .parallel_manager .sp_size } "
7399 attn_heads = attn .heads // attn .parallel_manager .sp_size
74100 query , key , value = map (
75- lambda x : all_to_all_comm (x , attn .parallel_manager .sp_group , scatter_dim = 2 , gather_dim = 1 ),
101+ lambda x : all_to_all_with_pad (
102+ x , attn .parallel_manager .sp_group , scatter_dim = 2 , gather_dim = 1 , gather_pad = get_pad ("pad" )
103+ ),
76104 [query , key , value ],
77105 )
78106 else :
@@ -90,6 +118,13 @@ def __call__(
90118 if attn .norm_k is not None :
91119 key = attn .norm_k (key )
92120
121+ if attn .parallel_manager .sp_size > 1 :
122+ # remove extra encoder for attention
123+ query , key , value = map (
124+ lambda x : self ._remove_extra_encoder (x , text_seq_length , attn ),
125+ [query , key , value ],
126+ )
127+
93128 # Apply RoPE if needed
94129 if image_rotary_emb is not None :
95130 emb_len = image_rotary_emb [0 ].shape [0 ]
@@ -108,76 +143,15 @@ def __call__(
108143 hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn_heads * head_dim )
109144
110145 if attn .parallel_manager .sp_size > 1 :
111- hidden_states = all_to_all_comm (hidden_states , attn .parallel_manager .sp_group , scatter_dim = 1 , gather_dim = 2 )
112-
113- # linear proj
114- hidden_states = attn .to_out [0 ](hidden_states )
115- # dropout
116- hidden_states = attn .to_out [1 ](hidden_states )
117-
118- encoder_hidden_states , hidden_states = hidden_states .split (
119- [text_seq_length , hidden_states .size (1 ) - text_seq_length ], dim = 1
120- )
121- return hidden_states , encoder_hidden_states
122-
123-
124- class FusedCogVideoXAttnProcessor2_0 :
125- r"""
126- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
127- query and key vectors, but does not include spatial normalization.
128- """
129-
130- def __init__ (self ):
131- if not hasattr (F , "scaled_dot_product_attention" ):
132- raise ImportError ("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." )
133-
134- def __call__ (
135- self ,
136- attn : Attention ,
137- hidden_states : torch .Tensor ,
138- encoder_hidden_states : torch .Tensor ,
139- attention_mask : Optional [torch .Tensor ] = None ,
140- image_rotary_emb : Optional [torch .Tensor ] = None ,
141- ) -> torch .Tensor :
142- text_seq_length = encoder_hidden_states .size (1 )
143-
144- hidden_states = torch .cat ([encoder_hidden_states , hidden_states ], dim = 1 )
145-
146- batch_size , sequence_length , _ = (
147- hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
148- )
149-
150- if attention_mask is not None :
151- attention_mask = attn .prepare_attention_mask (attention_mask , sequence_length , batch_size )
152- attention_mask = attention_mask .view (batch_size , attn .heads , - 1 , attention_mask .shape [- 1 ])
153-
154- qkv = attn .to_qkv (hidden_states )
155- split_size = qkv .shape [- 1 ] // 3
156- query , key , value = torch .split (qkv , split_size , dim = - 1 )
157-
158- inner_dim = key .shape [- 1 ]
159- head_dim = inner_dim // attn .heads
160-
161- query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
162- key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
163- value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
164-
165- if attn .norm_q is not None :
166- query = attn .norm_q (query )
167- if attn .norm_k is not None :
168- key = attn .norm_k (key )
169-
170- # Apply RoPE if needed
171- if image_rotary_emb is not None :
172- query [:, :, text_seq_length :] = apply_rotary_emb (query [:, :, text_seq_length :], image_rotary_emb )
173- if not attn .is_cross_attention :
174- key [:, :, text_seq_length :] = apply_rotary_emb (key [:, :, text_seq_length :], image_rotary_emb )
175-
176- hidden_states = F .scaled_dot_product_attention (
177- query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
178- )
179-
180- hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
146+ # add extra encoder for all_to_all
147+ hidden_states = self ._add_extra_encoder (hidden_states , text_seq_length , attn )
148+ hidden_states = all_to_all_with_pad (
149+ hidden_states ,
150+ attn .parallel_manager .sp_group ,
151+ scatter_dim = 1 ,
152+ gather_dim = 2 ,
153+ scatter_pad = get_pad ("pad" ),
154+ )
181155
182156 # linear proj
183157 hidden_states = attn .to_out [0 ](hidden_states )
0 commit comments