Skip to content

Attention Scores Matrix Visualization #10

@bulaikexiansheng

Description

@bulaikexiansheng

Hi, I would like to ask why the attention mask is not used in the prefill stage.
I want to output the attention scores matrix in prefill stage. Is the code below right?

        if spec: # spec decoding
            key_states, value_states = graph_cache.update(new_k_cache=key_states, new_v_cache=value_states, layer_idx=self.layer_idx)
        else:
            # update kv cache first
            key_states, value_states = kv_cache.update(key_states, value_states, layer_idx=self.layer_idx)
            if query_states.shape[1] == 1 and (isinstance(graph_cache, RetrievalCache)): 
                if graph_cache.init_graph == False:
                    # init graph cache
                    graph_cache.init_graph_cache(kv_cache, query_states, self.layer_idx)
                else:
                    # update graph cache (customized)
                    graph_cache.update_graph_cache_retrieval(kv_cache, query_states, self.layer_idx)

        # 计算注意力得分矩阵
        attention_scores = torch.einsum("bqhd,bkhd->bhqk", query_states, key_states)
        attention_scores /= torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))
        if attention_mask is not None:
            attention_mask = attention_mask.to(attention_scores.device)
            attention_scores += attention_mask

        attn_output = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, softmax_scale=1/torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float16)), causal=True)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        return attn_output, attention_scores

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions