Skip to content

Commit ff918ec

Browse files
gttiankaioahzxl
andauthored
[hotfix] fix CogVideoX parallel bug with 4 gpus (#221)
* [UPD]1. fix 4 process bug; * fix pad problem --------- Co-authored-by: Xuanlei Zhao <[email protected]>
1 parent e48a642 commit ff918ec

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

videosys/models/transformers/cogvideox_transformer_3d.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from diffusers.utils.torch_utils import maybe_allow_in_graph
2323
from torch import nn
2424

25-
from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence
25+
from videosys.core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence
2626
from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial
2727
from videosys.core.parallel_mgr import ParallelManager
2828
from videosys.models.modules.embeddings import apply_rotary_emb
@@ -52,9 +52,26 @@ def _remove_extra_encoder(self, hidden_states, text_seq_length, attn):
5252
for i in range(sp_size):
5353
new_seq.append(split_seq[i][:, :, text_seq_length:])
5454
hidden_states = torch.cat(new_seq, dim=2)
55+
56+
# remove padding added when all2all
57+
# if pad is removed earlier than this
58+
# the split size will be wrong
59+
pad = get_pad("pad")
60+
if pad > 0:
61+
hidden_states = hidden_states.narrow(2, 0, hidden_states.size(2) - pad)
5562
return hidden_states
5663

5764
def _add_extra_encoder(self, hidden_states, text_seq_length, attn):
65+
# add padding for split and later all2all
66+
# if pad is removed later than this
67+
# the split size will be wrong
68+
pad = get_pad("pad")
69+
if pad > 0:
70+
pad_shape = list(hidden_states.shape)
71+
pad_shape[1] = pad
72+
pad_tensor = torch.zeros(pad_shape, device=hidden_states.device, dtype=hidden_states.dtype)
73+
hidden_states = torch.cat([hidden_states, pad_tensor], dim=1)
74+
5875
# current layout is [text, seq]
5976
# we want to add the extra encoder info [text, 1/n seq, text, 1/n seq, ...]
6077
sp_size = attn.parallel_manager.sp_size
@@ -97,10 +114,10 @@ def __call__(
97114
attn.heads % attn.parallel_manager.sp_size == 0
98115
), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}"
99116
attn_heads = attn.heads // attn.parallel_manager.sp_size
117+
# normally we operate pad for every all2all. but for more convient implementation
118+
# we move pad operation to encoder add and remove in cogvideo
100119
query, key, value = map(
101-
lambda x: all_to_all_with_pad(
102-
x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1, gather_pad=get_pad("pad")
103-
),
120+
lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1),
104121
[query, key, value],
105122
)
106123
else:
@@ -145,13 +162,7 @@ def __call__(
145162
if attn.parallel_manager.sp_size > 1:
146163
# add extra encoder for all_to_all
147164
hidden_states = self._add_extra_encoder(hidden_states, text_seq_length, attn)
148-
hidden_states = all_to_all_with_pad(
149-
hidden_states,
150-
attn.parallel_manager.sp_group,
151-
scatter_dim=1,
152-
gather_dim=2,
153-
scatter_pad=get_pad("pad"),
154-
)
165+
hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2)
155166

156167
# linear proj
157168
hidden_states = attn.to_out[0](hidden_states)

0 commit comments

Comments
 (0)