|
22 | 22 | from diffusers.utils.torch_utils import maybe_allow_in_graph |
23 | 23 | from torch import nn |
24 | 24 |
|
25 | | -from videosys.core.comm import all_to_all_with_pad, gather_sequence, get_pad, set_pad, split_sequence |
| 25 | +from videosys.core.comm import all_to_all_comm, gather_sequence, get_pad, set_pad, split_sequence |
26 | 26 | from videosys.core.pab_mgr import enable_pab, if_broadcast_spatial |
27 | 27 | from videosys.core.parallel_mgr import ParallelManager |
28 | 28 | from videosys.models.modules.embeddings import apply_rotary_emb |
@@ -52,9 +52,26 @@ def _remove_extra_encoder(self, hidden_states, text_seq_length, attn): |
52 | 52 | for i in range(sp_size): |
53 | 53 | new_seq.append(split_seq[i][:, :, text_seq_length:]) |
54 | 54 | hidden_states = torch.cat(new_seq, dim=2) |
| 55 | + |
| 56 | + # remove padding added when all2all |
| 57 | + # if pad is removed earlier than this |
| 58 | + # the split size will be wrong |
| 59 | + pad = get_pad("pad") |
| 60 | + if pad > 0: |
| 61 | + hidden_states = hidden_states.narrow(2, 0, hidden_states.size(2) - pad) |
55 | 62 | return hidden_states |
56 | 63 |
|
57 | 64 | def _add_extra_encoder(self, hidden_states, text_seq_length, attn): |
| 65 | + # add padding for split and later all2all |
| 66 | + # if pad is removed later than this |
| 67 | + # the split size will be wrong |
| 68 | + pad = get_pad("pad") |
| 69 | + if pad > 0: |
| 70 | + pad_shape = list(hidden_states.shape) |
| 71 | + pad_shape[1] = pad |
| 72 | + pad_tensor = torch.zeros(pad_shape, device=hidden_states.device, dtype=hidden_states.dtype) |
| 73 | + hidden_states = torch.cat([hidden_states, pad_tensor], dim=1) |
| 74 | + |
58 | 75 | # current layout is [text, seq] |
59 | 76 | # we want to add the extra encoder info [text, 1/n seq, text, 1/n seq, ...] |
60 | 77 | sp_size = attn.parallel_manager.sp_size |
@@ -97,10 +114,10 @@ def __call__( |
97 | 114 | attn.heads % attn.parallel_manager.sp_size == 0 |
98 | 115 | ), f"Number of heads {attn.heads} must be divisible by sequence parallel size {attn.parallel_manager.sp_size}" |
99 | 116 | attn_heads = attn.heads // attn.parallel_manager.sp_size |
| 117 | + # normally we operate pad for every all2all. but for more convient implementation |
| 118 | + # we move pad operation to encoder add and remove in cogvideo |
100 | 119 | query, key, value = map( |
101 | | - lambda x: all_to_all_with_pad( |
102 | | - x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1, gather_pad=get_pad("pad") |
103 | | - ), |
| 120 | + lambda x: all_to_all_comm(x, attn.parallel_manager.sp_group, scatter_dim=2, gather_dim=1), |
104 | 121 | [query, key, value], |
105 | 122 | ) |
106 | 123 | else: |
@@ -145,13 +162,7 @@ def __call__( |
145 | 162 | if attn.parallel_manager.sp_size > 1: |
146 | 163 | # add extra encoder for all_to_all |
147 | 164 | hidden_states = self._add_extra_encoder(hidden_states, text_seq_length, attn) |
148 | | - hidden_states = all_to_all_with_pad( |
149 | | - hidden_states, |
150 | | - attn.parallel_manager.sp_group, |
151 | | - scatter_dim=1, |
152 | | - gather_dim=2, |
153 | | - scatter_pad=get_pad("pad"), |
154 | | - ) |
| 165 | + hidden_states = all_to_all_comm(hidden_states, attn.parallel_manager.sp_group, scatter_dim=1, gather_dim=2) |
155 | 166 |
|
156 | 167 | # linear proj |
157 | 168 | hidden_states = attn.to_out[0](hidden_states) |
|
0 commit comments