Skip to content

Commit a88fb2a

Browse files
author
wangpengcheng
committed
issue/168 - 接入paged attention的两个算子
1 parent 96e53db commit a88fb2a

File tree

12 files changed

+318
-43
lines changed

12 files changed

+318
-43
lines changed

csrc/cache/kv_cache.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#include "kv_cache.hpp"
22

33
#include "../utils.hpp"
4-
4+
#include "infinicore/ops.hpp"
55
#include <stdexcept>
66

77
namespace infinilm::cache {
@@ -155,6 +155,7 @@ PagedKVCache::PagedKVCache(
155155
num_blocks_per_layer_ = config.max_kv_memory_bytes()
156156
/ (k_dim * num_rank_k_heads_ + v_dim * num_rank_v_heads_)
157157
/ block_size_
158+
/ rank_num_layers_
158159
/ infinicore::dsize(dtype_);
159160
if (num_blocks_per_layer_ == 0) {
160161
throw std::runtime_error("Not enough memory for KV cache");
@@ -191,7 +192,17 @@ std::tuple<infinicore::Tensor, infinicore::Tensor> PagedKVCache::update(
191192
auto v_cache_layer = v_caches_->narrow({{0, layer_idx, 1}})->squeeze(0);
192193

193194
/// @todo: implement paged cache update here
194-
195+
auto k_shape = k->shape();
196+
auto b = k_shape[0];
197+
auto s = k_shape[1];
198+
auto n = k_shape[2];
199+
auto d = k_shape[3];
200+
201+
infinicore::op::paged_caching_(k->view({b * s, n, d}),
202+
v->view({b * s, n, d}),
203+
k_cache_layer,
204+
v_cache_layer,
205+
slot_mapping);
195206
return {k_cache_layer, v_cache_layer};
196207
}
197208
} // namespace infinilm::cache

csrc/engine/infer_engine.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,24 @@ std::vector<std::unordered_map<std::string, infinicore::nn::Parameter>> InferEng
5656
//------------------------------------------------------
5757
// forward
5858
//------------------------------------------------------
59-
infinilm::InfinilmModel::Input InferEngine::Input::to_model_input() const {
60-
return {input_ids, position_ids, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping};
59+
infinilm::InfinilmModel::Input InferEngine::Input::to_model_input(infinicore::Device device) const {
60+
61+
std::optional<infinicore::Tensor> input_lengths_on_device;
62+
if (input_lengths.has_value()) {
63+
input_lengths_on_device = input_lengths.value()->to(device);
64+
}
65+
66+
std::optional<infinicore::Tensor> block_tables_on_device;
67+
if (block_tables.has_value()) {
68+
block_tables_on_device = block_tables.value()->to(device);
69+
}
70+
71+
std::optional<infinicore::Tensor> slot_mapping_on_device;
72+
if (slot_mapping.has_value()) {
73+
slot_mapping_on_device = slot_mapping.value()->to(device);
74+
}
75+
76+
return {input_ids, position_ids, cache_lengths, input_lengths_on_device, input_offsets, block_tables_on_device, slot_mapping_on_device};
6177
}
6278

6379
InferEngine::Output InferEngine::forward(const InferEngine::Input &input) {

csrc/engine/rank_worker.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ void RankWorker::thread_loop() {
206206
local_param_name = pending_param_name_;
207207
local_param = pending_param_;
208208
} else if (local_cmd == Command::RUN) {
209-
local_args = pending_args_.to_model_input();
209+
local_args = pending_args_.to_model_input(rank_info_.device);
210210
} else if (local_cmd == Command::RESET_CACHE) {
211211
if (pending_cache_config_ != nullptr) {
212212
local_cache_config = pending_cache_config_->unique_copy();

csrc/engine/rank_worker.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class RankWorker {
4747

4848
float random_val{0.1};
4949

50-
infinilm::InfinilmModel::Input to_model_input() const;
50+
infinilm::InfinilmModel::Input to_model_input(infinicore::Device device) const;
5151
};
5252

5353
struct Output {

csrc/models/llama/llama_attention.cpp

Lines changed: 134 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
4343
} else {
4444
throw std::runtime_error("num_attention_heads / tp_size error.");
4545
}
46+
scaling_ = 1.0f / std::sqrt(static_cast<float>(head_dim_));
4647

4748
// Initialize projection layers
4849
INFINILM_QKV_LINEAR_INIT(qkv_proj, "q_proj", "k_proj", "v_proj", hidden_size_, head_dim_, config.num_attention_heads, config.num_key_value_heads, use_bias_,
@@ -52,17 +53,10 @@ LlamaAttention::LlamaAttention(const LlamaConfig &config,
5253
dtype, device, tp_rank, tp_size, rank_info.comm);
5354
}
5455

55-
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
56-
const infinicore::Tensor &position_ids,
57-
std::shared_ptr<cache::Cache> kv_cache,
58-
std::optional<infinicore::Tensor> cache_lengths,
59-
std::optional<infinicore::Tensor> input_lengths,
60-
std::optional<infinicore::Tensor> input_offsets,
61-
std::optional<infinicore::Tensor> block_tables,
62-
std::optional<infinicore::Tensor> slot_mapping) const {
63-
if (!rotary_emb_) {
64-
throw std::runtime_error("LlamaAttention: rotary_emb not configured");
65-
}
56+
infinicore::Tensor LlamaAttention::forward_static_(const infinicore::Tensor &hidden_states,
57+
const infinicore::Tensor &position_ids,
58+
std::shared_ptr<infinilm::cache::Cache> kv_cache,
59+
std::optional<infinicore::Tensor> cache_lengths) const {
6660
// Input shape: [batch, seq_len, hidden_size]
6761
auto hidden_states_mutable = hidden_states;
6862
auto shape = hidden_states->shape();
@@ -73,7 +67,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
7367
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
7468

7569
// 2. Reshape for multi-head attention
76-
7770
// Reshape Q, K, V to include batch dimension
7871
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
7972
// The view operation requires the tensor to be contiguous in the required dimensions
@@ -114,13 +107,6 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
114107
auto [k_total_tmp, v_total_tmp] = static_kv_cache->update(layer_idx_, k_permuted, v_permuted, cache_lengths.value());
115108
k_total = k_total_tmp;
116109
v_total = v_total_tmp;
117-
} else if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
118-
auto [k_total_tmp, v_total_tmp] = paged_kv_cache->update(layer_idx_, k_permuted, v_permuted, slot_mapping.value());
119-
k_total = k_total_tmp;
120-
v_total = v_total_tmp;
121-
122-
/// @todo Implement paged attention here.
123-
throw std::runtime_error("LlamaAttention: Paged attention not implemented");
124110
} else {
125111
throw std::runtime_error("LlamaAttention: Unsupported kvcache type");
126112
}
@@ -152,8 +138,136 @@ infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_stat
152138
return output;
153139
}
154140

141+
infinicore::Tensor LlamaAttention::forward_paged_(const infinicore::Tensor &hidden_states,
142+
const infinicore::Tensor &position_ids,
143+
std::shared_ptr<infinilm::cache::PagedKVCache> paged_kv_cache,
144+
std::optional<infinicore::Tensor> cache_lengths,
145+
std::optional<infinicore::Tensor> input_lengths,
146+
std::optional<infinicore::Tensor> input_offsets,
147+
std::optional<infinicore::Tensor> block_tables,
148+
std::optional<infinicore::Tensor> slot_mapping) const {
149+
if (!block_tables.has_value() or !input_lengths.has_value() or !slot_mapping.has_value()) {
150+
throw std::runtime_error("LlamaAttention::forward_paged: block_tables or input_lengths or slot_mapping is not set");
151+
}
152+
153+
// Input shape: [batch, seq_len, hidden_size]
154+
auto hidden_states_mutable = hidden_states;
155+
auto shape = hidden_states->shape();
156+
size_t batch_size = shape[0];
157+
size_t seq_len = shape[1];
158+
159+
bool is_prefill = (batch_size * seq_len != input_lengths.value()->shape()[0]);
160+
assert(batch_size == 1);
161+
162+
// 1. Project Q, K, V
163+
auto [q, k, v] = qkv_proj_->forward_split(hidden_states_mutable);
164+
165+
// 2. Reshape for multi-head attention
166+
167+
// Reshape Q, K, V to include batch dimension
168+
// Python: query_states = self.q_proj(hidden_states).view(querys_shape)
169+
// The view operation requires the tensor to be contiguous in the required dimensions
170+
auto q_reshaped = q->view({batch_size, seq_len, num_attention_heads_, head_dim_});
171+
auto k_reshaped = k->view({batch_size, seq_len, num_key_value_heads_, head_dim_});
172+
auto v_reshaped = v->view({batch_size, seq_len, num_key_value_heads_, head_dim_});
173+
174+
// 3. Prepare position_ids for RoPE - align with Python pattern
175+
176+
auto pos_shape = position_ids->shape();
177+
infinicore::Tensor pos_ids_for_rope = position_ids;
178+
if (pos_shape.size() == 2) {
179+
auto pos_narrowed = position_ids->narrow({{0, 0, 1}});
180+
pos_ids_for_rope = pos_narrowed->contiguous()->view({pos_shape[1]});
181+
} else if (pos_shape.size() == 1) {
182+
pos_ids_for_rope = position_ids->contiguous();
183+
} else {
184+
throw std::runtime_error("Unexpected position_ids shape");
185+
}
186+
187+
// 4. Apply RoPE to Q and K
188+
auto q_rope = infinicore::Tensor::empty({batch_size, num_attention_heads_, seq_len, head_dim_}, q_reshaped->dtype(), q_reshaped->device())->permute({0, 2, 1, 3});
189+
auto k_rope = infinicore::Tensor::empty({batch_size, num_key_value_heads_, seq_len, head_dim_}, q_reshaped->dtype(), q_reshaped->device())->permute({0, 2, 1, 3});
190+
rotary_emb_->forward(q_rope, q_reshaped, pos_ids_for_rope); // [bs, seq_len, n_q_head, head_dim]
191+
rotary_emb_->forward(k_rope, k_reshaped, pos_ids_for_rope); // [bs, seq_len, n_kv_head, head_dim]
192+
193+
// 5. Prepare KV caches
194+
// Ensure contiguous after permute for F16 compatibility with cache operations
195+
auto [k_total, v_total] = paged_kv_cache->update(layer_idx_,
196+
k_rope->contiguous(), // 如果不contiguous,报错Incompatible shape for view operation.
197+
v_reshaped,
198+
slot_mapping.value());
199+
200+
// 6. Compute attention
201+
infinicore::Tensor attn_output;
202+
if (is_prefill) {
203+
q_reshaped = q_rope->permute({0, 2, 1, 3}); // [bs, n_q_head, seq_len, head_dim]
204+
auto k_permuted = k_rope->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
205+
auto v_permuted = v_reshaped->permute({0, 2, 1, 3}); // [bs, n_kv_head, seq_len, head_dim]
206+
207+
auto total_seq_len = k_permuted->shape()[2];
208+
size_t ngroup = num_attention_heads_ / num_key_value_heads_;
209+
210+
auto Q = q_reshaped->view({batch_size * num_key_value_heads_, ngroup * seq_len, head_dim_});
211+
auto K = k_permuted->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
212+
auto V = v_permuted->contiguous()->view({batch_size * num_key_value_heads_, total_seq_len, head_dim_});
213+
214+
auto K_transposed = K->permute({0, 2, 1}); // [bs * n_kv_head, head_dim, total_seq_len]
215+
216+
auto attn_weight = infinicore::op::matmul(Q, K_transposed, scaling_); // [bs * n_kv_head, ng * seq_len, total_seq_len]
217+
218+
auto attn_weight_softmax = attn_weight->view({batch_size * num_attention_heads_, seq_len, total_seq_len});
219+
infinicore::op::causal_softmax_(attn_weight_softmax, attn_weight_softmax);
220+
221+
auto out = infinicore::op::matmul(attn_weight, V); // [bs * n_kv_head, ng * seq_len, head_dim]
222+
223+
attn_output = out->view({batch_size, num_attention_heads_, seq_len, head_dim_})
224+
->permute({0, 2, 1, 3})
225+
->contiguous()
226+
->view({batch_size, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
227+
228+
} else {
229+
q_reshaped = q_rope->contiguous()->view({1 * seq_len, num_attention_heads_, head_dim_}); // q_reshaped需要是contiguous
230+
auto out = infinicore::Tensor::empty({1 * seq_len, num_attention_heads_, head_dim_}, q_reshaped->dtype(), q_reshaped->device());
231+
infinicore::op::paged_attention_(out,
232+
q_reshaped,
233+
k_total,
234+
v_total,
235+
block_tables.value(),
236+
input_lengths.value(),
237+
std::nullopt,
238+
scaling_);
239+
240+
attn_output = out->view({1, seq_len, num_attention_heads_, head_dim_})->view({1, seq_len, num_attention_heads_ * head_dim_}); // [bs, seq_len, n_q_head * head_dim]
241+
}
242+
243+
// 7. Project output
244+
return o_proj_->forward(attn_output); // [ 1 13 3584 ] => [ 1 13 4096 ]
245+
}
246+
247+
infinicore::Tensor LlamaAttention::forward(const infinicore::Tensor &hidden_states,
248+
const infinicore::Tensor &position_ids,
249+
std::shared_ptr<cache::Cache> kv_cache,
250+
std::optional<infinicore::Tensor> cache_lengths,
251+
std::optional<infinicore::Tensor> input_lengths,
252+
std::optional<infinicore::Tensor> input_offsets,
253+
std::optional<infinicore::Tensor> block_tables,
254+
std::optional<infinicore::Tensor> slot_mapping) const {
255+
if (!rotary_emb_) {
256+
throw std::runtime_error("LlamaAttention: rotary_emb not configured");
257+
}
258+
259+
infinicore::Tensor output;
260+
if (auto paged_kv_cache = std::dynamic_pointer_cast<cache::PagedKVCache>(kv_cache)) {
261+
output = forward_paged_(hidden_states, position_ids, paged_kv_cache, cache_lengths, input_lengths, input_offsets, block_tables, slot_mapping);
262+
} else {
263+
264+
output = forward_static_(hidden_states, position_ids, kv_cache, cache_lengths);
265+
}
266+
return output;
267+
}
268+
155269
void LlamaAttention::set_rotary_emb(const std::shared_ptr<infinicore::nn::RoPE> &rotary_emb) {
156270
rotary_emb_ = rotary_emb;
157271
}
158272

159-
} // namespace infinilm::models::llama
273+
} // namespace infinilm::models::llama

csrc/models/llama/llama_attention.hpp

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class LlamaAttention : public infinicore::nn::Module {
5555
std::optional<infinicore::Tensor> input_lengths,
5656
std::optional<infinicore::Tensor> input_offsets,
5757
std::optional<infinicore::Tensor> block_tables,
58-
std::optional<infinicore::Tensor> slot_mappin) const;
58+
std::optional<infinicore::Tensor> slot_mapping) const;
5959

6060
/**
6161
* @brief Get the layer index
@@ -73,6 +73,21 @@ class LlamaAttention : public infinicore::nn::Module {
7373
size_t head_dim() const { return head_dim_; }
7474
size_t hidden_size() const { return hidden_size_; }
7575

76+
private:
77+
infinicore::Tensor forward_static_(const infinicore::Tensor &hidden_states,
78+
const infinicore::Tensor &position_ids,
79+
std::shared_ptr<infinilm::cache::Cache> kv_cache,
80+
std::optional<infinicore::Tensor> cache_lengths) const;
81+
82+
infinicore::Tensor forward_paged_(const infinicore::Tensor &hidden_states,
83+
const infinicore::Tensor &position_ids,
84+
std::shared_ptr<infinilm::cache::PagedKVCache> kv_cache,
85+
std::optional<infinicore::Tensor> cache_lengths,
86+
std::optional<infinicore::Tensor> input_lengths,
87+
std::optional<infinicore::Tensor> input_offsets,
88+
std::optional<infinicore::Tensor> block_tables,
89+
std::optional<infinicore::Tensor> slot_mapping) const;
90+
7691
protected:
7792
// Projection layers
7893
INFINICORE_NN_MODULE(infinilm::layers::QKVParallelLinear, qkv_proj);
@@ -93,6 +108,8 @@ class LlamaAttention : public infinicore::nn::Module {
93108
bool use_bias_; // Bias for Q/K/V projections
94109
bool use_output_bias_; // Bias for output projection (o_proj)
95110
size_t max_position_embeddings_; // For cache initialization (deprecated, kept for compatibility)
111+
112+
float scaling_;
96113
};
97114

98115
} // namespace infinilm::models::llama

csrc/pybind11/engine/engine.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ inline void bind_infer_engine(py::module &m) {
9090
std::move(input_ids),
9191
std::move(position_ids),
9292
std::move(cache_lengths),
93+
std::move(input_lengths),
94+
std::move(input_offsets),
9395
std::move(block_tables),
9496
std::move(slot_mapping)}};
9597

0 commit comments

Comments
 (0)