@@ -43,6 +43,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
4343 } else {
4444 throw std::runtime_error (" num_attention_heads / tp_size error." );
4545 }
46+ scaling_ = 1 .0f / std::sqrt (static_cast <float >(head_dim_));
4647
4748 // Initialize projection layers
4849 INFINILM_QKV_LINEAR_INIT (qkv_proj, " q_proj" , " k_proj" , " v_proj" , hidden_size_, head_dim_, config.num_attention_heads , config.num_key_value_heads , use_bias_,
@@ -52,17 +53,10 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
5253 dtype, device, tp_rank, tp_size, rank_info.comm );
5354}
5455
55- infinicore::Tensor LlamaAttention::forward (const infinicore::Tensor &hidden_states,
56- const infinicore::Tensor &position_ids,
57- std::shared_ptr<cache::Cache> kv_cache,
58- std::optional<infinicore::Tensor> cache_lengths,
59- std::optional<infinicore::Tensor> input_lengths,
60- std::optional<infinicore::Tensor> input_offsets,
61- std::optional<infinicore::Tensor> block_tables,
62- std::optional<infinicore::Tensor> slot_mapping) const {
63- if (!rotary_emb_) {
64- throw std::runtime_error (" LlamaAttention: rotary_emb not configured" );
65- }
56+ infinicore::Tensor LlamaAttention::forward_static_ (const infinicore::Tensor &hidden_states,
57+ const infinicore::Tensor &position_ids,
58+ std::shared_ptr<infinilm::cache::Cache> kv_cache,
59+ std::optional<infinicore::Tensor> cache_lengths) const {
6660 // Input shape: [batch, seq_len, hidden_size]
6761 auto hidden_states_mutable = hidden_states;
6862 auto shape = hidden_states->shape ();
@@ -73,7 +67,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
7367 auto [q, k, v] = qkv_proj_->forward_split (hidden_states_mutable);
7468
7569 // 2. Reshape for multi-head attention
76-
7770 // Reshape Q, K, V to include batch dimension
7871 // Python: query_states = self.q_proj(hidden_states).view(querys_shape)
7972 // The view operation requires the tensor to be contiguous in the required dimensions
@@ -114,13 +107,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
114107 auto [k_total_tmp, v_total_tmp] = static_kv_cache->update (layer_idx_, k_permuted, v_permuted, cache_lengths.value ());
115108 k_total = k_total_tmp;
116109 v_total = v_total_tmp;
117- } else if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
118- auto [k_total_tmp, v_total_tmp] = paged_kv_cache->update (layer_idx_, k_permuted, v_permuted, slot_mapping.value ());
119- k_total = k_total_tmp;
120- v_total = v_total_tmp;
121-
122- // / @todo Implement paged attention here.
123- throw std::runtime_error (" LlamaAttention: Paged attention not implemented" );
124110 } else {
125111 throw std::runtime_error (" LlamaAttention: Unsupported kvcache type" );
126112 }
@@ -152,8 +138,136 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
152138 return output;
153139}
154140
141+ infinicore::Tensor LlamaAttention::forward_paged_ (const infinicore::Tensor &hidden_states,
142+ const infinicore::Tensor &position_ids,
143+ std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
144+ std::optional<infinicore::Tensor> cache_lengths,
145+ std::optional<infinicore::Tensor> input_lengths,
146+ std::optional<infinicore::Tensor> input_offsets,
147+ std::optional<infinicore::Tensor> block_tables,
148+ std::optional<infinicore::Tensor> slot_mapping) const {
149+ if (!block_tables.has_value () or !input_lengths.has_value () or !slot_mapping.has_value ()) {
150+ throw std::runtime_error (" LlamaAttention::forward_paged: block_tables or input_lengths or slot_mapping is not set" );
151+ }
152+
153+ // Input shape: [batch, seq_len, hidden_size]
154+ auto hidden_states_mutable = hidden_states;
155+ auto shape = hidden_states->shape ();
156+ size_t batch_size = shape[0 ];
157+ size_t seq_len = shape[1 ];
158+
159+ bool is_prefill = (batch_size * seq_len != input_lengths.value ()->shape ()[0 ]);
160+ assert (batch_size == 1 );
161+
162+ // 1. Project Q, K, V
163+ auto [q, k, v] = qkv_proj_->forward_split (hidden_states_mutable);
164+
165+ // 2. Reshape for multi-head attention
166+
167+ // Reshape Q, K, V to include batch dimension
168+ // Python: query_states = self.q_proj(hidden_states).view(querys_shape)
169+ // The view operation requires the tensor to be contiguous in the required dimensions
170+ auto q_reshaped = q->view ({batch_size, seq_len, num_attention_heads_, head_dim_});
171+ auto k_reshaped = k->view ({batch_size, seq_len, num_key_value_heads_, head_dim_});
172+ auto v_reshaped = v->view ({batch_size, seq_len, num_key_value_heads_, head_dim_});
173+
174+ // 3. Prepare position_ids for RoPE - align with Python pattern
175+
176+ auto pos_shape = position_ids->shape ();
177+ infinicore::Tensor pos_ids_for_rope = position_ids;
178+ if (pos_shape.size () == 2 ) {
179+ auto pos_narrowed = position_ids->narrow ({{0 , 0 , 1 }});
180+ pos_ids_for_rope = pos_narrowed->contiguous ()->view ({pos_shape[1 ]});
181+ } else if (pos_shape.size () == 1 ) {
182+ pos_ids_for_rope = position_ids->contiguous ();
183+ } else {
184+ throw std::runtime_error (" Unexpected position_ids shape" );
185+ }
186+
187+ // 4. Apply RoPE to Q and K
188+ auto q_rope = infinicore::Tensor::empty ({batch_size, num_attention_heads_, seq_len, head_dim_}, q_reshaped->dtype (), q_reshaped->device ())->permute ({0 , 2 , 1 , 3 });
189+ auto k_rope = infinicore::Tensor::empty ({batch_size, num_key_value_heads_, seq_len, head_dim_}, q_reshaped->dtype (), q_reshaped->device ())->permute ({0 , 2 , 1 , 3 });
190+ rotary_emb_->forward (q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim]
191+ rotary_emb_->forward (k_rope, k_reshaped, pos_ids_for_rope); // [bs, seq_len, n_kv_head, head_dim]
192+
193+ // 5. Prepare KV caches
194+ // Ensure contiguous after permute for F16 compatibility with cache operations
195+ auto [k_total, v_total] = paged_kv_cache->update (layer_idx_,
196+ k_rope->contiguous (), // 如果不contiguous,报错Incompatible shape for view operation.
197+ v_reshaped,
198+ slot_mapping.value ());
199+
200+ // 6. Compute attention
201+ infinicore::Tensor attn_output;
202+ if (is_prefill) {
203+ q_reshaped = q_rope->permute ({0 , 2 , 1 , 3 }); // [bs, n_q_head, seq_len, head_dim]
204+ auto k_permuted = k_rope->permute ({0 , 2 , 1 , 3 }); // [bs, n_kv_head, seq_len, head_dim]
205+ auto v_permuted = v_reshaped->permute ({0 , 2 , 1 , 3 }); // [bs, n_kv_head, seq_len, head_dim]
206+
207+ auto total_seq_len = k_permuted->shape ()[2 ];
208+ size_t ngroup = num_attention_heads_ / num_key_value_heads_;
209+
210+ auto Q = q_reshaped->view ({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
211+ auto K = k_permuted->view ({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
212+ auto V = v_permuted->contiguous ()->view ({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
213+
214+ auto K_transposed = K->permute ({0 , 2 , 1 }); // [bs * n_kv_head, head_dim, total_seq_len]
215+
216+ auto attn_weight = infinicore::op::matmul (Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len]
217+
218+ auto attn_weight_softmax = attn_weight->view ({batch_size * num_attention_heads_, seq_len, total_seq_len});
219+ infinicore::op::causal_softmax_ (attn_weight_softmax, attn_weight_softmax);
220+
221+ auto out = infinicore::op::matmul (attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
222+
223+ attn_output = out->view ({batch_size, num_attention_heads_, seq_len, head_dim_})
224+ ->permute ({0 , 2 , 1 , 3 })
225+ ->contiguous ()
226+ ->view ({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
227+
228+ } else {
229+ q_reshaped = q_rope->contiguous ()->view ({1 * seq_len, num_attention_heads_, head_dim_}); // q_reshaped需要是contiguous
230+ auto out = infinicore::Tensor::empty ({1 * seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype (), q_reshaped->device ());
231+ infinicore::op::paged_attention_ (out,
232+ q_reshaped,
233+ k_total,
234+ v_total,
235+ block_tables.value (),
236+ input_lengths.value (),
237+ std::nullopt ,
238+ scaling_);
239+
240+ attn_output = out->view ({1 , seq_len, num_attention_heads_, head_dim_})->view ({1 , seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
241+ }
242+
243+ // 7. Project output
244+ return o_proj_->forward (attn_output); // [ 1 13 3584 ] => [ 1 13 4096 ]
245+ }
246+
247+ infinicore::Tensor LlamaAttention::forward (const infinicore::Tensor &hidden_states,
248+ const infinicore::Tensor &position_ids,
249+ std::shared_ptr<cache::Cache> kv_cache,
250+ std::optional<infinicore::Tensor> cache_lengths,
251+ std::optional<infinicore::Tensor> input_lengths,
252+ std::optional<infinicore::Tensor> input_offsets,
253+ std::optional<infinicore::Tensor> block_tables,
254+ std::optional<infinicore::Tensor> slot_mapping) const {
255+ if (!rotary_emb_) {
256+ throw std::runtime_error (" LlamaAttention: rotary_emb not configured" );
257+ }
258+
259+ infinicore::Tensor output;
260+ if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
261+ output = forward_paged_ (hidden_states, position_ids, paged_kv_cache, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
262+ } else {
263+
264+ output = forward_static_ (hidden_states, position_ids, kv_cache, cache_lengths);
265+ }
266+ return output;
267+ }
268+
155269void LlamaAttention::set_rotary_emb (const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb) {
156270 rotary_emb_ = rotary_emb;
157271}
158272
159- } // namespace infinilm::models::llama
273+ } // namespace infinilm::models::llama
0 commit comments