-
Notifications
You must be signed in to change notification settings - Fork 17
Open
Description
Hi, I would like to ask why the attention mask is not used in the prefill stage.
I want to output the attention scores matrix in prefill stage. Is the code below right?
if spec: # spec decoding
key_states, value_states = graph_cache.update(new_k_cache=key_states, new_v_cache=value_states, layer_idx=self.layer_idx)
else:
# update kv cache first
key_states, value_states = kv_cache.update(key_states, value_states, layer_idx=self.layer_idx)
if query_states.shape[1] == 1 and (isinstance(graph_cache, RetrievalCache)):
if graph_cache.init_graph == False:
# init graph cache
graph_cache.init_graph_cache(kv_cache, query_states, self.layer_idx)
else:
# update graph cache (customized)
graph_cache.update_graph_cache_retrieval(kv_cache, query_states, self.layer_idx)
# 计算注意力得分矩阵
attention_scores = torch.einsum("bqhd,bkhd->bhqk", query_states, key_states)
attention_scores /= torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
if attention_mask is not None:
attention_mask = attention_mask.to(attention_scores.device)
attention_scores += attention_mask
attn_output = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, softmax_scale=1/torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float16)), causal=True)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attention_scores
Metadata
Metadata
Assignees
Labels
No labels