@@ -112,27 +112,32 @@ void releaseDeviceResource(DeviceResource &res) {
112112
113113void inferDeviceBatch (const JiugeMeta &meta, DeviceResource &rsrc,
114114 uint32_t idev, uint32_t ndev,
115- const uint32_t *tokens, uint32_t ntok, // 所有req的ntokens之和
115+ const uint32_t *tokens, uint32_t ntok,
116116 const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos,
117117 struct KVCache **kv_caches,
118118 const float *temperature, const uint32_t *topk, const float *topp,
119119 uint32_t *output) {
120+
121+ // sparse attention
122+ auto ratio = 0.125 ; // sparse ratio for kv cache
123+ bool sparseOn = true ;
124+
120125 auto nlayer = meta.nlayer ;
121126 auto nkvh = meta.nkvh / ndev;
122127 auto nh = meta.nh / ndev;
123128 auto ngroup = nh / nkvh;
124129 // auto dctx = meta.dctx;
125130 auto dh = meta.dh ;
126- auto d = meta.d ; // hidden size
127- auto dt_logits = meta.dt_logits ; // data type
131+ auto d = meta.d ; // hidden size
132+ auto dt_logits = meta.dt_logits ; // data type
128133 auto di = meta.di / ndev;
129134 auto dvoc = meta.dvoc ;
130135 auto stream = rsrc.stream ;
131136 bool has_qkv_bias = rsrc.b_attn_qkv .size () > 0 ;
132137
133138 // Allocate buffers
134- auto logits_in = Tensor::buffer (dt_logits, {ntok, d}, rsrc.memory_pool ); // hidden_stat
135- auto logits_out = Tensor::buffer (dt_logits, {ntok, d}, rsrc.memory_pool ); // hidden_stat (rms)
139+ auto logits_in = Tensor::buffer (dt_logits, {ntok, d}, rsrc.memory_pool );
140+ auto logits_out = Tensor::buffer (dt_logits, {ntok, d}, rsrc.memory_pool );
136141 auto qkv_buf = Tensor::buffer (dt_logits, {ntok, (nh + nkvh * 2 ) * dh}, rsrc.memory_pool );
137142 auto gate_up_buf = Tensor::buffer (dt_logits, {ntok, 2 * di}, rsrc.memory_pool );
138143 auto o_buf = Tensor::buffer (dt_logits, {ntok, nh * dh}, rsrc.memory_pool );
@@ -220,7 +225,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
220225 size_t max_seq_len = 0 ;
221226 o_buf->dimSplit (1 , {nh, dh});
222227
223- size_t recentWindow = 16 ; // sparse attention
228+
224229 for (uint32_t req = 0 ; req < nreq; req++) {
225230 auto past_len = req_pos[req];
226231 auto seq_len = req_lens[req];
@@ -234,7 +239,9 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
234239 auto full_kv = kv_caches[req]->k [idev][0 ]->slice (0 , 0 , total_len)->permute ({1 , 2 , 0 });
235240 auto cache_kv = kv_caches[req]->k [idev][0 ]->slice (0 , past_len, seq_len);
236241
237- bool prune = (past_len == 0 ) && (seq_len > recentWindow);
242+
243+ uint32_t recentWindow = (uint32_t ) (seq_len * ratio);
244+ bool prune = sparseOn && (past_len == 0 ) && (recentWindow > 0 );
238245
239246 if (prune) {
240247 auto k_compressed = k->slice ({{0 , seq_len - recentWindow, recentWindow}});
@@ -382,19 +389,21 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
382389
383390 size_t token_offset = 0 ;
384391 for (uint32_t req = 0 ; req < nreq; req++) {
385- auto past_len = req_pos[req]; // for kv cache
392+ auto past_len = req_pos[req];
386393 auto seq_len = req_lens[req];
387394 auto o = o_buf->slice ({{0 , token_offset, seq_len}});
388395 auto q = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , 0 , nh}});
389396 auto k = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , nh, nkvh}});
390- auto v = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , nh + nkvh, nkvh}}); // 同一个req的qkv本身也储存在一起,arg2=起始位置,arg3=大小
391- // 不同req的qkv存在一起,用token_offset来维护
392- bool prune = (past_len == 0 ) && (seq_len > recentWindow);
397+ auto v = qkv_buf->slice ({{0 , token_offset, seq_len}, {1 , nh + nkvh, nkvh}});
398+
399+ uint32_t recentWindow = (uint32_t ) (seq_len * ratio);
400+ bool prune = sparseOn && (past_len == 0 ) && (recentWindow > 0 );
401+
393402 // self attention
394403 if (prune) { // first prefill phase
395404 auto k_compressed = k->slice ({{0 , seq_len - recentWindow, recentWindow}});
396405 auto v_compressed = v->slice ({{0 , seq_len - recentWindow, recentWindow}});
397- // 存入kv cache
406+
398407 RUN_INFINI (infiniopRearrange ( // concat
399408 desc_kv_rearranges[req],
400409 kv_caches[req]->k [idev][layer]->data (past_len * nkvh * dh),
@@ -421,7 +430,7 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
421430 RUN_INFINI (infiniopRearrange (
422431 desc_kv_rearranges[req],
423432 kv_caches[req]->k [idev][layer]->data (past_len * nkvh * dh),
424- k->data (), stream)); // 加进kv cache
433+ k->data (), stream));
425434
426435 RUN_INFINI (infiniopRearrange (
427436 desc_kv_rearranges[req],
@@ -480,7 +489,6 @@ void inferDeviceBatch(const JiugeMeta &meta, DeviceResource &rsrc,
480489 desc_ffn_down, workspace, workspace_size,
481490 logits_in->data (), gate_buf->data (),
482491 rsrc.w_ffn_down [layer]->data (), 1.0 , idev == 0 ? 1.0 : 0.0 , stream)); // only rank 0 adds residual
483- // logits_in->data()即是下一层的输入
484492 // All_reduce if distributed
485493 if (rsrc.comm != nullptr ) {
486494 RUN_INFINI (infinicclAllReduce (
@@ -573,13 +581,13 @@ inferBatch(struct JiugeModel *model,
573581 model->req .topk = topk;
574582 model->req .topp = topp;
575583
576- for (size_t idev = 0 ; idev < model->dev_ids .size (); idev++) { // 启动多设备推理
584+ for (size_t idev = 0 ; idev < model->dev_ids .size (); idev++) {
577585 std::unique_lock<std::mutex> lock (model->states [idev].mtx );
578586 model->states [idev].proceed = true ;
579587 lock.unlock ();
580- model->states [idev].cv_start .notify_one (); // 唤醒一个线程去执行推理任务
588+ model->states [idev].cv_start .notify_one ();
581589 }
582- for (size_t i = model->dev_ids .size (); i > 0 ; i--) { // 等待推理完成
590+ for (size_t i = model->dev_ids .size (); i > 0 ; i--) {
583591 auto idev = i - 1 ;
584592 std::unique_lock<std::mutex> lock (model->states [idev].mtx );
585593 model->states [idev].cv_done .wait (lock, [&] { return !(model->states [idev].proceed ); });
0 commit comments