Skip to content

Commit f99ad20

Browse files
authored
[hotfix] fix cogvideo parallel bug (#218)
* update fix cog * update fix
1 parent cd7ac04 commit f99ad20

File tree

1 file changed

+46
-72
lines changed

1 file changed

+46
-72
lines changed

videosys/models/transformers/cogvideox_transformer_3d.py

Lines changed: 46 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from diffusers.utils.torch_utils import maybe_allow_in_graph
2323
from torch import nn
2424

25-
from videosys.core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence
25+
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
2626
from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
2727
from videosys.core.parallel_mgr import ParallelManager
2828
from videosys.models.modules.embeddings import apply_rotary_emb
@@ -42,6 +42,32 @@ def __init__(self):
4242
if not hasattr(F, "scaled_dot_product_attention"):
4343
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
4444

45+
def _remove_extra_encoder(self, hidden_states, text_seq_length, attn):
46+
# current layout is [text, 1/n seq, text, 1/n seq, ...]
47+
# we want to remove the all the text info [text, seq]
48+
sp_size = attn.parallel_manager.sp_size
49+
split_seq = hidden_states.split(hidden_states.size(2) // sp_size, dim=2)
50+
encoder_hidden_states = split_seq[0][:, :, :text_seq_length]
51+
new_seq = [encoder_hidden_states]
52+
for i in range(sp_size):
53+
new_seq.append(split_seq[i][:, :, text_seq_length:])
54+
hidden_states = torch.cat(new_seq, dim=2)
55+
return hidden_states
56+
57+
def _add_extra_encoder(self, hidden_states, text_seq_length, attn):
58+
# current layout is [text, seq]
59+
# we want to add the extra encoder info [text, 1/n seq, text, 1/n seq, ...]
60+
sp_size = attn.parallel_manager.sp_size
61+
encoder = hidden_states[:, :text_seq_length]
62+
seq = hidden_states[:, text_seq_length:]
63+
seq = seq.split(seq.size(1) // sp_size, dim=1)
64+
new_seq = []
65+
for i in range(sp_size):
66+
new_seq.append(encoder)
67+
new_seq.append(seq[i])
68+
hidden_states = torch.cat(new_seq, dim=1)
69+
return hidden_states
70+
4571
def __call__(
4672
self,
4773
attn: Attention,
@@ -72,7 +98,9 @@ def __call__(
7298
), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
7399
attn_heads = attn.heads // attn.parallel_manager.sp_size
74100
query, key, value = map(
75-
lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
101+
lambda x: all_to_all_with_pad(
102+
x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1, gather_pad=get_pad("pad")
103+
),
76104
[query, key, value],
77105
)
78106
else:
@@ -90,6 +118,13 @@ def __call__(
90118
if attn.norm_k is not None:
91119
key = attn.norm_k(key)
92120

121+
if attn.parallel_manager.sp_size > 1:
122+
# remove extra encoder for attention
123+
query, key, value = map(
124+
lambda x: self._remove_extra_encoder(x, text_seq_length, attn),
125+
[query, key, value],
126+
)
127+
93128
# Apply RoPE if needed
94129
if image_rotary_emb is not None:
95130
emb_len = image_rotary_emb[0].shape[0]
@@ -108,76 +143,15 @@ def __call__(
108143
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn_heads * head_dim)
109144

110145
if attn.parallel_manager.sp_size > 1:
111-
hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
112-
113-
# linear proj
114-
hidden_states = attn.to_out[0](hidden_states)
115-
# dropout
116-
hidden_states = attn.to_out[1](hidden_states)
117-
118-
encoder_hidden_states, hidden_states = hidden_states.split(
119-
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
120-
)
121-
return hidden_states, encoder_hidden_states
122-
123-
124-
class FusedCogVideoXAttnProcessor2_0:
125-
r"""
126-
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
127-
query and key vectors, but does not include spatial normalization.
128-
"""
129-
130-
def __init__(self):
131-
if not hasattr(F, "scaled_dot_product_attention"):
132-
raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
133-
134-
def __call__(
135-
self,
136-
attn: Attention,
137-
hidden_states: torch.Tensor,
138-
encoder_hidden_states: torch.Tensor,
139-
attention_mask: Optional[torch.Tensor] = None,
140-
image_rotary_emb: Optional[torch.Tensor] = None,
141-
) -> torch.Tensor:
142-
text_seq_length = encoder_hidden_states.size(1)
143-
144-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
145-
146-
batch_size, sequence_length, _ = (
147-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
148-
)
149-
150-
if attention_mask is not None:
151-
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
152-
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
153-
154-
qkv = attn.to_qkv(hidden_states)
155-
split_size = qkv.shape[-1] // 3
156-
query, key, value = torch.split(qkv, split_size, dim=-1)
157-
158-
inner_dim = key.shape[-1]
159-
head_dim = inner_dim // attn.heads
160-
161-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
162-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
163-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
164-
165-
if attn.norm_q is not None:
166-
query = attn.norm_q(query)
167-
if attn.norm_k is not None:
168-
key = attn.norm_k(key)
169-
170-
# Apply RoPE if needed
171-
if image_rotary_emb is not None:
172-
query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
173-
if not attn.is_cross_attention:
174-
key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
175-
176-
hidden_states = F.scaled_dot_product_attention(
177-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
178-
)
179-
180-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
146+
# add extra encoder for all_to_all
147+
hidden_states = self._add_extra_encoder(hidden_states, text_seq_length, attn)
148+
hidden_states = all_to_all_with_pad(
149+
hidden_states,
150+
attn.parallel_manager.sp_group,
151+
scatter_dim=1,
152+
gather_dim=2,
153+
scatter_pad=get_pad("pad"),
154+
)
181155

182156
# linear proj
183157
hidden_states = attn.to_out[0](hidden_states)

0 commit comments

Comments
 (0)