- 
                Notifications
    You must be signed in to change notification settings 
- Fork 13.5k
llama: store mrope data in KV cell #16825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
bf7f924
              90353ea
              c3e1393
              ebac831
              9102a7c
              f706358
              18842b6
              5ec41a1
              bed0f57
              45d60e1
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -900,6 +900,13 @@ void llama_kv_cache::apply_ubatch(const slot_info & sinfo, const llama_ubatch & | |
|  | ||
| cells.pos_set(idx, ubatch.pos[i]); | ||
|  | ||
| if (ubatch.has_mrope()) { | ||
| cells.pos_mrope_set(idx, { | ||
| ubatch.pos[i + ubatch.n_tokens], // y | ||
| ubatch.pos[i + ubatch.n_tokens*2], // x | ||
| }); | ||
| } | ||
|  | ||
| for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { | ||
| cells.seq_add(idx, ubatch.seq_id[i][s]); | ||
| } | ||
|  | @@ -1243,6 +1250,13 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u | |
|  | ||
| const llama_pos p1 = ubatch->pos[i]; | ||
|  | ||
| // for M-RoPE | ||
| llama_kv_pos_mrope p1_mrope; | ||
| if (ubatch->has_mrope()) { | ||
| p1_mrope.y = ubatch->pos[i + ubatch->n_tokens]; | ||
| p1_mrope.x = ubatch->pos[i + ubatch->n_tokens*2]; | ||
| } | ||
|  | ||
|          | ||
| const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii); | ||
|  | ||
| for (uint32_t j = 0; j < n_kv; ++j) { | ||
|  | @@ -1262,6 +1276,14 @@ void llama_kv_cache::set_input_kq_mask(ggml_tensor * dst, const llama_ubatch * u | |
| continue; | ||
| } | ||
|  | ||
| // M-RoPE causal mask | ||
| if (causal_attn && ubatch->has_mrope() && p0 == p1) { | ||
| const auto & p0_mrope = cells.pos_mrope_get(j); | ||
| if (p0_mrope.is_gt(p1_mrope)) { | ||
| continue; | ||
| } | ||
| } | ||
|  | ||
| // apply SWA if any | ||
| if (is_masked_swa(p0, p1)) { | ||
| continue; | ||
|  | ||
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
|  | @@ -9,6 +9,15 @@ | |||||||||||
| #include <set> | ||||||||||||
| #include <map> | ||||||||||||
|  | ||||||||||||
| struct llama_kv_pos_mrope { | ||||||||||||
| llama_pos y = 0; | ||||||||||||
| llama_pos x = 0; | ||||||||||||
| // return true if this position is greater than the other position | ||||||||||||
| bool is_gt(const llama_kv_pos_mrope & other) const { | ||||||||||||
| return (y > other.y) || (y == other.y && x > other.x); | ||||||||||||
| } | ||||||||||||
| }; | ||||||||||||
|  | ||||||||||||
|          | ||||||||||||
| // meta information about KV cells that can be part of multiple sequences at the same time | ||||||||||||
| // TODO: add unit tests | ||||||||||||
| class llama_kv_cells { | ||||||||||||
|  | @@ -43,6 +52,7 @@ class llama_kv_cells { | |||||||||||
|  | ||||||||||||
| void resize(uint32_t n) { | ||||||||||||
| pos.resize(n); | ||||||||||||
| pos_mrope.resize(n); | ||||||||||||
| shift.resize(n); | ||||||||||||
| seq.resize(n); | ||||||||||||
|  | ||||||||||||
|  | @@ -107,8 +117,9 @@ class llama_kv_cells { | |||||||||||
| for (uint32_t j = 0; j < n; ++j) { | ||||||||||||
| const auto idx = i + j; | ||||||||||||
|  | ||||||||||||
| res.pos[j] = pos[idx]; | ||||||||||||
| res.seq[j] = seq[idx]; | ||||||||||||
| res.pos [j] = pos[idx]; | ||||||||||||
| res.pos_mrope[j] = pos_mrope[idx]; | ||||||||||||
| res.seq [j] = seq[idx]; | ||||||||||||
|  | ||||||||||||
| assert(shift[idx] == 0); | ||||||||||||
| } | ||||||||||||
|  | @@ -125,8 +136,9 @@ class llama_kv_cells { | |||||||||||
| for (uint32_t j = 0; j < idxs.size(); ++j) { | ||||||||||||
| const auto idx = idxs[j]; | ||||||||||||
|  | ||||||||||||
| res.pos[j] = pos[idx]; | ||||||||||||
| res.seq[j] = seq[idx]; | ||||||||||||
| res.pos [j] = pos[idx]; | ||||||||||||
| res.pos_mrope[j] = pos_mrope[idx]; | ||||||||||||
| res.seq [j] = seq[idx]; | ||||||||||||
|  | ||||||||||||
| assert(shift[idx] == 0); | ||||||||||||
| } | ||||||||||||
|  | @@ -340,6 +352,13 @@ class llama_kv_cells { | |||||||||||
| return pos[i]; | ||||||||||||
| } | ||||||||||||
|  | ||||||||||||
| const llama_kv_pos_mrope & pos_mrope_get(uint32_t i) const { | ||||||||||||
| assert(i < pos.size()); | ||||||||||||
| assert(pos[i] != -1); | ||||||||||||
|  | ||||||||||||
| return pos_mrope[i]; | ||||||||||||
| } | ||||||||||||
|  | ||||||||||||
| // note: call only if the cell is not empty | ||||||||||||
| llama_pos get_shift(uint32_t i) const { | ||||||||||||
| assert(i < pos.size()); | ||||||||||||
|  | @@ -368,6 +387,11 @@ class llama_kv_cells { | |||||||||||
| used.insert(i); | ||||||||||||
| } | ||||||||||||
|  | ||||||||||||
| void pos_mrope_set(uint32_t i, llama_kv_pos_mrope p) { | ||||||||||||
| assert(i < pos_mrope.size()); | ||||||||||||
| pos_mrope[i] = p; | ||||||||||||
| } | ||||||||||||
|  | ||||||||||||
| // pos[i] = pos[i] + d | ||||||||||||
| // sets "has_shift" to true | ||||||||||||
| // note: call only if the cell is not empty | ||||||||||||
|  | @@ -424,6 +448,9 @@ class llama_kv_cells { | |||||||||||
|  | ||||||||||||
| std::vector<llama_pos> pos; | ||||||||||||
|  | ||||||||||||
| // stores addition info for M-RoPE positions | ||||||||||||
| std::vector<llama_kv_pos_mrope> pos_mrope; | ||||||||||||
|  | ||||||||||||
|          | ||||||||||||
| // stores addition info for M-RoPE positions | |
| std::vector<llama_kv_pos_mrope> pos_mrope; | |
| // stores extra optional cell info | |
| std::vector<llama_kv_cell_ext> ext; | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can make this multi-dimensional positional information more decoupled from the concept of rope: