Skip to content

Commit 3b18483

Browse files
add head parallel to ulysses (#666)
1 parent 1a343fc commit 3b18483

File tree

6 files changed

+137
-61
lines changed

6 files changed

+137
-61
lines changed

configs/seko_talk/seko_talk_25_int8_dist_fp8_comm.json

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
"video_duration": 5,
55
"audio_sr": 16000,
66
"target_video_length": 81,
7-
"self_attn_1_type": "sage_attn2",
8-
"cross_attn_1_type": "sage_attn2",
9-
"cross_attn_2_type": "sage_attn2",
7+
"self_attn_1_type": "sage_attn3",
8+
"cross_attn_1_type": "sage_attn3",
9+
"cross_attn_2_type": "sage_attn3",
1010
"sample_guide_scale": 1,
1111
"sample_shift": 5,
1212
"enable_cfg": false,
@@ -35,6 +35,7 @@
3535
"parallel": {
3636
"seq_p_size": 8,
3737
"seq_p_fp8_comm": true,
38-
"seq_p_attn_type": "ulysses-4090"
38+
"seq_p_head_parallel": true,
39+
"seq_p_attn_type": "ulysses"
3940
}
4041
}

lightx2v/common/ops/attn/ring_attn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class RingAttnWeight(AttnWeightTemplate):
4141
def __init__(self):
4242
self.config = {}
4343

44-
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False):
44+
def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False, enable_head_parallel=False):
4545
"""
4646
执行 Ring 注意力机制,结合图像和文本的查询、键和值。
4747
@@ -57,6 +57,7 @@ def apply(self, q, k, v, img_qkv_len, cu_seqlens_qkv, attention_module=None, seq
5757
torch.Tensor: 计算得到的注意力结果
5858
"""
5959
assert not use_fp8_comm, "RingAttn can't support fp8 comm now."
60+
assert not enable_head_parallel, "RingAttn can't support head parallel mode."
6061

6162
# 获取当前进程的排名和全局进程数
6263
cur_rank = dist.get_rank(seq_p_group)

lightx2v/common/ops/attn/ulysses_attn.py

Lines changed: 119 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,15 @@
77
from lightx2v_platform.base.global_var import AI_DEVICE
88

99
from .template import AttnWeightTemplate
10-
from .utils.all2all import all2all_head2seq, all2all_seq2head
10+
from .utils.all2all import all2all_head2seq
1111

1212

1313
@ATTN_WEIGHT_REGISTER("ulysses")
1414
class UlyssesAttnWeight(AttnWeightTemplate):
1515
def __init__(self):
1616
self.config = {}
1717

18-
def apply(self, q, k, v, slice_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False, img_first=True):
18+
def apply(self, q, k, v, slice_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False, enable_head_parallel=False, img_first=True):
1919
"""
2020
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
2121
@@ -58,6 +58,17 @@ def apply(self, q, k, v, slice_qkv_len, cu_seqlens_qkv, attention_module=None, s
5858
_, heads, hidden_dims = q.shape
5959
shard_heads = heads // world_size # 每个进程处理的头数
6060
shard_seqlen = img_qkv_len # 每个进程处理的序列长度
61+
global_img_seqlen = shard_seqlen * world_size # 全局序列长度
62+
63+
# 初始化累积序列长度张量
64+
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=AI_DEVICE)
65+
s = txt_qkv_len + global_img_seqlen # 计算文本和图像的总长度
66+
s1 = s # 当前样本的结束位置
67+
cu_seqlens_qkv[1] = s1 # 设置累积序列长度
68+
if txt_mask_len:
69+
s2 = txt_mask_len + global_img_seqlen # 文本掩码的结束位置
70+
cu_seqlens_qkv = torch.cat((cu_seqlens_qkv, torch.tensor([s2], dtype=torch.int32, device=AI_DEVICE)))
71+
max_seqlen_qkv = global_img_seqlen + txt_qkv_len # 最大序列长度
6172

6273
# 分割图像和文本的查询、键和值
6374
if img_first:
@@ -67,68 +78,119 @@ def apply(self, q, k, v, slice_qkv_len, cu_seqlens_qkv, attention_module=None, s
6778
txt_q, txt_k, txt_v = q[:txt_qkv_len, :, :].contiguous(), k[:txt_qkv_len, :, :].contiguous(), v[:txt_qkv_len, :, :].contiguous()
6879
img_q, img_k, img_v = q[txt_qkv_len:, :, :].contiguous(), k[txt_qkv_len:, :, :].contiguous(), v[txt_qkv_len:, :, :].contiguous()
6980

70-
# 将图像的查询、键和值转换为头的格式
71-
if use_fp8_comm:
72-
original_dtype = img_q.dtype
73-
original_shape = img_q.shape
74-
img_q_fp8, q_scale = quant_fp8_vllm(img_q.reshape(-1, original_shape[-1]))
75-
img_k_fp8, k_scale = quant_fp8_vllm(img_k.reshape(-1, original_shape[-1]))
76-
img_v_fp8, v_scale = quant_fp8_vllm(img_v.reshape(-1, original_shape[-1]))
77-
img_q_fp8 = all2all_seq2head(img_q_fp8.reshape(original_shape), group=seq_p_group)
78-
img_k_fp8 = all2all_seq2head(img_k_fp8.reshape(original_shape), group=seq_p_group)
79-
img_v_fp8 = all2all_seq2head(img_v_fp8.reshape(original_shape), group=seq_p_group)
80-
q_scale = all2all_seq2head(q_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
81-
k_scale = all2all_seq2head(k_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
82-
v_scale = all2all_seq2head(v_scale.reshape(original_shape[0], original_shape[1], 1), group=seq_p_group)
83-
img_q = dequant_fp8_vllm(img_q_fp8, q_scale, original_dtype)
84-
img_k = dequant_fp8_vllm(img_k_fp8, k_scale, original_dtype)
85-
img_v = dequant_fp8_vllm(img_v_fp8, v_scale, original_dtype)
86-
else:
87-
img_q = all2all_seq2head(img_q, group=seq_p_group)
88-
img_k = all2all_seq2head(img_k, group=seq_p_group)
89-
img_v = all2all_seq2head(img_v, group=seq_p_group)
90-
91-
# 处理文本的查询、键和值,选择当前进程的头
92-
txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
93-
txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
94-
txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
81+
img_qkv = torch.stack([img_q, img_k, img_v], dim=0).reshape(3, img_qkv_len, world_size, shard_heads, hidden_dims)
82+
original_dtype = img_qkv.dtype
83+
84+
if enable_head_parallel:
85+
img_qkv = img_qkv.permute(3, 2, 1, 0, 4).contiguous() # (shard_heads, world_size, img_qkv_len, 3, hidden_dims)
86+
output_qkv = torch.empty_like(img_qkv)
87+
88+
# 通信图像的查询、键和值
89+
if use_fp8_comm:
90+
img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims))
91+
img_qkv_fp8 = img_qkv_fp8.reshape(shard_heads, world_size, img_qkv_len, 3, hidden_dims)
92+
img_qkv_scale = img_qkv_scale.reshape(shard_heads, world_size, img_qkv_len, 3, 1)
93+
output_qkv_fp8 = torch.empty_like(img_qkv_fp8)
94+
output_qkv_scale = torch.empty_like(img_qkv_scale)
95+
comm_fp8_works = []
96+
comm_scale_works = []
97+
for h in range(shard_heads):
98+
work_fp8 = dist.all_to_all_single(output_qkv_fp8[h], img_qkv_fp8[h], group=seq_p_group, async_op=True)
99+
work_scale = dist.all_to_all_single(output_qkv_scale[h], img_qkv_scale[h], group=seq_p_group, async_op=True)
100+
comm_fp8_works.append(work_fp8)
101+
comm_scale_works.append(work_scale)
102+
else:
103+
comm_works = []
104+
for h in range(shard_heads):
105+
work = dist.all_to_all_single(output_qkv[h], img_qkv[h], group=seq_p_group, async_op=True)
106+
comm_works.append(work)
107+
108+
# 逐个head完成Attention计算
109+
single_head = 1
110+
head_attns = []
111+
for h in range(shard_heads):
112+
if use_fp8_comm:
113+
comm_fp8_works[h].wait()
114+
comm_scale_works[h].wait()
115+
output_qkv[h] = dequant_fp8_vllm(output_qkv_fp8[h], output_qkv_scale[h], original_dtype)
116+
else:
117+
comm_works[h].wait()
118+
119+
qkv = output_qkv[h].reshape(global_img_seqlen, 3, single_head, hidden_dims).transpose(0, 1)
120+
shard_img_q = qkv[0] # (global_img_seqlen, single_head, hidden_dims)
121+
shard_img_k = qkv[1]
122+
shard_img_v = qkv[2]
123+
124+
# 处理文本的查询、键和值,选择当前进程的当前头
125+
shard_txt_q = txt_q[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :]
126+
shard_txt_k = txt_k[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :]
127+
shard_txt_v = txt_v[:, (cur_rank * shard_heads + h) : (cur_rank * shard_heads + h + 1), :]
128+
129+
# 合并图像和文本的查询、键和值
130+
q = torch.cat((shard_img_q, shard_txt_q), dim=0)
131+
k = torch.cat((shard_img_k, shard_txt_k), dim=0)
132+
v = torch.cat((shard_img_v, shard_txt_v), dim=0)
133+
134+
# 调用注意力函数计算注意力结果
135+
head_attn = attention_module.apply(
136+
q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls
137+
).reshape(-1, single_head, hidden_dims)
138+
head_attns.append(head_attn)
139+
140+
# 合并当前进程的所有head的attn
141+
attn = torch.cat(head_attns, dim=1)
95142

96-
# 合并图像和文本的查询、键和值
97-
if img_first:
98-
q = torch.cat((img_q, txt_q), dim=0)
99-
k = torch.cat((img_k, txt_k), dim=0)
100-
v = torch.cat((img_v, txt_v), dim=0)
101143
else:
102-
q = torch.cat((txt_q, img_q), dim=0)
103-
k = torch.cat((txt_k, img_k), dim=0)
104-
v = torch.cat((txt_v, img_v), dim=0)
105-
106-
# 初始化累积序列长度张量
107-
cu_seqlens_qkv = torch.zeros([2], dtype=torch.int32, device=AI_DEVICE)
108-
s = txt_qkv_len + img_q.shape[0] # 计算文本和图像的总长度
109-
s1 = s # 当前样本的结束位置
110-
cu_seqlens_qkv[1] = s1 # 设置累积序列长度
111-
if txt_mask_len:
112-
s2 = txt_mask_len + img_q.shape[0] # 文本掩码的结束位置
113-
cu_seqlens_qkv = torch.cat(cu_seqlens_qkv, s2)
114-
max_seqlen_qkv = img_q.shape[0] + txt_q.shape[0] # 最大序列长度
115-
116-
# 调用注意力函数计算注意力结果
117-
# attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
118-
attn = attention_module.apply(q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls)
144+
img_qkv = img_qkv.permute(2, 1, 0, 3, 4).contiguous() # (world_size, img_qkv_len, 3, shard_heads, hidden_dims)
145+
146+
# 通信图像的查询、键和值
147+
if use_fp8_comm:
148+
img_qkv_fp8, img_qkv_scale = quant_fp8_vllm(img_qkv.reshape(-1, hidden_dims))
149+
img_qkv_fp8 = img_qkv_fp8.reshape(world_size, img_qkv_len, shard_heads, 3, hidden_dims)
150+
img_qkv_scale = img_qkv_scale.reshape(world_size, img_qkv_len, shard_heads, 3, 1)
151+
output_qkv_fp8 = torch.empty_like(img_qkv_fp8)
152+
output_qkv_scale = torch.empty_like(img_qkv_scale)
153+
dist.all_to_all_single(output_qkv_fp8, img_qkv_fp8, group=seq_p_group)
154+
dist.all_to_all_single(output_qkv_scale, img_qkv_scale, group=seq_p_group)
155+
output_qkv = dequant_fp8_vllm(output_qkv_fp8, output_qkv_scale, original_dtype)
156+
else:
157+
output_qkv = torch.empty_like(img_qkv)
158+
dist.all_to_all_single(output_qkv, img_qkv, group=seq_p_group)
159+
160+
# 完成Attention计算
161+
qkv = output_qkv.reshape(global_img_seqlen, 3, shard_heads, hidden_dims).transpose(0, 1)
162+
shard_img_q = qkv[0] # (global_img_seqlen, shard_head, hidden_dims)
163+
shard_img_k = qkv[1]
164+
shard_img_v = qkv[2]
165+
166+
# 处理文本的查询、键和值,选择当前进程的当前头
167+
shard_txt_q = txt_q[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
168+
shard_txt_k = txt_k[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
169+
shard_txt_v = txt_v[:, cur_rank * shard_heads : (cur_rank + 1) * shard_heads, :]
170+
171+
# 合并图像和文本的查询、键和值
172+
q = torch.cat((shard_img_q, shard_txt_q), dim=0)
173+
k = torch.cat((shard_img_k, shard_txt_k), dim=0)
174+
v = torch.cat((shard_img_v, shard_txt_v), dim=0)
175+
176+
# 调用注意力函数计算注意力结果
177+
attn = attention_module.apply(
178+
q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv, model_cls=model_cls
179+
).reshape(-1, shard_heads, hidden_dims)
119180

120181
# 分割图像和文本的注意力结果
182+
attn = attn.reshape(attn.shape[0], -1)
121183
if img_first:
122-
img_attn, txt_attn = attn[: img_q.shape[0], :], attn[img_q.shape[0] :,]
184+
img_attn, txt_attn = attn[:global_img_seqlen, :], attn[global_img_seqlen:]
123185
else:
124-
txt_attn, img_attn = attn[: txt_q.shape[0], :], attn[txt_q.shape[0] :,]
186+
txt_attn, img_attn = attn[:txt_qkv_len, :], attn[txt_qkv_len:]
187+
188+
# 通信所有进程的图像注意力结果
189+
img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm)
125190

126191
# 收集所有进程的文本注意力结果
127192
gathered_txt_attn = [torch.empty_like(txt_attn) for _ in range(world_size)]
128193
dist.all_gather(gathered_txt_attn, txt_attn, group=seq_p_group)
129-
130-
img_attn = self._reshape_img_attn(img_attn, world_size, shard_seqlen, shard_heads, hidden_dims, seq_p_group, use_fp8_comm)
131-
132194
txt_attn = torch.cat(gathered_txt_attn, dim=1) # 合并所有进程的文本注意力结果
133195

134196
# 合并图像和文本的注意力结果
@@ -247,7 +309,7 @@ def load_balanced_all_to_all(self, shards, seq_p_group=None):
247309

248310
return gathered_shards
249311

250-
def apply(self, q, k, v, slice_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False, img_first=True):
312+
def apply(self, q, k, v, slice_qkv_len, cu_seqlens_qkv, attention_module=None, seq_p_group=None, model_cls=None, use_fp8_comm=False, enable_head_parallel=False, img_first=True):
251313
"""
252314
执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
253315
@@ -262,6 +324,8 @@ def apply(self, q, k, v, slice_qkv_len, cu_seqlens_qkv, attention_module=None, s
262324
返回:
263325
torch.Tensor: 计算得到的注意力结果
264326
"""
327+
assert not enable_head_parallel, "Ulysses-4090 can't support head parallel mode."
328+
265329
if len(self.rounds) == 0:
266330
self.generate_round_robin_pairs(seq_p_group)
267331

lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,11 @@ def __init__(self, config):
104104
if self.config["seq_parallel"]:
105105
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
106106
self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False)
107+
self.enable_head_parallel = self.config["parallel"].get("seq_p_head_parallel", False)
107108
else:
108109
self.seq_p_group = None
109110
self.seq_p_fp8_comm = False
111+
self.enable_head_parallel = False
110112
self.infer_func = self.infer_without_offload
111113
if self.config.get("modulate_type", "triton") == "triton":
112114
self.modulate_func = fuse_scale_shift_kernel
@@ -234,6 +236,7 @@ def _infer_attn(self, weights, img_q, img_k, img_v, txt_q, txt_k, txt_v):
234236
attention_module=weights.self_attention,
235237
seq_p_group=self.seq_p_group,
236238
use_fp8_comm=self.seq_p_fp8_comm,
239+
enable_head_parallel=self.enable_head_parallel,
237240
model_cls=self.config["model_cls"],
238241
)
239242
else:

lightx2v/models/networks/qwen_image/infer/transformer_infer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@ def __init__(self, config):
2727
self.zero_cond_t = config.get("zero_cond_t", False)
2828
if self.config["seq_parallel"]:
2929
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
30+
self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False)
31+
self.enable_head_parallel = self.config["parallel"].get("seq_p_head_parallel", False)
3032
else:
3133
self.seq_p_group = None
32-
self.seq_p_fp8_comm = False
34+
self.seq_p_fp8_comm = False
35+
self.enable_head_parallel = False
3336
if self.config.get("modulate_type", "triton") == "triton":
3437
self.modulate_func = fuse_scale_shift_kernel
3538
else:
@@ -140,6 +143,7 @@ def apply_attn(self, block_weight, hidden_states, encoder_hidden_states, image_r
140143
attention_module=block_weight.attn.calculate,
141144
seq_p_group=self.seq_p_group,
142145
use_fp8_comm=self.seq_p_fp8_comm,
146+
enable_head_parallel=self.enable_head_parallel,
143147
model_cls=self.config["model_cls"],
144148
img_first=False,
145149
)

lightx2v/models/networks/wan/infer/transformer_infer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,11 @@ def __init__(self, config):
4646
if self.config["seq_parallel"]:
4747
self.seq_p_group = self.config.get("device_mesh").get_group(mesh_dim="seq_p")
4848
self.seq_p_fp8_comm = self.config["parallel"].get("seq_p_fp8_comm", False)
49+
self.enable_head_parallel = self.config["parallel"].get("seq_p_head_parallel", False)
4950
else:
5051
self.seq_p_group = None
5152
self.seq_p_fp8_comm = False
53+
self.enable_head_parallel = False
5254
self.infer_func = self.infer_without_offload
5355

5456
self.cos_sin = None
@@ -184,6 +186,7 @@ def infer_self_attn(self, phase, x, shift_msa, scale_msa):
184186
attention_module=phase.self_attn_1,
185187
seq_p_group=self.seq_p_group,
186188
use_fp8_comm=self.seq_p_fp8_comm,
189+
enable_head_parallel=self.enable_head_parallel,
187190
model_cls=self.config["model_cls"],
188191
)
189192
else:

0 commit comments

Comments
 (0)