@@ -487,9 +487,9 @@ struct llm_graph_context {
487487
488488 ggml_tensor * build_attn_mha (
489489 ggml_cgraph * gf,
490- ggml_tensor * q,
491- ggml_tensor * k,
492- ggml_tensor * v,
490+ ggml_tensor * q, // [n_embd_head_q, n_tokens, n_head_q]
491+ ggml_tensor * k, // [n_embd_head_k, n_tokens, n_head_k]
492+ ggml_tensor * v, // [n_embd_head_v, n_tokens, n_head_v] (v_trans == false)
493493 ggml_tensor * kq_b,
494494 ggml_tensor * kq_mask,
495495 bool v_trans,
@@ -502,9 +502,9 @@ struct llm_graph_context {
502502 ggml_cgraph * gf,
503503 ggml_tensor * wo,
504504 ggml_tensor * wo_b,
505- ggml_tensor * q_cur,
506- ggml_tensor * k_cur,
507- ggml_tensor * v_cur,
505+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
506+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
507+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
508508 ggml_tensor * kq_b,
509509 float kq_scale,
510510 int il) const ;
@@ -516,9 +516,9 @@ struct llm_graph_context {
516516 ggml_cgraph * gf,
517517 ggml_tensor * wo,
518518 ggml_tensor * wo_b,
519- ggml_tensor * q_cur,
520- ggml_tensor * k_cur,
521- ggml_tensor * v_cur,
519+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
520+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
521+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
522522 ggml_tensor * kq_b,
523523 float kq_scale,
524524 int il) const ;
@@ -530,9 +530,9 @@ struct llm_graph_context {
530530 ggml_cgraph * gf,
531531 ggml_tensor * wo,
532532 ggml_tensor * wo_b,
533- ggml_tensor * q_cur,
534- ggml_tensor * k_cur,
535- ggml_tensor * v_cur,
533+ ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens]
534+ ggml_tensor * k_cur, // [n_embd_head_k, n_head_k, n_tokens]
535+ ggml_tensor * v_cur, // [n_embd_head_v, n_head_v, n_tokens]
536536 ggml_tensor * kq_b,
537537 float kq_scale,
538538 int il) const ;
0 commit comments