77from lightx2v_platform .base .global_var import AI_DEVICE
88
99from .template import AttnWeightTemplate
10- from .utils .all2all import all2all_head2seq , all2all_seq2head
10+ from .utils .all2all import all2all_head2seq
1111
1212
1313@ATTN_WEIGHT_REGISTER ("ulysses" )
1414class UlyssesAttnWeight (AttnWeightTemplate ):
1515 def __init__ (self ):
1616 self .config = {}
1717
18- def apply (self , q , k , v , slice_qkv_len , cu_seqlens_qkv , attention_module = None , seq_p_group = None , model_cls = None , use_fp8_comm = False , img_first = True ):
18+ def apply (self , q , k , v , slice_qkv_len , cu_seqlens_qkv , attention_module = None , seq_p_group = None , model_cls = None , use_fp8_comm = False , enable_head_parallel = False , img_first = True ):
1919 """
2020 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
2121
@@ -58,6 +58,17 @@ def apply(self, q, k, v, slice_qkv_len, cu_seqlens_qkv, attention_module=None, s
5858 _ , heads , hidden_dims = q .shape
5959 shard_heads = heads // world_size # 每个进程处理的头数
6060 shard_seqlen = img_qkv_len # 每个进程处理的序列长度
61+ global_img_seqlen = shard_seqlen * world_size # 全局序列长度
62+
63+ # 初始化累积序列长度张量
64+ cu_seqlens_qkv = torch .zeros ([2 ], dtype = torch .int32 , device = AI_DEVICE )
65+ s = txt_qkv_len + global_img_seqlen # 计算文本和图像的总长度
66+ s1 = s # 当前样本的结束位置
67+ cu_seqlens_qkv [1 ] = s1 # 设置累积序列长度
68+ if txt_mask_len :
69+ s2 = txt_mask_len + global_img_seqlen # 文本掩码的结束位置
70+ cu_seqlens_qkv = torch .cat ((cu_seqlens_qkv , torch .tensor ([s2 ], dtype = torch .int32 , device = AI_DEVICE )))
71+ max_seqlen_qkv = global_img_seqlen + txt_qkv_len # 最大序列长度
6172
6273 # 分割图像和文本的查询、键和值
6374 if img_first :
@@ -67,68 +78,119 @@ def apply(self, q, k, v, slice_qkv_len, cu_seqlens_qkv, attention_module=None, s
6778 txt_q , txt_k , txt_v = q [:txt_qkv_len , :, :].contiguous (), k [:txt_qkv_len , :, :].contiguous (), v [:txt_qkv_len , :, :].contiguous ()
6879 img_q , img_k , img_v = q [txt_qkv_len :, :, :].contiguous (), k [txt_qkv_len :, :, :].contiguous (), v [txt_qkv_len :, :, :].contiguous ()
6980
70- # 将图像的查询、键和值转换为头的格式
71- if use_fp8_comm :
72- original_dtype = img_q .dtype
73- original_shape = img_q .shape
74- img_q_fp8 , q_scale = quant_fp8_vllm (img_q .reshape (- 1 , original_shape [- 1 ]))
75- img_k_fp8 , k_scale = quant_fp8_vllm (img_k .reshape (- 1 , original_shape [- 1 ]))
76- img_v_fp8 , v_scale = quant_fp8_vllm (img_v .reshape (- 1 , original_shape [- 1 ]))
77- img_q_fp8 = all2all_seq2head (img_q_fp8 .reshape (original_shape ), group = seq_p_group )
78- img_k_fp8 = all2all_seq2head (img_k_fp8 .reshape (original_shape ), group = seq_p_group )
79- img_v_fp8 = all2all_seq2head (img_v_fp8 .reshape (original_shape ), group = seq_p_group )
80- q_scale = all2all_seq2head (q_scale .reshape (original_shape [0 ], original_shape [1 ], 1 ), group = seq_p_group )
81- k_scale = all2all_seq2head (k_scale .reshape (original_shape [0 ], original_shape [1 ], 1 ), group = seq_p_group )
82- v_scale = all2all_seq2head (v_scale .reshape (original_shape [0 ], original_shape [1 ], 1 ), group = seq_p_group )
83- img_q = dequant_fp8_vllm (img_q_fp8 , q_scale , original_dtype )
84- img_k = dequant_fp8_vllm (img_k_fp8 , k_scale , original_dtype )
85- img_v = dequant_fp8_vllm (img_v_fp8 , v_scale , original_dtype )
86- else :
87- img_q = all2all_seq2head (img_q , group = seq_p_group )
88- img_k = all2all_seq2head (img_k , group = seq_p_group )
89- img_v = all2all_seq2head (img_v , group = seq_p_group )
90-
91- # 处理文本的查询、键和值,选择当前进程的头
92- txt_q = txt_q [:, cur_rank * shard_heads : (cur_rank + 1 ) * shard_heads , :]
93- txt_k = txt_k [:, cur_rank * shard_heads : (cur_rank + 1 ) * shard_heads , :]
94- txt_v = txt_v [:, cur_rank * shard_heads : (cur_rank + 1 ) * shard_heads , :]
81+ img_qkv = torch .stack ([img_q , img_k , img_v ], dim = 0 ).reshape (3 , img_qkv_len , world_size , shard_heads , hidden_dims )
82+ original_dtype = img_qkv .dtype
83+
84+ if enable_head_parallel :
85+ img_qkv = img_qkv .permute (3 , 2 , 1 , 0 , 4 ).contiguous () # (shard_heads, world_size, img_qkv_len, 3, hidden_dims)
86+ output_qkv = torch .empty_like (img_qkv )
87+
88+ # 通信图像的查询、键和值
89+ if use_fp8_comm :
90+ img_qkv_fp8 , img_qkv_scale = quant_fp8_vllm (img_qkv .reshape (- 1 , hidden_dims ))
91+ img_qkv_fp8 = img_qkv_fp8 .reshape (shard_heads , world_size , img_qkv_len , 3 , hidden_dims )
92+ img_qkv_scale = img_qkv_scale .reshape (shard_heads , world_size , img_qkv_len , 3 , 1 )
93+ output_qkv_fp8 = torch .empty_like (img_qkv_fp8 )
94+ output_qkv_scale = torch .empty_like (img_qkv_scale )
95+ comm_fp8_works = []
96+ comm_scale_works = []
97+ for h in range (shard_heads ):
98+ work_fp8 = dist .all_to_all_single (output_qkv_fp8 [h ], img_qkv_fp8 [h ], group = seq_p_group , async_op = True )
99+ work_scale = dist .all_to_all_single (output_qkv_scale [h ], img_qkv_scale [h ], group = seq_p_group , async_op = True )
100+ comm_fp8_works .append (work_fp8 )
101+ comm_scale_works .append (work_scale )
102+ else :
103+ comm_works = []
104+ for h in range (shard_heads ):
105+ work = dist .all_to_all_single (output_qkv [h ], img_qkv [h ], group = seq_p_group , async_op = True )
106+ comm_works .append (work )
107+
108+ # 逐个head完成Attention计算
109+ single_head = 1
110+ head_attns = []
111+ for h in range (shard_heads ):
112+ if use_fp8_comm :
113+ comm_fp8_works [h ].wait ()
114+ comm_scale_works [h ].wait ()
115+ output_qkv [h ] = dequant_fp8_vllm (output_qkv_fp8 [h ], output_qkv_scale [h ], original_dtype )
116+ else :
117+ comm_works [h ].wait ()
118+
119+ qkv = output_qkv [h ].reshape (global_img_seqlen , 3 , single_head , hidden_dims ).transpose (0 , 1 )
120+ shard_img_q = qkv [0 ] # (global_img_seqlen, single_head, hidden_dims)
121+ shard_img_k = qkv [1 ]
122+ shard_img_v = qkv [2 ]
123+
124+ # 处理文本的查询、键和值,选择当前进程的当前头
125+ shard_txt_q = txt_q [:, (cur_rank * shard_heads + h ) : (cur_rank * shard_heads + h + 1 ), :]
126+ shard_txt_k = txt_k [:, (cur_rank * shard_heads + h ) : (cur_rank * shard_heads + h + 1 ), :]
127+ shard_txt_v = txt_v [:, (cur_rank * shard_heads + h ) : (cur_rank * shard_heads + h + 1 ), :]
128+
129+ # 合并图像和文本的查询、键和值
130+ q = torch .cat ((shard_img_q , shard_txt_q ), dim = 0 )
131+ k = torch .cat ((shard_img_k , shard_txt_k ), dim = 0 )
132+ v = torch .cat ((shard_img_v , shard_txt_v ), dim = 0 )
133+
134+ # 调用注意力函数计算注意力结果
135+ head_attn = attention_module .apply (
136+ q = q , k = k , v = v , cu_seqlens_q = cu_seqlens_qkv , cu_seqlens_kv = cu_seqlens_qkv , max_seqlen_q = max_seqlen_qkv , max_seqlen_kv = max_seqlen_qkv , model_cls = model_cls
137+ ).reshape (- 1 , single_head , hidden_dims )
138+ head_attns .append (head_attn )
139+
140+ # 合并当前进程的所有head的attn
141+ attn = torch .cat (head_attns , dim = 1 )
95142
96- # 合并图像和文本的查询、键和值
97- if img_first :
98- q = torch .cat ((img_q , txt_q ), dim = 0 )
99- k = torch .cat ((img_k , txt_k ), dim = 0 )
100- v = torch .cat ((img_v , txt_v ), dim = 0 )
101143 else :
102- q = torch .cat ((txt_q , img_q ), dim = 0 )
103- k = torch .cat ((txt_k , img_k ), dim = 0 )
104- v = torch .cat ((txt_v , img_v ), dim = 0 )
105-
106- # 初始化累积序列长度张量
107- cu_seqlens_qkv = torch .zeros ([2 ], dtype = torch .int32 , device = AI_DEVICE )
108- s = txt_qkv_len + img_q .shape [0 ] # 计算文本和图像的总长度
109- s1 = s # 当前样本的结束位置
110- cu_seqlens_qkv [1 ] = s1 # 设置累积序列长度
111- if txt_mask_len :
112- s2 = txt_mask_len + img_q .shape [0 ] # 文本掩码的结束位置
113- cu_seqlens_qkv = torch .cat (cu_seqlens_qkv , s2 )
114- max_seqlen_qkv = img_q .shape [0 ] + txt_q .shape [0 ] # 最大序列长度
115-
116- # 调用注意力函数计算注意力结果
117- # attn = attention(attention_type=attention_type, q=q, k=k, v=v, cu_seqlens_q=cu_seqlens_qkv, cu_seqlens_kv=cu_seqlens_qkv, max_seqlen_q=max_seqlen_qkv, max_seqlen_kv=max_seqlen_qkv)
118- attn = attention_module .apply (q = q , k = k , v = v , cu_seqlens_q = cu_seqlens_qkv , cu_seqlens_kv = cu_seqlens_qkv , max_seqlen_q = max_seqlen_qkv , max_seqlen_kv = max_seqlen_qkv , model_cls = model_cls )
144+ img_qkv = img_qkv .permute (2 , 1 , 0 , 3 , 4 ).contiguous () # (world_size, img_qkv_len, 3, shard_heads, hidden_dims)
145+
146+ # 通信图像的查询、键和值
147+ if use_fp8_comm :
148+ img_qkv_fp8 , img_qkv_scale = quant_fp8_vllm (img_qkv .reshape (- 1 , hidden_dims ))
149+ img_qkv_fp8 = img_qkv_fp8 .reshape (world_size , img_qkv_len , shard_heads , 3 , hidden_dims )
150+ img_qkv_scale = img_qkv_scale .reshape (world_size , img_qkv_len , shard_heads , 3 , 1 )
151+ output_qkv_fp8 = torch .empty_like (img_qkv_fp8 )
152+ output_qkv_scale = torch .empty_like (img_qkv_scale )
153+ dist .all_to_all_single (output_qkv_fp8 , img_qkv_fp8 , group = seq_p_group )
154+ dist .all_to_all_single (output_qkv_scale , img_qkv_scale , group = seq_p_group )
155+ output_qkv = dequant_fp8_vllm (output_qkv_fp8 , output_qkv_scale , original_dtype )
156+ else :
157+ output_qkv = torch .empty_like (img_qkv )
158+ dist .all_to_all_single (output_qkv , img_qkv , group = seq_p_group )
159+
160+ # 完成Attention计算
161+ qkv = output_qkv .reshape (global_img_seqlen , 3 , shard_heads , hidden_dims ).transpose (0 , 1 )
162+ shard_img_q = qkv [0 ] # (global_img_seqlen, shard_head, hidden_dims)
163+ shard_img_k = qkv [1 ]
164+ shard_img_v = qkv [2 ]
165+
166+ # 处理文本的查询、键和值,选择当前进程的当前头
167+ shard_txt_q = txt_q [:, cur_rank * shard_heads : (cur_rank + 1 ) * shard_heads , :]
168+ shard_txt_k = txt_k [:, cur_rank * shard_heads : (cur_rank + 1 ) * shard_heads , :]
169+ shard_txt_v = txt_v [:, cur_rank * shard_heads : (cur_rank + 1 ) * shard_heads , :]
170+
171+ # 合并图像和文本的查询、键和值
172+ q = torch .cat ((shard_img_q , shard_txt_q ), dim = 0 )
173+ k = torch .cat ((shard_img_k , shard_txt_k ), dim = 0 )
174+ v = torch .cat ((shard_img_v , shard_txt_v ), dim = 0 )
175+
176+ # 调用注意力函数计算注意力结果
177+ attn = attention_module .apply (
178+ q = q , k = k , v = v , cu_seqlens_q = cu_seqlens_qkv , cu_seqlens_kv = cu_seqlens_qkv , max_seqlen_q = max_seqlen_qkv , max_seqlen_kv = max_seqlen_qkv , model_cls = model_cls
179+ ).reshape (- 1 , shard_heads , hidden_dims )
119180
120181 # 分割图像和文本的注意力结果
182+ attn = attn .reshape (attn .shape [0 ], - 1 )
121183 if img_first :
122- img_attn , txt_attn = attn [: img_q . shape [ 0 ] , :], attn [img_q . shape [ 0 ] :, ]
184+ img_attn , txt_attn = attn [:global_img_seqlen , :], attn [global_img_seqlen : ]
123185 else :
124- txt_attn , img_attn = attn [: txt_q .shape [0 ], :], attn [txt_q .shape [0 ] :,]
186+ txt_attn , img_attn = attn [:txt_qkv_len , :], attn [txt_qkv_len :]
187+
188+ # 通信所有进程的图像注意力结果
189+ img_attn = self ._reshape_img_attn (img_attn , world_size , shard_seqlen , shard_heads , hidden_dims , seq_p_group , use_fp8_comm )
125190
126191 # 收集所有进程的文本注意力结果
127192 gathered_txt_attn = [torch .empty_like (txt_attn ) for _ in range (world_size )]
128193 dist .all_gather (gathered_txt_attn , txt_attn , group = seq_p_group )
129-
130- img_attn = self ._reshape_img_attn (img_attn , world_size , shard_seqlen , shard_heads , hidden_dims , seq_p_group , use_fp8_comm )
131-
132194 txt_attn = torch .cat (gathered_txt_attn , dim = 1 ) # 合并所有进程的文本注意力结果
133195
134196 # 合并图像和文本的注意力结果
@@ -247,7 +309,7 @@ def load_balanced_all_to_all(self, shards, seq_p_group=None):
247309
248310 return gathered_shards
249311
250- def apply (self , q , k , v , slice_qkv_len , cu_seqlens_qkv , attention_module = None , seq_p_group = None , model_cls = None , use_fp8_comm = False , img_first = True ):
312+ def apply (self , q , k , v , slice_qkv_len , cu_seqlens_qkv , attention_module = None , seq_p_group = None , model_cls = None , use_fp8_comm = False , enable_head_parallel = False , img_first = True ):
251313 """
252314 执行 Ulysses 注意力机制,结合图像和文本的查询、键和值。
253315
@@ -262,6 +324,8 @@ def apply(self, q, k, v, slice_qkv_len, cu_seqlens_qkv, attention_module=None, s
262324 返回:
263325 torch.Tensor: 计算得到的注意力结果
264326 """
327+ assert not enable_head_parallel , "Ulysses-4090 can't support head parallel mode."
328+
265329 if len (self .rounds ) == 0 :
266330 self .generate_round_robin_pairs (seq_p_group )
267331
0 commit comments