diff --git a/examples/talk-llama/llama-arch.cpp b/examples/talk-llama/llama-arch.cpp index e63ab284bc3..062a9977678 100644 --- a/examples/talk-llama/llama-arch.cpp +++ b/examples/talk-llama/llama-arch.cpp @@ -34,6 +34,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_PLAMO2, "plamo2" }, { LLM_ARCH_CODESHELL, "codeshell" }, { LLM_ARCH_ORION, "orion" }, { LLM_ARCH_INTERNLM2, "internlm2" }, @@ -67,6 +68,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_JAIS, "jais" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_EXAONE, "exaone" }, + { LLM_ARCH_EXAONE4, "exaone4" }, { LLM_ARCH_RWKV6, "rwkv6" }, { LLM_ARCH_RWKV6QWEN2, "rwkv6qwen2" }, { LLM_ARCH_RWKV7, "rwkv7" }, @@ -81,9 +83,11 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_DOTS1, "dots1" }, { LLM_ARCH_ARCEE, "arcee" }, { LLM_ARCH_ERNIE4_5, "ernie4_5" }, + { LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" }, { LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" }, { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_LFM2, "lfm2" }, + { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -784,6 +788,36 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_PLAMO2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_X, "blk.%d.ssm_x" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + { LLM_TENSOR_SSM_DT_NORM, "blk.%d.ssm_dt_norm" }, + { LLM_TENSOR_SSM_B_NORM, "blk.%d.ssm_b_norm" }, + { LLM_TENSOR_SSM_C_NORM, "blk.%d.ssm_c_norm" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + }, + }, { LLM_ARCH_CODESHELL, { @@ -1477,6 +1511,26 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_EXAONE4, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" }, + } + }, { LLM_ARCH_RWKV6, { @@ -1793,6 +1847,31 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_ERNIE4_5_MOE, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" }, + { LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" }, + { LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + }, + }, { LLM_ARCH_HUNYUAN_MOE, { @@ -1854,6 +1933,23 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, } }, + { + LLM_ARCH_DREAM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_UNKNOWN, { @@ -2094,6 +2190,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { switch (arch) { case LLM_ARCH_JAMBA: case LLM_ARCH_FALCON_H1: + case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: return true; @@ -2101,3 +2198,12 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { return false; } } + +bool llm_arch_is_diffusion(const llm_arch & arch) { + switch (arch) { + case LLM_ARCH_DREAM: + return true; + default: + return false; + } +} diff --git a/examples/talk-llama/llama-arch.h b/examples/talk-llama/llama-arch.h index 1f973259524..d09b7d7810b 100644 --- a/examples/talk-llama/llama-arch.h +++ b/examples/talk-llama/llama-arch.h @@ -38,6 +38,7 @@ enum llm_arch { LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, LLM_ARCH_PLAMO, + LLM_ARCH_PLAMO2, LLM_ARCH_CODESHELL, LLM_ARCH_ORION, LLM_ARCH_INTERNLM2, @@ -71,6 +72,7 @@ enum llm_arch { LLM_ARCH_JAIS, LLM_ARCH_NEMOTRON, LLM_ARCH_EXAONE, + LLM_ARCH_EXAONE4, LLM_ARCH_RWKV6, LLM_ARCH_RWKV6QWEN2, LLM_ARCH_RWKV7, @@ -85,9 +87,11 @@ enum llm_arch { LLM_ARCH_DOTS1, LLM_ARCH_ARCEE, LLM_ARCH_ERNIE4_5, + LLM_ARCH_ERNIE4_5_MOE, LLM_ARCH_HUNYUAN_MOE, LLM_ARCH_SMOLLM3, LLM_ARCH_LFM2, + LLM_ARCH_DREAM, LLM_ARCH_UNKNOWN, }; @@ -478,3 +482,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor); bool llm_arch_is_recurrent(const llm_arch & arch); bool llm_arch_is_hybrid (const llm_arch & arch); +bool llm_arch_is_diffusion(const llm_arch & arch); diff --git a/examples/talk-llama/llama-batch.cpp b/examples/talk-llama/llama-batch.cpp index 3bc8554e51c..a546063c0a7 100644 --- a/examples/talk-llama/llama-batch.cpp +++ b/examples/talk-llama/llama-batch.cpp @@ -27,6 +27,7 @@ bool llama_batch_allocr::init( const llama_vocab & vocab, const llama_memory_i * memory, uint32_t n_embd, + uint32_t n_seq_max, bool output_all) { clear(); @@ -40,6 +41,11 @@ bool llama_batch_allocr::init( // validate input batch // + if (n_seq_max > LLAMA_MAX_SEQ) { + LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ); + return false; + } + if (batch.token) { for (int32_t i = 0; i < batch.n_tokens; ++i) { if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) { @@ -52,8 +58,8 @@ bool llama_batch_allocr::init( if (batch.seq_id) { for (int32_t i = 0; i < batch.n_tokens; ++i) { for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) { - if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) { - LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ); + if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) { + LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max); return false; } } @@ -86,7 +92,7 @@ bool llama_batch_allocr::init( // initialize the starting position for each sequence based on the positions in the memory llama_pos p0[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (!memory) { // if no memory -> start from 0 p0[s] = 0; @@ -143,13 +149,16 @@ bool llama_batch_allocr::init( // compute stats // - this->n_embd = n_embd; + this->n_embd = n_embd; + this->n_seq_max = n_seq_max; // count the outputs in this batch for (int32_t i = 0; i < batch.n_tokens; ++i) { n_outputs += batch.logits[i] != 0; } + has_cpl = false; + // determine coupled sequences // these are pairs of sequences that have at least one token in the input batch that is assigned to both of them for (int32_t i = 0; i < batch.n_tokens; ++i) { @@ -189,7 +198,7 @@ bool llama_batch_allocr::init( seq_set_map[cur].push_back(i); } - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_set_unq.test(s)) { seq_idx[s] = seq_id_unq.size(); seq_id_unq.push_back(s); @@ -201,7 +210,7 @@ bool llama_batch_allocr::init( LLAMA_LOG_DEBUG("%s: input batch info:\n", __func__); llama_ubatch ubatch { - /*.equal_seqs =*/ false, + /*.b_equal_seqs =*/ false, /*.n_tokens =*/ (uint32_t) batch.n_tokens, /*.n_seq_tokens =*/ (uint32_t) 1, /*.n_seqs =*/ (uint32_t) batch.n_tokens, @@ -214,6 +223,7 @@ bool llama_batch_allocr::init( /*.seq_id_unq =*/ this->seq_id_unq.data(), /*.seq_idx =*/ this->seq_idx.data(), /*.output =*/ batch.logits, + /*.data =*/ {}, }; ubatch_print(ubatch, debug); @@ -241,7 +251,7 @@ bool llama_batch_allocr::init( // consistency checks // - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_pos[s].empty()) { continue; } @@ -284,8 +294,8 @@ bool llama_batch_allocr::init( } if (memory) { - for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) { - for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) { + for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) { + for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) { if (seq_cpl[s0][s1]) { if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) || memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) { @@ -316,12 +326,12 @@ bool llama_batch_allocr::init( // { seq_set_t cur_seq_set[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { cur_seq_set[s].set(); } llama_pos cur_seq_pos[LLAMA_MAX_SEQ]; - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { cur_seq_pos[s] = -1; } @@ -357,39 +367,38 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t clear(); split_reset(); - ubatches.emplace_back(); + auto udata = std::make_shared(); - auto & ubatch = ubatches.back(); - - ubatch.token .resize(n_tokens); - ubatch.embd .clear(); - ubatch.pos .resize(n_tokens); - ubatch.n_seq_id .resize(n_tokens); - ubatch.seq_id .resize(n_tokens); - ubatch.seq_id_unq.resize(0); - ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1); - ubatch.output .resize(n_tokens); + udata->token .resize(n_tokens); + udata->embd .clear(); + udata->pos .resize(n_tokens); + udata->n_seq_id .resize(n_tokens); + udata->seq_id .resize(n_tokens); + udata->seq_id_unq.resize(0); + udata->seq_idx .resize(LLAMA_MAX_SEQ, -1); + udata->output .resize(n_tokens); for (uint32_t s = 0; s < n_seqs; ++s) { - ubatch.seq_idx[s] = s; - ubatch.seq_id_unq.push_back(s); + udata->seq_idx[s] = s; + udata->seq_id_unq.push_back(s); } llama_ubatch res { - /*.equal_seqs =*/ true, + /*.b_equal_seqs =*/ true, /*.n_tokens =*/ n_tokens, /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, /*.n_seqs_unq =*/ n_seqs, - /*.token =*/ ubatch.token.data(), + /*.token =*/ udata->token.data(), /*.embd =*/ nullptr, - /*.pos =*/ ubatch.pos.data(), - /*.n_seq_id =*/ ubatch.n_seq_id.data(), - /*.seq_id =*/ ubatch.seq_id.data(), - /*.seq_id_unq =*/ ubatch.seq_id_unq.data(), - /*.seq_idx =*/ ubatch.seq_idx.data(), - /*.output =*/ ubatch.output.data(), + /*.pos =*/ udata->pos.data(), + /*.n_seq_id =*/ udata->n_seq_id.data(), + /*.seq_id =*/ udata->seq_id.data(), + /*.seq_id_unq =*/ udata->seq_id_unq.data(), + /*.seq_idx =*/ udata->seq_idx.data(), + /*.output =*/ udata->output.data(), + /*.data =*/ std::move(udata), }; return res; @@ -430,8 +439,6 @@ void llama_batch_allocr::split_reset() { used.clear(); used.resize(get_n_tokens(), false); - - ubatches.clear(); } llama_ubatch llama_batch_allocr::split_simple(uint32_t n_ubatch) { @@ -646,78 +653,77 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u assert(n_tokens%n_seqs == 0); - ubatches.emplace_back(); - - auto & ubatch = ubatches.back(); + auto udata = std::make_shared(); const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1; const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0; const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur; - ubatch.token .resize(n_tokens); - ubatch.embd .resize(n_embd_all); - ubatch.pos .resize(n_pos_all); - ubatch.n_seq_id .resize(n_tokens); - ubatch.seq_id .resize(n_tokens); - ubatch.seq_id_unq.resize(0); - ubatch.seq_idx .resize(LLAMA_MAX_SEQ, -1); - ubatch.output .resize(n_tokens); + udata->token .resize(n_tokens); + udata->embd .resize(n_embd_all); + udata->pos .resize(n_pos_all); + udata->n_seq_id .resize(n_tokens); + udata->seq_id .resize(n_tokens); + udata->seq_id_unq.resize(0); + udata->seq_idx .resize(LLAMA_MAX_SEQ, -1); + udata->output .resize(n_tokens); seq_set_t seq_set_unq; for (size_t i = 0; i < idxs.size(); ++i) { if (batch.token) { - ubatch.token[i] = batch.token[idxs[i]]; + udata->token[i] = batch.token[idxs[i]]; } if (batch.embd) { - memcpy(ubatch.embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float)); + memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float)); } for (int j = 0; j < n_pos_cur; ++j) { - ubatch.pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]]; + udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]]; } - ubatch.n_seq_id[i] = batch.n_seq_id[idxs[i]]; - ubatch.seq_id[i] = batch.seq_id[idxs[i]]; - ubatch.output[i] = batch.logits[idxs[i]]; + udata->n_seq_id[i] = batch.n_seq_id[idxs[i]]; + udata->seq_id[i] = batch.seq_id[idxs[i]]; + udata->output[i] = batch.logits[idxs[i]]; - for (int s = 0; s < ubatch.n_seq_id[i]; ++s) { - seq_set_unq.set(ubatch.seq_id[i][s]); + for (int s = 0; s < udata->n_seq_id[i]; ++s) { + seq_set_unq.set(udata->seq_id[i][s]); } - if (ubatch.output[i]) { + if (udata->output[i]) { out_ids.push_back(idxs[i]); } } - for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < n_seq_max; ++s) { if (seq_set_unq.test(s)) { - ubatch.seq_idx[s] = ubatch.seq_id_unq.size(); - ubatch.seq_id_unq.push_back(s); + udata->seq_idx[s] = udata->seq_id_unq.size(); + udata->seq_id_unq.push_back(s); } } llama_ubatch res { - /*.equal_seqs =*/ equal_seqs, + /*.b_equal_seqs =*/ equal_seqs, /*.n_tokens =*/ n_tokens, /*.n_seq_tokens =*/ n_tokens/n_seqs, /*.n_seqs =*/ n_seqs, - /*.n_seqs_unq =*/ (uint32_t) ubatch.seq_id_unq.size(), - - /*.token =*/ batch.token ? ubatch.token.data() : nullptr, - /*.embd =*/ batch.embd ? ubatch.embd.data() : nullptr, - /*.pos =*/ ubatch.pos.data(), - /*.n_seq_id =*/ ubatch.n_seq_id.data(), - /*.seq_id =*/ ubatch.seq_id.data(), - /*.seq_id_unq =*/ ubatch.seq_id_unq.data(), - /*.seq_idx =*/ ubatch.seq_idx.data(), - /*.output =*/ ubatch.output.data(), + /*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(), + + /*.token =*/ batch.token ? udata->token.data() : nullptr, + /*.embd =*/ batch.embd ? udata->embd.data() : nullptr, + /*.pos =*/ udata->pos.data(), + /*.n_seq_id =*/ udata->n_seq_id.data(), + /*.seq_id =*/ udata->seq_id.data(), + /*.seq_id_unq =*/ udata->seq_id_unq.data(), + /*.seq_idx =*/ udata->seq_idx.data(), + /*.output =*/ udata->output.data(), + /*.data =*/ std::move(udata), }; if (debug > 0) { - LLAMA_LOG_DEBUG("%s: added ubatch %d to split:\n", __func__, (int) ubatches.size() - 1); + LLAMA_LOG_DEBUG("%s: added ubatch to split:\n", __func__); ubatch_print(res, debug); } @@ -727,7 +733,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector & idxs, u void llama_batch_allocr::ubatch_print(const llama_ubatch & ubatch, int debug) { if (debug > 0) { - LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs); + LLAMA_LOG_DEBUG("%s: equal_seqs = %d\n", __func__, ubatch.equal_seqs()); LLAMA_LOG_DEBUG("%s: n_tokens = %d\n", __func__, ubatch.n_tokens); LLAMA_LOG_DEBUG("%s: n_seq_tokens = %d\n", __func__, ubatch.n_seq_tokens); LLAMA_LOG_DEBUG("%s: n_seqs = %d\n", __func__, ubatch.n_seqs); diff --git a/examples/talk-llama/llama-batch.h b/examples/talk-llama/llama-batch.h index 3420803ff94..d563adc66aa 100644 --- a/examples/talk-llama/llama-batch.h +++ b/examples/talk-llama/llama-batch.h @@ -8,12 +8,17 @@ #include #include #include +#include #include // keep this struct lightweight -// it points to data in `llama_batch_allocr` struct llama_ubatch { - bool equal_seqs; + bool equal_seqs() const { + return b_equal_seqs != 0; + } + + uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment + // otherwise address sanitizer complains // TODO: whole_seqs for embeddings? uint32_t n_tokens; // total tokens (n_seq_tokens * n_seqs) @@ -34,6 +39,20 @@ struct llama_ubatch { llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id int32_t * seq_idx; // [LLAMA_MAX_SEQ] | - | seq_idx int8_t * output; // [n_tokens] | i | - + + struct data_t { + std::vector token; + std::vector embd; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector seq_id_unq; + std::vector seq_idx; + std::vector output; + }; + + // the llama_ubatch pointers above point to this data if set. otherwise - points to non-owning data + std::shared_ptr data; }; // a helper for sanitizing, fulfilling and splitting a batch @@ -48,6 +67,7 @@ class llama_batch_allocr { const llama_vocab & vocab, const llama_memory_i * memory, uint32_t n_embd, + uint32_t n_seq_max, bool output_all); const llama_batch & get_batch() const; @@ -100,6 +120,7 @@ class llama_batch_allocr { const uint32_t n_pos_per_embd; uint32_t n_embd; + uint32_t n_seq_max; uint32_t n_outputs; std::array seq_id_0 = { 0 }; // default sequence id @@ -115,7 +136,7 @@ class llama_batch_allocr { using seq_cpl_t = std::vector; // helper flag to quickly determine if there are any coupled sequences in the batch - bool has_cpl; + bool has_cpl = false; std::vector seq_pos; // seq_pos[s]: the set of positions in sequence s std::vector seq_cpl; // seq_cpl[s0][s1]: if sequence s0 is coupled to sequence s1 @@ -135,20 +156,5 @@ class llama_batch_allocr { // used[i] indicates if token i has already been used in a previous ubatch std::vector used; - // llama_ubatch points to this data: - struct ubatch { - std::vector token; - std::vector embd; - std::vector pos; - std::vector n_seq_id; - std::vector seq_id; - std::vector seq_id_unq; - std::vector seq_idx; - std::vector output; - }; - - // current splitting state: - std::vector ubatches; - int debug; }; diff --git a/examples/talk-llama/llama-chat.cpp b/examples/talk-llama/llama-chat.cpp index cbc19d3c40c..d34bb26878c 100644 --- a/examples/talk-llama/llama-chat.cpp +++ b/examples/talk-llama/llama-chat.cpp @@ -56,6 +56,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "glmedge", LLM_CHAT_TEMPLATE_GLMEDGE }, { "minicpm", LLM_CHAT_TEMPLATE_MINICPM }, { "exaone3", LLM_CHAT_TEMPLATE_EXAONE_3 }, + { "exaone4", LLM_CHAT_TEMPLATE_EXAONE_4 }, { "rwkv-world", LLM_CHAT_TEMPLATE_RWKV_WORLD }, { "granite", LLM_CHAT_TEMPLATE_GRANITE }, { "gigachat", LLM_CHAT_TEMPLATE_GIGACHAT }, @@ -65,6 +66,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "llama4", LLM_CHAT_TEMPLATE_LLAMA4 }, { "smolvlm", LLM_CHAT_TEMPLATE_SMOLVLM }, { "hunyuan-moe", LLM_CHAT_TEMPLATE_HUNYUAN_MOE }, + { "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -167,10 +169,13 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { } else if (tmpl_contains(LU8("<|Assistant|>")) && tmpl_contains(LU8("<|User|>")) && tmpl_contains(LU8("<|end▁of▁sentence|>"))) { return LLM_CHAT_TEMPLATE_DEEPSEEK_3; } else if (tmpl_contains("[|system|]") && tmpl_contains("[|assistant|]") && tmpl_contains("[|endofturn|]")) { + if (tmpl_contains("[|tool|]")) { + return LLM_CHAT_TEMPLATE_EXAONE_4; + } // ref: https://huggingface.co/LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct/discussions/8#66bae61b1893d14ee8ed85bb // EXAONE-3.0-7.8B-Instruct return LLM_CHAT_TEMPLATE_EXAONE_3; - } else if (tmpl_contains("rwkv-world")) { + } else if (tmpl_contains("rwkv-world") || tmpl_contains("{{- 'User: ' + message['content']|trim + '\\n\\n' -}}")) { return LLM_CHAT_TEMPLATE_RWKV_WORLD; } else if (tmpl_contains("<|start_of_role|>")) { return LLM_CHAT_TEMPLATE_GRANITE; @@ -188,6 +193,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_DOTS1; } else if (tmpl_contains("<|startoftext|>") && tmpl_contains("<|extra_4|>")) { return LLM_CHAT_TEMPLATE_HUNYUAN_MOE; + } else if (tmpl_contains("<|im_assistant|>assistant<|im_middle|>")) { + return LLM_CHAT_TEMPLATE_KIMI_K2; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -529,6 +536,22 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[|assistant|]"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_EXAONE_4) { + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "[|system|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "user") { + ss << "[|user|]" << trim(message->content) << "\n"; + } else if (role == "assistant") { + ss << "[|assistant|]" << trim(message->content) << "[|endofturn|]\n"; + } else if (role == "tool") { + ss << "[|tool|]" << trim(message->content) << "[|endofturn|]\n"; + } + } + if (add_ass) { + ss << "[|assistant|]"; + } } else if (tmpl == LLM_CHAT_TEMPLATE_RWKV_WORLD) { // this template requires the model to have "\n\n" as EOT token for (size_t i = 0; i < chat.size(); i++) { @@ -680,6 +703,25 @@ int32_t llm_chat_apply_template( ss << "<|startoftext|>" << message->content << "<|extra_0|>"; } } + } else if (tmpl == LLM_CHAT_TEMPLATE_KIMI_K2) { + // moonshotai/Kimi-K2-Instruct + for (auto message : chat) { + std::string role(message->role); + if (role == "system") { + ss << "<|im_system|>system<|im_middle|>"; + } else if (role == "user") { + ss << "<|im_user|>user<|im_middle|>"; + } else if (role == "assistant") { + ss << "<|im_assistant|>assistant<|im_middle|>"; + } else if (role == "tool") { + ss << "<|im_system|>tool<|im_middle|>"; + } + + ss << message->content << "<|im_end|>"; + } + if (add_ass) { + ss << "<|im_assistant|>assistant<|im_middle|>"; + } } else { // template not supported return -1; diff --git a/examples/talk-llama/llama-chat.h b/examples/talk-llama/llama-chat.h index b621fda2816..6968a19fbe1 100644 --- a/examples/talk-llama/llama-chat.h +++ b/examples/talk-llama/llama-chat.h @@ -35,6 +35,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_GLMEDGE, LLM_CHAT_TEMPLATE_MINICPM, LLM_CHAT_TEMPLATE_EXAONE_3, + LLM_CHAT_TEMPLATE_EXAONE_4, LLM_CHAT_TEMPLATE_RWKV_WORLD, LLM_CHAT_TEMPLATE_GRANITE, LLM_CHAT_TEMPLATE_GIGACHAT, @@ -45,6 +46,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SMOLVLM, LLM_CHAT_TEMPLATE_DOTS1, LLM_CHAT_TEMPLATE_HUNYUAN_MOE, + LLM_CHAT_TEMPLATE_KIMI_K2, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/examples/talk-llama/llama-context.cpp b/examples/talk-llama/llama-context.cpp index 06e93b19cbf..9e77fe6d869 100644 --- a/examples/talk-llama/llama-context.cpp +++ b/examples/talk-llama/llama-context.cpp @@ -98,10 +98,20 @@ llama_context::llama_context( LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD); cparams.n_batch = GGML_KQ_MASK_PAD; } - cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); cparams.op_offload = params.op_offload; + cparams.kv_unified = params.kv_unified; + + { + const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); + supports_set_rows = LLAMA_SET_ROWS ? (atoi(LLAMA_SET_ROWS) != 0) : false; + + if (!supports_set_rows && !cparams.kv_unified) { + LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__); + cparams.kv_unified = true; + } + } const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; @@ -112,6 +122,7 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch); LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn); LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn); + LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false"); LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base); LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale); @@ -227,8 +238,8 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); - // buffer used to store the computation graph and the tensor meta data - buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + gf_res_prev.reset(new llm_graph_result(max_nodes)); + gf_res_reserve.reset(new llm_graph_result(max_nodes)); // TODO: move these checks to ggml_backend_sched // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary @@ -267,7 +278,7 @@ llama_context::llama_context( // reserve worst-case graph if (!hparams.vocab_only && memory) { - const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); @@ -287,7 +298,7 @@ llama_context::llama_context( cross.v_embd.clear(); - // reserve pp graph first so that buffers are only allocated once + // reserve pp (prompt processing) graph first so that buffers are only allocated once { auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { @@ -298,9 +309,9 @@ llama_context::llama_context( n_nodes_pp = ggml_graph_n_nodes(gf); } - // reserve with tg graph to get the number of splits and nodes + // reserve with tg (token generation) graph to get the number of splits and nodes { - auto * gf = graph_reserve(1, 1, 1, mctx.get()); + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute tg buffers"); } @@ -311,6 +322,10 @@ llama_context::llama_context( // reserve again with pp graph to avoid ggml-alloc reallocations during inference { + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); if (!gf) { throw std::runtime_error("failed to allocate compute pp buffers"); @@ -388,10 +403,6 @@ ggml_backend_sched_t llama_context::get_sched() const { return sched.get(); } -ggml_context * llama_context::get_ctx_compute() const { - return ctx_compute.get(); -} - uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } @@ -463,6 +474,11 @@ bool llama_context::kv_self_update(bool optimize) { } } + // reset the previous graph result to make sure that it won't be reused + // TODO: change the mctx->apply() to return information if a graph reserve is needed + // reset the graph result only if the memory module did reset the scheduler + gf_res_prev->reset(); + if (!mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory update\n", __func__); } @@ -475,7 +491,7 @@ bool llama_context::kv_self_update(bool optimize) { throw std::runtime_error("failed to initialize memory context"); } - const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); @@ -492,12 +508,16 @@ enum llama_pooling_type llama_context::pooling_type() const { } float * llama_context::get_logits() { + output_reorder(); + return logits; } float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; + output_reorder(); + try { if (logits == nullptr) { throw std::runtime_error("no logits"); @@ -534,12 +554,16 @@ float * llama_context::get_logits_ith(int32_t i) { } float * llama_context::get_embeddings() { + output_reorder(); + return embd; } float * llama_context::get_embeddings_ith(int32_t i) { int64_t j = -1; + output_reorder(); + try { if (embd == nullptr) { throw std::runtime_error("no embeddings"); @@ -678,38 +702,59 @@ bool llama_context::apply_adapter_cvec( return cvec.apply(model, data, len, n_embd, il_start, il_end); } -llm_graph_result_ptr llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { +llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { if (mctx && !mctx->apply()) { LLAMA_LOG_ERROR("%s: failed to apply memory context\n", __func__); ret = GGML_STATUS_FAILED; return nullptr; } - auto * gf = graph_init(); - if (!gf) { - LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); - ret = GGML_STATUS_FAILED; - return nullptr; - } + auto * res = gf_res_prev.get(); + auto * gf = res->get_gf(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, gtype, mctx); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph\n", __func__); - ret = GGML_STATUS_FAILED; - return nullptr; - } + // the new graph parameters + // in order to correctly reuse a graph, it's full topology has to be uniquely determined by these parameters + const auto gparams = graph_params(res, ubatch, mctx, gtype); - // LLAMA_LOG_INFO("graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf->n_nodes, gf->n_leafs); + if (res->can_reuse(gparams)) { + //LLAMA_LOG_DEBUG("%s: reusing previous graph\n", __func__); - if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { - LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); - ret = GGML_STATUS_ALLOC_FAILED; - return nullptr; + n_reused++; + } else { + res->reset(); + + ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); + + //const auto t_start_us = ggml_time_us(); + + gf = model.build_graph(gparams); + + //LLAMA_LOG_INFO("graph build time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + + if (!gf) { + LLAMA_LOG_ERROR("%s: failed to initialize graph\n", __func__); + ret = GGML_STATUS_FAILED; + return nullptr; + } + + if (!ggml_backend_sched_alloc_graph(sched.get(), gf)) { + LLAMA_LOG_ERROR("%s: failed to allocate graph\n", __func__); + ret = GGML_STATUS_ALLOC_FAILED; + return nullptr; + } } - res->set_inputs(&ubatch); + // set the input data for the input tensors + { + //const auto t_start_us = ggml_time_us(); + + res->set_inputs(&ubatch); + + //LLAMA_LOG_INFO("graph set inputs time: %.3f ms\n", (ggml_time_us() - t_start_us)/1000.0); + } - const auto status = graph_compute(gf, ubatch.n_tokens > 1); + const auto status = graph_compute(res->get_gf(), ubatch.n_tokens > 1); if (status != GGML_STATUS_SUCCESS) { LLAMA_LOG_ERROR("%s: failed to compute graph, compute status: %d\n", __func__, status); ret = status; @@ -731,16 +776,19 @@ int llama_context::encode(const llama_batch & batch_inp) { const auto & hparams = model.hparams; - const int64_t n_embd = hparams.n_embd; + const int64_t n_embd = hparams.n_embd; + const int32_t n_vocab = model.vocab.n_tokens(); // note: during encode, we always pass the full sequence starting from pos = 0 - if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) { + if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } const uint32_t n_tokens = balloc->get_n_tokens(); + // [TAG_NO_CACHE_PAD] + // TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true const llama_ubatch ubatch = balloc->split_simple(n_tokens); // micro-batching is not possible for non-causal encoding, so we process the batch in a single shot @@ -767,9 +815,6 @@ int llama_context::encode(const llama_batch & batch_inp) { n_outputs = n_tokens; - ggml_backend_sched_reset(sched.get()); - ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - const auto causal_attn_org = cparams.causal_attn; // always use non-causal attention for encoder graphs @@ -778,7 +823,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cparams.causal_attn = false; ggml_status status; - const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_ENCODER, nullptr, status); cparams.causal_attn = causal_attn_org; @@ -791,10 +836,20 @@ int llama_context::encode(const llama_batch & batch_inp) { } } + auto * t_logits = res->get_logits(); auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); + // extract logits + if (logits && t_logits) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + } + // extract embeddings - if (t_embd) { + if (embd && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -844,9 +899,11 @@ int llama_context::encode(const llama_batch & batch_inp) { } } - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); + if (!supports_set_rows) { + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + } // TODO: hacky solution if (model.arch == LLM_ARCH_T5 && t_embd) { @@ -899,7 +956,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) { + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -927,6 +984,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); + output_swaps.clear(); bool did_optimize = false; @@ -1005,11 +1063,8 @@ int llama_context::decode(const llama_batch & batch_inp) { n_outputs = n_outputs_new; } - ggml_backend_sched_reset(sched.get()); - ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); - ggml_status status; - const auto res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); + const auto * res = process_ubatch(ubatch, LLM_GRAPH_TYPE_DECODER, mctx.get(), status); if (!res) { // the last ubatch failed or was aborted -> remove all positions of that ubatch from the KV cache @@ -1149,9 +1204,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // make the outputs have the same order they had in the user-provided batch // note: this is mostly relevant for recurrent models atm if (!sorted_output) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint64_t n_embd = model.hparams.n_embd; - GGML_ASSERT((size_t) n_outputs == out_ids.size()); // TODO: is there something more efficient which also minimizes swaps? @@ -1167,16 +1219,9 @@ int llama_context::decode(const llama_batch & batch_inp) { continue; } std::swap(out_ids[i], out_ids[j_min]); - if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); - } - } - if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); - } - } + + // remember the swaps and apply them lazily upon logits/embeddings access + output_swaps.push_back({ i, j_min }); } std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1190,9 +1235,11 @@ int llama_context::decode(const llama_batch & batch_inp) { // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); - // Reset state for the next token before backend sync, to allow the CPU activities in the reset to - // overlap with device computation. - ggml_backend_sched_reset(sched.get()); + if (!supports_set_rows) { + // Reset state for the next token before backend sync, to allow the CPU activities in the reset to + // overlap with device computation. + ggml_backend_sched_reset(sched.get()); + } return 0; } @@ -1271,24 +1318,40 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { return n_outputs_max; } +void llama_context::output_reorder() { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint64_t n_embd = model.hparams.n_embd; + + for (uint32_t s = 0; s < output_swaps.size(); ++s) { + const uint32_t i0 = output_swaps[s].i0; + const uint32_t i1 = output_swaps[s].i1; + + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + } + } + + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); + } + } + } + + output_swaps.clear(); +} + // // graph // -int32_t llama_context::graph_max_nodes() const { - return std::max(65536, 5*model.n_tensors()); +uint32_t llama_context::graph_max_nodes() const { + return std::max(1024u, 8u*model.n_tensors()); } -ggml_cgraph * llama_context::graph_init() { - ggml_init_params params = { - /*.mem_size =*/ buf_compute_meta.size(), - /*.mem_buffer =*/ buf_compute_meta.data(), - /*.no_alloc =*/ true, - }; - - ctx_compute.reset(ggml_init(params)); - - return ggml_new_graph_custom(ctx_compute.get(), graph_max_nodes(), false); +llm_graph_result * llama_context::get_gf_res_reserve() const { + return static_cast(gf_res_reserve.get()); } ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx) { @@ -1301,6 +1364,11 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); } + ggml_backend_sched_reset(sched.get()); + + // when the scheduler is reset, we cannnot reuse the old graph, so we reset the previous graph result to prevent that + gf_res_prev->reset(); + // store the n_outputs as it is, and restore it afterwards // TODO: not sure if needed, might simplify in the future by removing this const auto save_n_outputs = this->n_outputs; @@ -1310,17 +1378,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx); + auto * res = gf_res_reserve.get(); - this->n_outputs = save_n_outputs; + const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build worst-case graph\n", __func__); - return nullptr; - } + res->reset(); - ggml_backend_sched_reset(sched.get()); + auto * gf = model.build_graph(gparams); + + this->n_outputs = save_n_outputs; // initialize scheduler with the specified graph if (!ggml_backend_sched_reserve(sched.get(), gf)) { @@ -1331,28 +1397,27 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u return gf; } -llm_graph_result_ptr llama_context::graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype, - const llama_memory_context_i * mctx) { - return model.build_graph( - { - /*.ctx =*/ ctx, - /*.arch =*/ model.arch, - /*.hparams =*/ model.hparams, - /*.cparams =*/ cparams, - /*.ubatch =*/ ubatch, - /*.sched =*/ sched.get(), - /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, - /*.mctx =*/ mctx, - /*.cross =*/ &cross, - /*.n_outputs =*/ n_outputs, - /*.cb =*/ graph_get_cb(), - }, gf, gtype); +llm_graph_params llama_context::graph_params( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_memory_context_i * mctx, + llm_graph_type gtype) const { + return { + /*.arch =*/ model.arch, + /*.hparams =*/ model.hparams, + /*.cparams =*/ cparams, + /*.ubatch =*/ ubatch, + /*.gtype =*/ gtype, + /*.sched =*/ sched.get(), + /*.backend_cpu =*/ backend_cpu, + /*.cvec =*/ &cvec, + /*.loras =*/ &loras, + /*.mctx =*/ mctx, + /*.cross =*/ &cross, + /*.n_outputs =*/ n_outputs, + /*.cb =*/ graph_get_cb(), + /*.res =*/ res, + }; } ggml_status llama_context::graph_compute( @@ -1930,6 +1995,7 @@ llama_perf_context_data llama_context::perf_get_data() const { data.t_eval_ms = 1e-3 * t_eval_us; data.n_p_eval = std::max(1, n_p_eval); data.n_eval = std::max(1, n_eval); + data.n_reused = std::max(0, n_reused); return data; } @@ -1938,6 +2004,7 @@ void llama_context::perf_reset() { t_start_us = ggml_time_us(); t_eval_us = n_eval = 0; t_p_eval_us = n_p_eval = 0; + n_reused = 0; } // @@ -2028,7 +2095,7 @@ void llama_context::opt_epoch_iter( batch.logits [pos_batch] = true; } - if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) { + if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return; } @@ -2064,8 +2131,13 @@ void llama_context::opt_epoch_iter( break; } - auto * gf = graph_init(); - auto res = graph_build(ctx_compute.get(), gf, ubatch, LLM_GRAPH_TYPE_DEFAULT, mctx.get()); + auto * res = gf_res_prev.get(); + + const auto gparams = graph_params(res, ubatch, mctx.get(), LLM_GRAPH_TYPE_DEFAULT); + + res->reset(); + + auto * gf = model.build_graph(gparams); struct ggml_context * ctx_compute_opt; { @@ -2187,6 +2259,7 @@ llama_context_params llama_context_default_params() { /*.no_perf =*/ true, /*.op_offload =*/ true, /*.swa_full =*/ true, + /*.kv_unified =*/ false, }; return result; @@ -2807,6 +2880,7 @@ void llama_perf_context_print(const llama_context * ctx) { LLAMA_LOG_INFO("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n", __func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval); LLAMA_LOG_INFO("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval)); + LLAMA_LOG_INFO("%s: graphs reused = %10d\n", __func__, data.n_reused); } void llama_perf_context_reset(llama_context * ctx) { diff --git a/examples/talk-llama/llama-context.h b/examples/talk-llama/llama-context.h index 9ce05715a8c..5c3a1c09886 100644 --- a/examples/talk-llama/llama-context.h +++ b/examples/talk-llama/llama-context.h @@ -35,8 +35,6 @@ struct llama_context { ggml_backend_sched_t get_sched() const; - ggml_context * get_ctx_compute() const; - uint32_t n_ctx() const; uint32_t n_ctx_per_seq() const; uint32_t n_batch() const; @@ -96,7 +94,7 @@ struct llama_context { // if memory_context is provided, it will be applied first to the context's memory // ret contains the status of the graph computation // returns nullptr only if ret != GGML_STATUS_SUCCESS - llm_graph_result_ptr process_ubatch( + llm_graph_result * process_ubatch( const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, @@ -183,15 +181,17 @@ struct llama_context { // Returns max number of outputs for which space was reserved. uint32_t output_reserve(int32_t n_outputs); + void output_reorder(); + // // graph // public: - int32_t graph_max_nodes() const; + uint32_t graph_max_nodes() const; - // zero-out inputs and create the ctx_compute for the compute graph - ggml_cgraph * graph_init(); + // can reuse the llm_graph_result instance of the context (for example to update a memory module) + llm_graph_result * get_gf_res_reserve() const; // returns the result of ggml_backend_sched_graph_compute_async execution ggml_status graph_compute(ggml_cgraph * gf, bool batched); @@ -200,12 +200,11 @@ struct llama_context { ggml_cgraph * graph_reserve(uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx); private: - llm_graph_result_ptr graph_build( - ggml_context * ctx, - ggml_cgraph * gf, - const llama_ubatch & ubatch, - llm_graph_type gtype, - const llama_memory_context_i * mctx); + llm_graph_params graph_params( + llm_graph_result * res, + const llama_ubatch & ubatch, + const llama_memory_context_i * mctx, + llm_graph_type gtype) const; llm_graph_cb graph_get_cb() const; @@ -253,13 +252,18 @@ struct llama_context { std::vector output_ids; // map batch token positions to ids of the logits and embd buffers + struct swap_info { + uint32_t i0; + uint32_t i1; + }; + + std::vector output_swaps; + ggml_backend_sched_ptr sched; ggml_backend_t backend_cpu = nullptr; std::vector backends; - ggml_context_ptr ctx_compute; - // training ggml_opt_context_t opt_ctx = nullptr; @@ -275,14 +279,18 @@ struct llama_context { std::vector backend_ptrs; std::vector backend_buft; - // memory buffers used to evaluate the model - std::vector buf_compute_meta; + llm_graph_result_ptr gf_res_prev; + llm_graph_result_ptr gf_res_reserve; // host buffer for the model output (logits and embeddings) ggml_backend_buffer_ptr buf_output; bool has_evaluated_once = false; + // env: LLAMA_SET_ROWS (temporary) + // ref: https://github.com/ggml-org/llama.cpp/pull/14285 + bool supports_set_rows = false; + // perf mutable int64_t t_start_us = 0; mutable int64_t t_load_us = 0; @@ -294,4 +302,6 @@ struct llama_context { mutable int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) mutable int32_t n_eval = 0; // number of eval calls + + mutable int32_t n_reused = 0; // number of times the previous graph was reused }; diff --git a/examples/talk-llama/llama-cparams.h b/examples/talk-llama/llama-cparams.h index 118615d5bd2..38750affc50 100644 --- a/examples/talk-llama/llama-cparams.h +++ b/examples/talk-llama/llama-cparams.h @@ -11,8 +11,8 @@ struct llama_cparams { uint32_t n_batch; uint32_t n_ubatch; uint32_t n_seq_max; - int n_threads; // number of threads to use for generation - int n_threads_batch; // number of threads to use for batch processing + int32_t n_threads; // number of threads to use for generation + int32_t n_threads_batch; // number of threads to use for batch processing float rope_freq_base; float rope_freq_scale; @@ -33,6 +33,7 @@ struct llama_cparams { bool no_perf; bool warmup; bool op_offload; + bool kv_unified; enum llama_pooling_type pooling_type; diff --git a/examples/talk-llama/llama-graph.cpp b/examples/talk-llama/llama-graph.cpp index a248a7ec223..b63a41053b4 100644 --- a/examples/talk-llama/llama-graph.cpp +++ b/examples/talk-llama/llama-graph.cpp @@ -28,6 +28,15 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { + bool res = true; + + res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); + res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens); + + return res; +} + void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && pos) { const int64_t n_tokens = ubatch->n_tokens; @@ -50,6 +59,14 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) { + bool res = true; + + res &= pos->ne[0] == params.ubatch.n_tokens; + + return res; +} + void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) { if (ubatch->pos && attn_scale) { const int64_t n_tokens = ubatch->n_tokens; @@ -71,7 +88,7 @@ void llm_graph_input_pos_bucket::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(pos_bucket->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing int32_t * data = (int32_t *) pos_bucket->data; @@ -118,6 +135,14 @@ void llm_graph_input_out_ids::set_input(const llama_ubatch * ubatch) { } } +bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) { + bool res = true; + + res &= n_outputs == params.n_outputs; + + return res; +} + void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = ubatch->n_tokens; @@ -287,6 +312,24 @@ void llm_graph_input_attn_kv_unified::set_input(const llama_ubatch * ubatch) { mctx->set_input_kq_mask(self_kq_mask, ubatch, cparams.causal_attn); } +bool llm_graph_input_attn_kv_unified::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= self_kq_mask->ne[0] == mctx->get_n_kv(); + res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= mctx->get_supports_set_rows(); // TODO: tmp + + return res; +} + void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch) { mctx->get_base()->set_input_k_idxs(self_k_idxs, ubatch); mctx->get_base()->set_input_v_idxs(self_v_idxs, ubatch); @@ -299,6 +342,30 @@ void llm_graph_input_attn_kv_unified_iswa::set_input(const llama_ubatch * ubatch mctx->get_swa()->set_input_kq_mask(self_kq_mask_swa, ubatch, cparams.causal_attn); } +bool llm_graph_input_attn_kv_unified_iswa::can_reuse(const llm_graph_params & params) { + const auto * mctx = static_cast(params.mctx); + + this->mctx = mctx; + + bool res = true; + + res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; + //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there + + res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); + res &= self_kq_mask->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); + res &= self_kq_mask_swa->ne[1] == GGML_PAD(params.ubatch.n_tokens, GGML_KQ_MASK_PAD); + + res &= mctx->get_base()->get_supports_set_rows(); // TODO: tmp + + return res; +} + void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { GGML_ASSERT(cross_kq_mask); @@ -306,7 +373,7 @@ void llm_graph_input_attn_cross::set_input(const llama_ubatch * ubatch) { const int64_t n_tokens = ubatch->n_tokens; GGML_ASSERT(ggml_backend_buffer_is_host(cross_kq_mask->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing float * data = (float *) cross_kq_mask->data; @@ -340,6 +407,91 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { inp_rs->set_input(ubatch); } +// +// llm_graph_result +// + +llm_graph_result::llm_graph_result(int64_t max_nodes) : max_nodes(max_nodes) { + reset(); + + const char * LLAMA_GRAPH_RESULT_DEBUG = getenv("LLAMA_GRAPH_RESULT_DEBUG"); + debug = LLAMA_GRAPH_RESULT_DEBUG ? atoi(LLAMA_GRAPH_RESULT_DEBUG) : 0; +} + +int64_t llm_graph_result::get_max_nodes() const { + return max_nodes; +} + +void llm_graph_result::reset() { + t_tokens = nullptr; + t_logits = nullptr; + t_embd = nullptr; + t_embd_pooled = nullptr; + + params = {}; + + inputs.clear(); + + buf_compute_meta.resize(ggml_tensor_overhead()*max_nodes + ggml_graph_overhead_custom(max_nodes, false)); + + ggml_init_params params = { + /*.mem_size =*/ buf_compute_meta.size(), + /*.mem_buffer =*/ buf_compute_meta.data(), + /*.no_alloc =*/ true, + }; + + ctx_compute.reset(ggml_init(params)); + + gf = ggml_new_graph_custom(ctx_compute.get(), max_nodes, false); +} + +void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { + for (auto & input : inputs) { + input->set_input(ubatch); + } +} + +bool llm_graph_result::can_reuse(const llm_graph_params & params) { + if (!this->params.allow_reuse(params)) { + if (debug > 1) { + LLAMA_LOG_DEBUG("%s: cannot reuse graph due to incompatible graph parameters\n", __func__); + } + + return false; + } + + if (debug > 1) { + LLAMA_LOG_DEBUG("%s: checking compatibility of %d inputs:\n", __func__, (int) inputs.size()); + } + + bool res = true; + + for (auto & input : inputs) { + const bool cur = input->can_reuse(params); + + if (debug > 1) { + LLAMA_LOG_DEBUG("%s: can_reuse = %d\n", "placeholder", cur); + } + + res = res && cur; + } + + if (debug > 0) { + LLAMA_LOG_DEBUG("%s: can reuse graph = %d\n", __func__, res); + } + + return res; +} + +llm_graph_input_i * llm_graph_result::add_input(llm_graph_input_ptr input) { + inputs.emplace_back(std::move(input)); + return inputs.back().get(); +} + +void llm_graph_result::set_params(const llm_graph_params & params) { + this->params = params; +} + // // llm_graph_context // @@ -374,7 +526,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : n_ctx_orig (cparams.n_ctx_orig_yarn), pooling_type (cparams.pooling_type), rope_type (hparams.rope_type), - ctx0 (params.ctx), sched (params.sched), backend_cpu (params.backend_cpu), cvec (params.cvec), @@ -382,7 +533,10 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : mctx (params.mctx), cross (params.cross), cb_func (params.cb), - res (std::make_unique()) { + res (params.res), + ctx0 (res->get_ctx()), + gf (res->get_gf()) { + res->set_params(params); } void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { @@ -753,20 +907,28 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_weighted", il); } + ggml_tensor * cur_experts[LLAMA_MAX_EXPERTS] = { nullptr }; + + assert(n_expert_used > 0); + + // order the views before the adds + for (uint32_t i = 0; i < hparams.n_expert_used; ++i) { + cur_experts[i] = ggml_view_2d(ctx0, experts, n_embd, n_tokens, experts->nb[2], i*experts->nb[1]); + + ggml_build_forward_expand(gf, cur_experts[i]); + } + // aggregate experts - ggml_tensor * moe_out = nullptr; - for (int i = 0; i < n_expert_used; ++i) { - ggml_tensor * cur_expert = ggml_view_2d(ctx0, experts, n_embd, n_tokens, - experts->nb[2], i*experts->nb[1]); + // note: here we explicitly use hparams.n_expert_used instead of n_expert_used + // to avoid potentially a large number of add nodes during warmup + // ref: https://github.com/ggml-org/llama.cpp/pull/14753 + ggml_tensor * moe_out = cur_experts[0]; - if (i == 0) { - moe_out = cur_expert; - } else { - moe_out = ggml_add(ctx0, moe_out, cur_expert); - } + for (uint32_t i = 1; i < hparams.n_expert_used; ++i) { + moe_out = ggml_add(ctx0, moe_out, cur_experts[i]); } - if (n_expert_used == 1) { + if (hparams.n_expert_used == 1) { // avoid returning a non-contiguous tensor moe_out = ggml_cont(ctx0, moe_out); } @@ -972,7 +1134,6 @@ ggml_tensor * llm_graph_context::build_pos_bias(ggml_tensor * pos_bucket, ggml_t } ggml_tensor * llm_graph_context::build_attn_mha( - ggml_cgraph * gf, ggml_tensor * q, ggml_tensor * k, ggml_tensor * v, @@ -982,13 +1143,16 @@ ggml_tensor * llm_graph_context::build_attn_mha( float kq_scale) const { const bool v_trans = v->nb[1] > v->nb[2]; + // split the batch into streams if needed + const auto n_stream = k->ne[3]; + + q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream); + q = ggml_permute(ctx0, q, 0, 2, 1, 3); k = ggml_permute(ctx0, k, 0, 2, 1, 3); v = ggml_permute(ctx0, v, 0, 2, 1, 3); - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - const auto n_kv = k->ne[1]; + const auto n_kv = k->ne[1]; ggml_tensor * cur; @@ -1030,7 +1194,7 @@ ggml_tensor * llm_graph_context::build_attn_mha( #endif } - cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); } else { ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); @@ -1075,7 +1239,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + // recombine streams + cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]); if (!cparams.offload_kqv) { // all nodes between the KV store and the attention output are run on the CPU @@ -1102,7 +1267,6 @@ llm_graph_input_attn_no_cache * llm_graph_context::build_attn_inp_no_cache() con ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_no_cache * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, @@ -1122,11 +1286,15 @@ ggml_tensor * llm_graph_context::build_attn( const auto & kq_mask = inp->get_kq_mask(); + // [TAG_NO_CACHE_PAD] + // TODO: if ubatch.equal_seqs() == true, we can split the three tensors below into ubatch.n_seqs_unq streams + assert(!ubatch.equal_seqs()); + ggml_tensor * q = q_cur; ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1156,13 +1324,14 @@ static std::unique_ptr build_attn_inp_kv_unifie { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); + const auto n_kv = mctx_cur->get_n_kv(); const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1181,7 +1350,6 @@ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_unified * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, @@ -1214,7 +1382,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1234,7 +1402,6 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_kv_unified_iswa * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, @@ -1281,7 +1448,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = mctx_cur->get_k(ctx0, il); ggml_tensor * v = mctx_cur->get_v(ctx0, il); - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1314,7 +1481,6 @@ llm_graph_input_attn_cross * llm_graph_context::build_attn_inp_cross() const { ggml_tensor * llm_graph_context::build_attn( llm_graph_input_attn_cross * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, @@ -1336,7 +1502,7 @@ ggml_tensor * llm_graph_context::build_attn( ggml_tensor * k = k_cur; ggml_tensor * v = v_cur; - ggml_tensor * cur = build_attn_mha(gf, q, k, v, kq_b, kq_mask, v_mla, kq_scale); + ggml_tensor * cur = build_attn_mha(q, k, v, kq_b, kq_mask, v_mla, kq_scale); cb(cur, "kqv_out", il); if (wo) { @@ -1362,13 +1528,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif auto inp = std::make_unique(hparams, cparams, mctx_cur); + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + { const auto n_kv = mctx_cur->get_base()->get_n_kv(); inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1382,7 +1550,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1); + inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask_swa); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; @@ -1392,7 +1560,6 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif } ggml_tensor * llm_graph_context::build_rs( - ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, int32_t state_size, @@ -1450,21 +1617,19 @@ llm_graph_input_rs * llm_graph_context::build_rs_inp() const { ggml_tensor * llm_graph_context::build_rs( llm_graph_input_rs * inp, - ggml_cgraph * gf, ggml_tensor * s, int32_t state_size, int32_t n_seqs, const llm_graph_get_rows_fn & get_state_rows) const { const auto * kv_state = inp->mctx; - return build_rs(gf, s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); + return build_rs(s, inp->s_copy, state_size, n_seqs, kv_state->get_n_rs(), kv_state->get_head(), kv_state->get_size(), kv_state->get_rs_z(), get_state_rows); } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( llm_graph_input_rs * inp, - ggml_cgraph * gf, const llama_ubatch & ubatch, - int il) const { + int il) const { const auto * mctx_cur = static_cast(mctx); const auto token_shift_count = hparams.token_shift_count; @@ -1474,7 +1639,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * token_shift_all = mctx_cur->get_r_l(il); ggml_tensor * token_shift = build_rs( - inp, gf, token_shift_all, + inp, token_shift_all, hparams.n_embd_r(), n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); @@ -1514,7 +1679,6 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { } void llm_graph_context::build_pooling( - ggml_cgraph * gf, ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, diff --git a/examples/talk-llama/llama-graph.h b/examples/talk-llama/llama-graph.h index fbf8e288956..a28a8c4bdda 100644 --- a/examples/talk-llama/llama-graph.h +++ b/examples/talk-llama/llama-graph.h @@ -1,6 +1,7 @@ #pragma once #include "llama-arch.h" +#include "llama-batch.h" #include "llama-hparams.h" #include "llama-adapter.h" @@ -14,7 +15,6 @@ struct ggml_cgraph; struct ggml_context; struct ggml_tensor; -struct llama_ubatch; struct llama_cparams; struct llama_memory_context_i; @@ -69,6 +69,8 @@ struct llama_cross { std::vector> seq_ids_enc; }; +struct llm_graph_params; + // // llm_graph_input // @@ -78,11 +80,19 @@ class llm_graph_input_i { virtual ~llm_graph_input_i() = default; virtual void set_input(const llama_ubatch * ubatch) = 0; + + // return true if the resulting input tensors using the provided graph parameters would be + // the same as the previous input tensors that we have currently stored in the object + virtual bool can_reuse(const llm_graph_params & params) { + // returning false here by default will prevent from reusing the graph if the check + // for the input type has not been implemented yet + GGML_UNUSED(params); + return false; + } }; using llm_graph_input_ptr = std::unique_ptr; - class llm_graph_input_embd : public llm_graph_input_i { public: llm_graph_input_embd() = default; @@ -90,6 +100,8 @@ class llm_graph_input_embd : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] }; @@ -101,6 +113,8 @@ class llm_graph_input_pos : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * pos = nullptr; // I32 [n_batch] const uint32_t n_pos_per_embd = 1; @@ -154,17 +168,19 @@ class llm_graph_input_out_ids : public llm_graph_input_i { llm_graph_input_out_ids( const llama_hparams & hparams, const llama_cparams & cparams, - int32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {} + uint32_t n_outputs) : hparams(hparams), cparams(cparams), n_outputs(n_outputs) {} virtual ~llm_graph_input_out_ids() = default; void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * out_ids; // I32 [n_outputs] const llama_hparams & hparams; const llama_cparams & cparams; - const int32_t n_outputs; + const uint32_t n_outputs; }; class llm_graph_input_mean : public llm_graph_input_i { @@ -249,16 +265,18 @@ class llm_graph_input_attn_kv_unified : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * get_k_idxs() const { return self_k_idxs; } ggml_tensor * get_v_idxs() const { return self_v_idxs; } ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; } ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] const llama_hparams & hparams; const llama_cparams & cparams; @@ -280,6 +298,8 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + ggml_tensor * get_k_idxs() const { return self_k_idxs; } ggml_tensor * get_v_idxs() const { return self_v_idxs; } ggml_tensor * get_k_idxs_swa() const { return self_k_idxs_swa; } @@ -289,14 +309,14 @@ class llm_graph_input_attn_kv_unified_iswa : public llm_graph_input_i { ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; } ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch] - ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] + ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa] - ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1] - ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1] + ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] const llama_hparams & hparams; const llama_cparams & cparams; @@ -351,40 +371,108 @@ class llm_graph_input_mem_hybrid : public llm_graph_input_i { // along with the input tensors, the object also provides commonly used outputs tensors, such as logits, embeddings, etc. // these are used by the llama_context to extact the relevant data, based on the compute parameters -class llm_graph_result_i { -public: - virtual ~llm_graph_result_i() = default; +// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) +using llm_graph_cb = std::function; - virtual ggml_tensor * get_tokens() = 0; - virtual ggml_tensor * get_logits() = 0; - virtual ggml_tensor * get_embd() = 0; - virtual ggml_tensor * get_embd_pooled() = 0; +class llm_graph_result; - virtual void set_inputs(const llama_ubatch * ubatch) = 0; -}; +struct llm_graph_params { + llm_arch arch = LLM_ARCH_UNKNOWN; -using llm_graph_result_ptr = std::unique_ptr; + llama_hparams hparams; + llama_cparams cparams; + llama_ubatch ubatch; // note: intentionally make a copy -class llm_graph_result : public llm_graph_result_i { -public: - virtual ~llm_graph_result() = default; + llm_graph_type gtype; - ggml_tensor * get_tokens() override { return t_tokens; } - ggml_tensor * get_logits() override { return t_logits; } - ggml_tensor * get_embd() override { return t_embd; } - ggml_tensor * get_embd_pooled() override { return t_embd_pooled; } + ggml_backend_sched_t sched; + ggml_backend_t backend_cpu; - void set_inputs(const llama_ubatch * ubatch) override { - for (auto & input : inputs) { - input->set_input(ubatch); + const llama_adapter_cvec * cvec; + const llama_adapter_loras * loras; + const llama_memory_context_i * mctx; + const llama_cross * cross; + + uint32_t n_outputs; + + llm_graph_cb cb; + + llm_graph_result * res; + + // return true if the "other" params would result in a graph with the same topology as with the current params + // having the same topology allows us to reuse the graph in some cases + bool allow_reuse(const llm_graph_params & other) const { + // first check the ubatch + bool can_reuse_ubatch = + ubatch.equal_seqs() == other.ubatch.equal_seqs() && + ubatch.n_tokens == other.ubatch.n_tokens && + ubatch.n_seq_tokens == other.ubatch.n_seq_tokens && + ubatch.n_seqs == other.ubatch.n_seqs && + ubatch.n_seqs_unq == other.ubatch.n_seqs_unq && + ( + (!ubatch.token && !other.ubatch.token) || + (!ubatch.embd && !other.ubatch.embd) + ); + + if (can_reuse_ubatch && !ubatch.equal_seqs()) { + if (!ubatch.data) { + // if the old ubatch does not own it's data, then we cannot guarantee that it is still alive, and + // therefore we cannot perform the sequence id check. normally should never happen + can_reuse_ubatch = false; + } else { + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + can_reuse_ubatch &= ubatch.seq_id_unq[s] == other.ubatch.seq_id_unq[s]; + } + } } - } - llm_graph_input_i * add_input(llm_graph_input_ptr input) { - inputs.emplace_back(std::move(input)); - return inputs.back().get(); + if (!can_reuse_ubatch) { + return false; + } + + return + cparams.embeddings == other.cparams.embeddings && + cparams.causal_attn == other.cparams.causal_attn && + arch == other.arch && + gtype == other.gtype && + cvec == other.cvec && + loras == other.loras && + cross == other.cross && + n_outputs == other.n_outputs; } +}; + +class llm_graph_result { +public: + llm_graph_result(int64_t max_nodes); + + virtual ~llm_graph_result() = default; + + ggml_tensor * get_tokens() const { return t_tokens; } + ggml_tensor * get_logits() const { return t_logits; } + ggml_tensor * get_embd() const { return t_embd; } + ggml_tensor * get_embd_pooled() const { return t_embd_pooled; } + + ggml_cgraph * get_gf() const { return gf; } + ggml_context * get_ctx() const { return ctx_compute.get(); } + + int64_t get_max_nodes() const; + + void reset(); + + void set_inputs(const llama_ubatch * ubatch); + + // try to update the existing graph result using the new graph parameters in order to reuse it + // this can only be done if we determine that the resulting graph using the new graph parameters + // would be identical to the existing graph. in that case, we simply have to update the memory + // contexts of the input tensors of the graph and we can reuse it for another computation + // return true if the graph was updated and can be reused + bool can_reuse(const llm_graph_params & params); + + llm_graph_input_i * add_input(llm_graph_input_ptr input); + + void set_params(const llm_graph_params & params); // important graph nodes ggml_tensor * t_tokens = nullptr; @@ -393,36 +481,31 @@ class llm_graph_result : public llm_graph_result_i { ggml_tensor * t_embd_pooled = nullptr; std::vector inputs; -}; -// -// llm_graph_context -// + ggml_context_ptr ctx_compute; -// callback that allows us to apply custom logic to each tensor (e.g. ggml-alloc, offloading, etc.) -using llm_graph_cb = std::function; + // memory buffers used to evaluate the model + std::vector buf_compute_meta; -struct llm_graph_params { - ggml_context * ctx; + ggml_cgraph * gf; - const llm_arch arch; + int64_t max_nodes; - const llama_hparams & hparams; - const llama_cparams & cparams; - const llama_ubatch & ubatch; +private: + // keep a copy of the previous graph parameters + // we will use this to determine whether the graph can be reused by comparing them with the new parameters + // note: these are updated after constructing the new graph + llm_graph_params params; - ggml_backend_sched_t sched; - ggml_backend_t backend_cpu; - - const llama_adapter_cvec * cvec; - const llama_adapter_loras * loras; - const llama_memory_context_i * mctx; - const llama_cross * cross; + // env: LLAMA_GRAPH_RESULT_DEBUG + int debug = 0; +}; - uint32_t n_outputs; +using llm_graph_result_ptr = std::unique_ptr; - const llm_graph_cb & cb; -}; +// +// llm_graph_context +// // used in build_rs to properly order writes and avoid unnecessary copies using llm_graph_get_rows_fn = std::function; @@ -463,8 +546,6 @@ struct llm_graph_context { const enum llama_pooling_type pooling_type; const enum llama_rope_type rope_type; - ggml_context * ctx0 = nullptr; - ggml_backend_sched_t sched; ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? @@ -476,7 +557,10 @@ struct llm_graph_context { const llm_graph_cb & cb_func; - std::unique_ptr res; + llm_graph_result * res; + + ggml_context * ctx0 = nullptr; + ggml_cgraph * gf = nullptr; llm_graph_context(const llm_graph_params & params); virtual ~llm_graph_context() = default; @@ -562,7 +646,6 @@ struct llm_graph_context { // ggml_tensor * build_attn_mha( - ggml_cgraph * gf, ggml_tensor * q, // [n_embd_head_q, n_head_q, n_tokens] ggml_tensor * k, // [n_embd_head_k, n_head_k, n_tokens] ggml_tensor * v, // [n_embd_head_v, n_head_v, n_tokens] (v_trans == false) @@ -575,7 +658,6 @@ struct llm_graph_context { ggml_tensor * build_attn( llm_graph_input_attn_no_cache * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] @@ -590,7 +672,6 @@ struct llm_graph_context { ggml_tensor * build_attn( llm_graph_input_attn_kv_unified * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] @@ -606,7 +687,6 @@ struct llm_graph_context { // note: if k_cur or v_cur are not provided, they will not be stored in the memory ggml_tensor * build_attn( llm_graph_input_attn_kv_unified_iswa * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] @@ -621,7 +701,6 @@ struct llm_graph_context { ggml_tensor * build_attn( llm_graph_input_attn_cross * inp, - ggml_cgraph * gf, ggml_tensor * wo, ggml_tensor * wo_b, ggml_tensor * q_cur, // [n_embd_head_q, n_head_q, n_tokens] @@ -643,7 +722,6 @@ struct llm_graph_context { // implementation in 2 separate methods. the goal is to avoid calling `ggml_build_forward_expand` in // `llama_memory_recurrent` ggml_tensor * build_rs( - ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, int32_t state_size, @@ -658,7 +736,6 @@ struct llm_graph_context { ggml_tensor * build_rs( llm_graph_input_rs * inp, - ggml_cgraph * gf, ggml_tensor * s, int32_t state_size, int32_t n_seqs, @@ -666,9 +743,8 @@ struct llm_graph_context { ggml_tensor * build_rwkv_token_shift_load( llm_graph_input_rs * inp, - ggml_cgraph * gf, const llama_ubatch & ubatch, - int il) const; + int il) const; ggml_tensor * build_rwkv_token_shift_store( ggml_tensor * token_shift, @@ -685,7 +761,6 @@ struct llm_graph_context { // void build_pooling( - ggml_cgraph * gf, ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, diff --git a/examples/talk-llama/llama-hparams.cpp b/examples/talk-llama/llama-hparams.cpp index 7aa736e2f39..c6c67d26f93 100644 --- a/examples/talk-llama/llama-hparams.cpp +++ b/examples/talk-llama/llama-hparams.cpp @@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const { return n_embd_head_v * n_head_kv; } +bool llama_hparams::is_n_embd_k_gqa_variable() const { + const uint32_t val = n_embd_k_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + if (val != n_embd_k_gqa(il)) { + return true; + } + } + + return false; +} + +bool llama_hparams::is_n_embd_v_gqa_variable() const { + const uint32_t val = n_embd_v_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + if (val != n_embd_v_gqa(il)) { + return true; + } + } + + return false; +} + +uint32_t llama_hparams::n_embd_k_gqa_max() const { + uint32_t val = n_embd_k_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + val = std::max(val, n_embd_k_gqa(il)); + } + + return val; +} + +uint32_t llama_hparams::n_embd_v_gqa_max() const { + uint32_t val = n_embd_v_gqa(); + for (uint32_t il = 0; il < n_layer; ++il) { + val = std::max(val, n_embd_v_gqa(il)); + } + + return val; +} + uint32_t llama_hparams::n_embd_r() const { if (wkv_head_size != 0) { // for RWKV models diff --git a/examples/talk-llama/llama-hparams.h b/examples/talk-llama/llama-hparams.h index d0500e4d0fd..ec7fd6a42bf 100644 --- a/examples/talk-llama/llama-hparams.h +++ b/examples/talk-llama/llama-hparams.h @@ -6,7 +6,7 @@ // bump if necessary #define LLAMA_MAX_LAYERS 512 -#define LLAMA_MAX_EXPERTS 256 // DeepSeekV3 +#define LLAMA_MAX_EXPERTS 384 // Kimi-K2 enum llama_expert_gating_func_type { LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0, @@ -98,7 +98,7 @@ struct llama_hparams { float rope_freq_scale_train; float rope_freq_scale_train_swa; uint32_t n_ctx_orig_yarn; - float rope_yarn_log_mul; + float rope_yarn_log_mul = 0.0f; std::array rope_sections; @@ -191,6 +191,14 @@ struct llama_hparams { // dimension of value embeddings across all k-v heads uint32_t n_embd_v_gqa(uint32_t il = 0) const; + // true if any layer has a different n_embd_k_gqa/n_embd_v_gqa + bool is_n_embd_k_gqa_variable() const; + bool is_n_embd_v_gqa_variable() const; + + // return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers + uint32_t n_embd_k_gqa_max() const; + uint32_t n_embd_v_gqa_max() const; + // dimension of the rolling state embeddings // corresponds to Mamba's conv_states size or RWKV's token_shift states size uint32_t n_embd_r() const; diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.cpp b/examples/talk-llama/llama-kv-cache-unified-iswa.cpp index fe207ad5360..01d27fb4db9 100644 --- a/examples/talk-llama/llama-kv-cache-unified-iswa.cpp +++ b/examples/talk-llama/llama-kv-cache-unified-iswa.cpp @@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( bool v_trans, bool offload, bool swa_full, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, - uint32_t n_pad) : hparams(model.hparams) { + uint32_t n_pad) : hparams(model.hparams), unified(unified) { llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); }; llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); }; const uint32_t size_base = kv_size; - uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad)); + uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad)); // when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size if (swa_full) { @@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa( kv_base = std::make_unique( model, std::move(filter_base), type_k, type_v, - v_trans, offload, size_base, n_seq_max, n_pad, + v_trans, offload, unified, size_base, n_seq_max, n_pad, 0, LLAMA_SWA_TYPE_NONE); LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa); kv_swa = std::make_unique( model, std::move(filter_swa), type_k, type_v, - v_trans, offload, size_swa, n_seq_max, n_pad, + v_trans, offload, unified, size_swa, n_seq_max, n_pad, hparams.n_swa, hparams.swa_type); } @@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all // first try simple split do { + if (!unified) { + // requires equal splits, so we skip the simple split + break; + } + balloc.split_reset(); std::vector ubatches; @@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all std::vector ubatches; while (true) { - auto ubatch = balloc.split_equal(n_ubatch, false); + auto ubatch = balloc.split_equal(n_ubatch, !unified); if (ubatch.n_tokens == 0) { break; diff --git a/examples/talk-llama/llama-kv-cache-unified-iswa.h b/examples/talk-llama/llama-kv-cache-unified-iswa.h index 23205d826b2..d2650dadd35 100644 --- a/examples/talk-llama/llama-kv-cache-unified-iswa.h +++ b/examples/talk-llama/llama-kv-cache-unified-iswa.h @@ -20,6 +20,7 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { bool v_trans, bool offload, bool swa_full, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_ubatch, @@ -68,6 +69,8 @@ class llama_kv_cache_unified_iswa : public llama_memory_i { private: const llama_hparams & hparams; + const bool unified; + std::unique_ptr kv_base; std::unique_ptr kv_swa; }; diff --git a/examples/talk-llama/llama-kv-cache-unified.cpp b/examples/talk-llama/llama-kv-cache-unified.cpp index d3129cc5328..321dc79fc36 100644 --- a/examples/talk-llama/llama-kv-cache-unified.cpp +++ b/examples/talk-llama/llama-kv-cache-unified.cpp @@ -23,13 +23,14 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_type type_v, bool v_trans, bool offload, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, uint32_t n_swa, llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), - n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { + n_seq_max(n_seq_max), n_stream(unified ? 1 : n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) { GGML_ASSERT(kv_size % n_pad == 0); @@ -45,7 +46,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( auto it = ctx_map.find(buft); if (it == ctx_map.end()) { ggml_init_params params = { - /*.mem_size =*/ size_t(2u*n_layer_cache*ggml_tensor_overhead()), + /*.mem_size =*/ size_t(2u*(1 + n_stream)*n_layer_cache*ggml_tensor_overhead()), /*.mem_buffer =*/ NULL, /*.no_alloc =*/ true, }; @@ -64,9 +65,33 @@ llama_kv_cache_unified::llama_kv_cache_unified( return it->second; }; - head = 0; + GGML_ASSERT(n_stream == 1 || n_stream == n_seq_max); - cells.resize(kv_size); + v_heads.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + v_heads[s] = 0; + } + + v_cells.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + v_cells[s].resize(kv_size); + } + + // by default, all sequence ids are mapped to the 0th stream + seq_to_stream.resize(LLAMA_MAX_SEQ, 0); + + if (n_stream > 1) { + seq_to_stream.resize(n_stream, 0); + for (uint32_t s = 0; s < n_stream; ++s) { + seq_to_stream[s] = s; + } + } + + // [TAG_V_CACHE_VARIABLE] + if (v_trans && hparams.is_n_embd_v_gqa_variable()) { + LLAMA_LOG_WARN("%s: the V embeddings have different sizes across layers and FA is not enabled - padding V cache to %d\n", + __func__, hparams.n_embd_v_gqa_max()); + } for (uint32_t il = 0; il < n_layer_cache; il++) { if (filter && !filter(il)) { @@ -74,8 +99,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( continue; } - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + // [TAG_V_CACHE_VARIABLE] + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const uint32_t n_embd_v_gqa = !v_trans ? hparams.n_embd_v_gqa(il) : hparams.n_embd_v_gqa_max(); const char * dev_name = "CPU"; @@ -98,14 +124,23 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_tensor * k; ggml_tensor * v; - k = ggml_new_tensor_2d(ctx, type_k, n_embd_k_gqa, kv_size); - v = ggml_new_tensor_2d(ctx, type_v, n_embd_v_gqa, kv_size); + k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); + std::vector k_stream; + std::vector v_stream; + + for (uint32_t s = 0; s < n_stream; ++s) { + k_stream.push_back(ggml_view_2d(ctx, k, n_embd_k_gqa, kv_size, k->nb[1], s*k->nb[2])); + v_stream.push_back(ggml_view_2d(ctx, v, n_embd_v_gqa, kv_size, v->nb[1], s*v->nb[2])); + } + map_layer_ids[il] = layers.size(); - layers.push_back({ il, k, v }); + + layers.push_back({ il, k, v, k_stream, v_stream, }); } // TODO: this is temporary until we support passing reuse layer filters [KV_REUSE] @@ -148,8 +183,8 @@ llama_kv_cache_unified::llama_kv_cache_unified( const size_t memory_size_k = size_k_bytes(); const size_t memory_size_v = size_v_bytes(); - LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, + LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u/%2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max, n_stream, ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } @@ -158,7 +193,12 @@ llama_kv_cache_unified::llama_kv_cache_unified( debug = LLAMA_KV_CACHE_DEBUG ? atoi(LLAMA_KV_CACHE_DEBUG) : 0; const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS"); - supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0; + supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) != 0 : 0; + + if (!supports_set_rows) { + // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + GGML_ASSERT(unified && "cannot use non-unified KV cache without ggml_set_rows() support"); + } if (!supports_set_rows) { LLAMA_LOG_WARN("%s: LLAMA_SET_ROWS=0, using old ggml_cpy() method for backwards compatibility\n", __func__); @@ -166,9 +206,10 @@ llama_kv_cache_unified::llama_kv_cache_unified( } void llama_kv_cache_unified::clear(bool data) { - cells.reset(); - - head = 0; + for (uint32_t s = 0; s < n_stream; ++s) { + v_cells[s].reset(); + v_heads[s] = 0; + } if (data) { for (auto & buf : bufs) { @@ -178,6 +219,11 @@ void llama_kv_cache_unified::clear(bool data) { } bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + uint32_t new_head = cells.size(); if (p0 < 0) { @@ -224,30 +270,94 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - if (seq_id_src == seq_id_dst) { + GGML_ASSERT(seq_id_src >= 0 && (size_t) seq_id_src < seq_to_stream.size()); + GGML_ASSERT(seq_id_dst >= 0 && (size_t) seq_id_dst < seq_to_stream.size()); + + const auto s0 = seq_to_stream[seq_id_src]; + const auto s1 = seq_to_stream[seq_id_dst]; + + if (s0 == s1) { + // since both sequences are in the same stream, no data copy is necessary + // we just have to update the cells meta data + + auto & cells = v_cells[s0]; + + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.pos_in(i, p0, p1)) { + continue; + } + + if (cells.seq_has(i, seq_id_src)) { + cells.seq_add(i, seq_id_dst); + } + } + return; } - if (p0 < 0) { - p0 = 0; + // cross-stream sequence copies require to copy the actual buffer data + + bool is_full = true; + + if (p0 > 0 && p0 + 1 < (int) get_size()) { + is_full = false; } - if (p1 < 0) { - p1 = std::numeric_limits::max(); + if (p1 > 0 && p1 + 1 < (int) get_size()) { + is_full = false; } - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.pos_in(i, p0, p1)) { - continue; - } + GGML_ASSERT(is_full && "seq_cp() is only supported for full KV buffers"); + + // enqueue the copy operation - the buffer copy will be performed during the next update + sc_info.ssrc.push_back(s0); + sc_info.sdst.push_back(s1); - if (cells.seq_has(i, seq_id_src)) { - cells.seq_add(i, seq_id_dst); + v_cells[s1].reset(); + for (uint32_t i = 0; i < v_cells[s0].size(); ++i) { + if (v_cells[s0].seq_has(i, seq_id_src)) { + llama_pos pos = v_cells[s0].pos_get(i); + llama_pos shift = v_cells[s0].get_shift(i); + + if (shift != 0) { + pos -= shift; + assert(pos >= 0); + } + + v_cells[s1].pos_set(i, pos); + v_cells[s1].seq_add(i, seq_id_dst); + + if (shift != 0) { + v_cells[s1].pos_add(i, shift); + } } } + + v_heads[s1] = v_heads[s0]; + + //for (uint32_t s = 0; s < n_stream; ++s) { + // LLAMA_LOG_WARN("%s: seq %d: min = %d, max = %d\n", __func__, s, v_cells[s].seq_pos_min(s), v_cells[s].seq_pos_max(s)); + //} } void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + uint32_t new_head = cells.size(); for (uint32_t i = 0; i < cells.size(); ++i) { @@ -265,6 +375,11 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { } void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos shift) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + auto & head = v_heads[seq_to_stream[seq_id]]; + if (shift == 0) { return; } @@ -304,6 +419,10 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po } void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[seq_id]]; + if (d == 1) { return; } @@ -333,10 +452,18 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po } llama_pos llama_kv_cache_unified::seq_pos_min(llama_seq_id seq_id) const { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + const auto & cells = v_cells[seq_to_stream[seq_id]]; + return cells.seq_pos_min(seq_id); } llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { + GGML_ASSERT(seq_id >= 0 && (size_t) seq_id < seq_to_stream.size()); + + const auto & cells = v_cells[seq_to_stream[seq_id]]; + return cells.seq_pos_max(seq_id); } @@ -351,7 +478,7 @@ llama_memory_context_ptr llama_kv_cache_unified::init_batch( std::vector ubatches; while (true) { - auto ubatch = balloc.split_simple(n_ubatch); + auto ubatch = n_stream == 1 ? balloc.split_simple(n_ubatch) : balloc.split_equal(n_ubatch, true); if (ubatch.n_tokens == 0) { break; @@ -387,7 +514,10 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct defrag_info dinfo; // see if we need to defrag - { + if (n_stream == 1) { + // note : for now do not consider defrag for n_stream > 1 + const auto & cells = v_cells[seq_to_stream[0]]; + bool do_defrag = optimize; const auto thold = lctx->get_cparams().defrag_thold; @@ -411,22 +541,22 @@ llama_memory_context_ptr llama_kv_cache_unified::init_update(llama_context * lct } } - return std::make_unique(this, lctx, do_shift, std::move(dinfo)); + return std::make_unique(this, lctx, do_shift, std::move(dinfo), std::move(sc_info)); } llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const std::vector & ubatches) { llama_kv_cache_unified::slot_info_vec_t res; - struct state { - uint32_t head_old; // old position of the head, before placing the ubatch - + struct state_t { slot_info sinfo; // slot info for the ubatch - llama_kv_cells_unified cells; // copy of the old cells, before placing the ubatch + std::vector v_heads_old; // old positions of the heads, before placing the ubatch + + std::vector v_cells; // copy of the old cells, before placing the ubatch }; // remember the old state of the cells so we can restore it in the end - std::vector states; + std::vector states; bool success = true; @@ -445,16 +575,35 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st res.push_back(sinfo_new); // store the old state of the cells in the recovery stack - states.push_back({head, sinfo_new, cells.cp(sinfo_new.idxs)}); + { + state_t state = { sinfo_new, v_heads, {} }; + + for (uint32_t s = 0; s < sinfo_new.n_stream(); ++s) { + auto & cells = v_cells[sinfo_new.strm[s]]; + + state.v_cells.push_back(cells.cp(sinfo_new.idxs[s])); + } + + states.push_back(std::move(state)); + } // now emplace the ubatch apply_ubatch(sinfo_new, ubatch); } + GGML_ASSERT(!states.empty() || !success); + // iterate backwards and restore the cells to their original state for (auto it = states.rbegin(); it != states.rend(); ++it) { - cells.set(it->sinfo.idxs, it->cells); - head = it->head_old; + const auto & sinfo = it->sinfo; + + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & cells = v_cells[sinfo.strm[s]]; + auto & head = v_heads[sinfo.strm[s]]; + + cells.set(sinfo.idxs[s], it->v_cells[s]); + head = it->v_heads_old[s]; + } } if (!success) { @@ -464,11 +613,38 @@ llama_kv_cache_unified::slot_info_vec_t llama_kv_cache_unified::prepare(const st return res; } -bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo) { +bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info) { bool updated = false; auto * sched = lctx->get_sched(); + if (!sc_info.empty()) { + assert(n_stream > 1 && "stream copy should never happen with a single stream"); + + llama_synchronize(lctx); + + const size_t n_copy = sc_info.ssrc.size(); + + for (size_t i = 0; i < n_copy; ++i) { + const auto ssrc = sc_info.ssrc[i]; + const auto sdst = sc_info.sdst[i]; + + assert(ssrc < n_stream); + assert(sdst < n_stream); + + LLAMA_LOG_DEBUG("%s: copying KV buffer: stream %d to stream %d\n", __func__, ssrc, sdst); + + assert(ssrc != sdst); + + for (uint32_t il = 0; il < layers.size(); ++il) { + const auto & layer = layers[il]; + + ggml_backend_tensor_copy(layer.k_stream[ssrc], layer.k_stream[sdst]); + ggml_backend_tensor_copy(layer.v_stream[ssrc], layer.v_stream[sdst]); + } + } + } + if (do_shift) { if (!get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); @@ -480,14 +656,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { ggml_backend_sched_reset(sched); - auto * gf = lctx->graph_init(); + auto * res = lctx->get_gf_res_reserve(); - auto res = build_graph_shift(lctx->get_cparams(), lctx->get_ctx_compute(), gf); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph for K-shift\n", __func__); - return updated; - } + res->reset(); + auto * gf = build_graph_shift(res, lctx); if (!ggml_backend_sched_alloc_graph(sched, gf)) { LLAMA_LOG_ERROR("%s: failed to allocate compute graph for K-shift\n", __func__); return updated; @@ -503,12 +676,20 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d updated = true; } - cells.reset_shift(); + for (uint32_t s = 0; s < n_stream; ++s) { + auto & cells = v_cells[s]; + + cells.reset_shift(); + } } if (!dinfo.empty()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + // note: for now do not consider defrag for n_stream > 1 + auto & cells = v_cells[seq_to_stream[0]]; + auto & head = v_heads[seq_to_stream[0]]; + // apply moves: { const auto n_kv = dinfo.ids.size(); @@ -529,14 +710,11 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d ggml_backend_sched_reset(sched); - auto * gf = lctx->graph_init(); + auto * res = lctx->get_gf_res_reserve(); - auto res = build_graph_defrag(lctx->get_cparams(), lctx->get_ctx_compute(), gf, dinfo); - if (!res) { - LLAMA_LOG_ERROR("%s: failed to build graph for defrag\n", __func__); - return updated; - } + res->reset(); + auto * gf = build_graph_defrag(res, lctx, dinfo); if (!ggml_backend_sched_alloc_graph(sched, gf)) { LLAMA_LOG_ERROR("%s: failed to allocate compute graph for defrag\n", __func__); return updated; @@ -556,23 +734,13 @@ bool llama_kv_cache_unified::update(llama_context * lctx, bool do_shift, const d } llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch, bool cont) const { - const uint32_t n_tokens = ubatch.n_tokens; + if (debug > 0) { + const auto & cells = v_cells[seq_to_stream[1]]; - uint32_t head_cur = this->head; + const uint32_t head_cur = v_heads[1]; - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head_cur > cells.get_used() + 2*ubatch.n_tokens) { - head_cur = 0; - } - - if (n_tokens > cells.size()) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); - return { }; - } - - if (debug > 0) { - LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", __func__, cells.used_max_p1(), cells.get_used(), head, get_size(), n_swa); + LLAMA_LOG_DEBUG("%s: n = %5d, used = %5d, head = %5d, size = %5d, n_swa = %5d\n", + __func__, cells.used_max_p1(), cells.get_used(), head_cur, get_size(), n_swa); if ((debug == 2 && n_swa > 0) || debug > 2) { std::string ss; @@ -629,86 +797,133 @@ llama_kv_cache_unified::slot_info llama_kv_cache_unified::find_slot(const llama_ } } - uint32_t n_tested = 0; + uint32_t n_tokens = ubatch.n_tokens; + uint32_t n_seqs = 1; + + if (n_stream > 1) { + GGML_ASSERT(n_tokens % ubatch.n_seqs_unq == 0); - // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head - // for non-continuous slots, we test the tokens one by one - const uint32_t n_test = cont ? n_tokens : 1; + n_seqs = ubatch.n_seqs_unq; + n_tokens = n_tokens / n_seqs; + } - slot_info res; + slot_info res = { + /*.s0 =*/ LLAMA_MAX_SEQ, + /*.s1 =*/ 0, + /*.strm =*/ { }, + /*.idxs =*/ { }, + }; - auto & idxs = res.idxs; + res.resize(n_seqs); - idxs.reserve(n_tokens); + for (uint32_t s = 0; s < n_seqs; ++s) { + const auto seq_id = ubatch.seq_id_unq[s]; - while (true) { - if (head_cur + n_test > cells.size()) { - n_tested += cells.size() - head_cur; + if (n_stream > 1) { + GGML_ASSERT(ubatch.n_seq_id[s*n_tokens] == 1); + GGML_ASSERT(ubatch.seq_id [s*n_tokens][0] == seq_id); + } + + res.s0 = std::min(res.s0, seq_to_stream[seq_id]); + res.s1 = std::max(res.s1, seq_to_stream[seq_id]); + + res.strm[s] = seq_to_stream[seq_id]; + res.idxs[s].reserve(n_tokens); + + const auto & cells = v_cells[seq_to_stream[seq_id]]; + + uint32_t head_cur = v_heads[seq_to_stream[seq_id]]; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head_cur > cells.get_used() + 2*n_tokens) { head_cur = 0; - continue; } - for (uint32_t i = 0; i < n_test; i++) { - const auto idx = head_cur; + if (n_tokens > cells.size()) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %u\n", __func__, n_tokens, cells.size()); + return { }; + } + + uint32_t n_tested = 0; + + // for continuous slots, we test that all tokens in the ubatch fit, starting from the current head + // for non-continuous slots, we test the tokens one by one + const uint32_t n_test = cont ? n_tokens : 1; - //const llama_pos pos = ubatch.pos[i]; - //const llama_seq_id seq_id = ubatch.seq_id[i][0]; + while (true) { + if (head_cur + n_test > cells.size()) { + n_tested += cells.size() - head_cur; + head_cur = 0; + continue; + } - // can we use this cell? either: - // - the cell is empty - // - the cell is occupied only by one sequence: - // - (disabled) mask causally, if the sequence is the same as the one we are inserting - // - mask SWA, using current max pos for that sequence in the cache - // always insert in the cell with minimum pos - bool can_use = cells.is_empty(idx); + for (uint32_t i = 0; i < n_test; i++) { + const auto idx = head_cur; - if (!can_use && cells.seq_count(idx) == 1) { - const llama_pos pos_cell = cells.pos_get(idx); + head_cur++; + n_tested++; - // (disabled) causal mask - // note: it's better to purge any "future" tokens beforehand - //if (cells.seq_has(idx, seq_id)) { - // can_use = pos_cell >= pos; - //} + //const llama_pos pos = ubatch.pos[i]; + //const llama_seq_id seq_id = ubatch.seq_id[i][0]; - if (!can_use) { - const llama_seq_id seq_id_cell = cells.seq_get(idx); + // can we use this cell? either: + // - the cell is empty + // - the cell is occupied only by one sequence: + // - (disabled) mask causally, if the sequence is the same as the one we are inserting + // - mask SWA, using current max pos for that sequence in the cache + // always insert in the cell with minimum pos + bool can_use = cells.is_empty(idx); - // SWA mask - if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { - can_use = true; + if (!can_use && cells.seq_count(idx) == 1) { + const llama_pos pos_cell = cells.pos_get(idx); + + // (disabled) causal mask + // note: it's better to purge any "future" tokens beforehand + //if (cells.seq_has(idx, seq_id)) { + // can_use = pos_cell >= pos; + //} + + if (!can_use) { + const llama_seq_id seq_id_cell = cells.seq_get(idx); + + // SWA mask + if (is_masked_swa(pos_cell, cells.seq_pos_max(seq_id_cell) + 1)) { + can_use = true; + } } } - } - head_cur++; - n_tested++; + if (can_use) { + res.idxs[s].push_back(idx); + } else { + if (cont) { + break; + } + } + } - if (can_use) { - idxs.push_back(idx); - } else { + if (res.idxs[s].size() == n_tokens) { break; } - } - if (idxs.size() == n_tokens) { - break; - } + if (cont) { + res.idxs[s].clear(); + } - if (cont) { - idxs.clear(); + if (n_tested >= cells.size()) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return { }; + } } - if (n_tested >= cells.size()) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + // we didn't find a suitable slot - return empty result + if (res.idxs[s].size() < n_tokens) { return { }; } } - // we didn't find a suitable slot - return empty result - if (idxs.size() < n_tokens) { - res.clear(); - } + assert(res.s1 >= res.s0); return res; } @@ -717,41 +932,51 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u // keep track of the max sequence position that we would overwrite with this ubatch // for non-SWA cache, this would be always empty llama_seq_id seq_pos_max_rm[LLAMA_MAX_SEQ]; - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { seq_pos_max_rm[s] = -1; } - assert(ubatch.n_tokens == sinfo.idxs.size()); + assert(ubatch.n_tokens == sinfo.n_stream()*sinfo.size()); - for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { - const auto idx = sinfo.idxs.at(i); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + for (uint32_t ii = 0; ii < sinfo.size(); ++ii) { + const uint32_t i = s*sinfo.size() + ii; - if (!cells.is_empty(idx)) { - assert(cells.seq_count(idx) == 1); + auto & cells = v_cells[sinfo.strm[s]]; - const llama_seq_id seq_id = cells.seq_get(idx); - const llama_pos pos = cells.pos_get(idx); + const auto idx = sinfo.idxs[s][ii]; - seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + if (!cells.is_empty(idx)) { + assert(cells.seq_count(idx) == 1); - cells.rm(idx); - } + const llama_seq_id seq_id = cells.seq_get(idx); + const llama_pos pos = cells.pos_get(idx); - cells.pos_set(idx, ubatch.pos[i]); + seq_pos_max_rm[seq_id] = std::max(seq_pos_max_rm[seq_id], pos); + + cells.rm(idx); + } - for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { - cells.seq_add(idx, ubatch.seq_id[i][s]); + cells.pos_set(idx, ubatch.pos[i]); + + for (int32_t s = 0; s < ubatch.n_seq_id[i]; s++) { + cells.seq_add(idx, ubatch.seq_id[i][s]); + } } } // note: we want to preserve the invariant that all positions between [pos_min, pos_max] for each sequence // will be present in the cache. so we have to purge any position which is less than those we would overwrite // ref: https://github.com/ggml-org/llama.cpp/pull/13746#issuecomment-2916057092 - for (int s = 0; s < LLAMA_MAX_SEQ; ++s) { + for (uint32_t s = 0; s < LLAMA_MAX_SEQ; ++s) { if (seq_pos_max_rm[s] == -1) { continue; } + GGML_ASSERT(s < seq_to_stream.size()); + + auto & cells = v_cells[seq_to_stream[s]]; + if (cells.seq_pos_min(s) <= seq_pos_max_rm[s]) { LLAMA_LOG_DEBUG("%s: purging positions [%d, %d] of sequence %d from KV cache\n", __func__, cells.seq_pos_min(s), seq_pos_max_rm[s], s); @@ -761,7 +986,11 @@ void llama_kv_cache_unified::apply_ubatch(const slot_info & sinfo, const llama_u } // move the head at the end of the slot - head = sinfo.idxs.back() + 1; + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + auto & head = v_heads[sinfo.strm[s]]; + + head = sinfo.idxs[s].back() + 1; + } } bool llama_kv_cache_unified::get_can_shift() const { @@ -769,49 +998,91 @@ bool llama_kv_cache_unified::get_can_shift() const { } uint32_t llama_kv_cache_unified::get_size() const { + const auto & cells = v_cells[seq_to_stream[0]]; + return cells.size(); } +uint32_t llama_kv_cache_unified::get_n_stream() const { + return n_stream; +} + bool llama_kv_cache_unified::get_has_shift() const { - return cells.get_has_shift(); + bool result = false; + + for (uint32_t s = 0; s < n_stream; ++s) { + result |= v_cells[s].get_has_shift(); + } + + return result; } uint32_t llama_kv_cache_unified::get_n_kv() const { - return std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))); + uint32_t result = 0; + + for (uint32_t s = 0; s < n_stream; ++s) { + const auto & cells = v_cells[s]; + + result = std::max(std::min(cells.size(), std::max(n_pad, GGML_PAD(cells.used_max_p1(), n_pad))), result); + } + + return result; } -ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const { +bool llama_kv_cache_unified::get_supports_set_rows() const { + return supports_set_rows; +} + +ggml_tensor * llama_kv_cache_unified::get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * k = layers[ikv].k; - return ggml_view_3d(ctx, k, - hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, + const uint64_t kv_size = get_size(); + const uint64_t n_embd_k_gqa = k->ne[0]; + + assert(n_embd_k_gqa == hparams.n_embd_k_gqa(il)); + + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + + return ggml_view_4d(ctx, k, + hparams.n_embd_head_k, hparams.n_head_kv(il), n_kv, ns, ggml_row_size(k->type, hparams.n_embd_head_k), - ggml_row_size(k->type, hparams.n_embd_k_gqa(il)), - 0); + ggml_row_size(k->type, n_embd_k_gqa), + ggml_row_size(k->type, n_embd_k_gqa*kv_size), + ggml_row_size(k->type, n_embd_k_gqa*kv_size)*sinfo.s0); } -ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const { +ggml_tensor * llama_kv_cache_unified::get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const { const int32_t ikv = map_layer_ids.at(il); auto * v = layers[ikv].v; + const uint64_t kv_size = get_size(); + const uint64_t n_embd_v_gqa = v->ne[0]; + + // [TAG_V_CACHE_VARIABLE] + assert(n_embd_v_gqa >= hparams.n_embd_v_gqa(il)); + + const uint32_t ns = sinfo.s1 - sinfo.s0 + 1; + if (!v_trans) { // note: v->nb[1] <= v->nb[2] - return ggml_view_3d(ctx, v, - hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, - ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, hparams.n_embd_v_gqa(il)), // v->nb[2] - 0); + return ggml_view_4d(ctx, v, + hparams.n_embd_head_v, hparams.n_head_kv(il), n_kv, ns, + ggml_row_size(v->type, hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, n_embd_v_gqa), // v->nb[2] + ggml_row_size(v->type, n_embd_v_gqa*kv_size), // v->nb[3] + ggml_row_size(v->type, n_embd_v_gqa*kv_size)*sinfo.s0); } // note: v->nb[1] > v->nb[2] - return ggml_view_3d(ctx, v, - n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, - ggml_row_size(v->type, v->ne[1]*hparams.n_embd_head_v), // v->nb[1] - ggml_row_size(v->type, v->ne[1]), // v->nb[2] - 0); + return ggml_view_4d(ctx, v, + n_kv, hparams.n_head_kv(il), hparams.n_embd_head_v, ns, + ggml_row_size(v->type, kv_size*hparams.n_embd_head_v), // v->nb[1] + ggml_row_size(v->type, kv_size), // v->nb[2] + ggml_row_size(v->type, kv_size*n_embd_v_gqa), // v->nb[3] + ggml_row_size(v->type, kv_size*n_embd_v_gqa)*sinfo.s0); } ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const { @@ -825,12 +1096,18 @@ ggml_tensor * llama_kv_cache_unified::cpy_k(ggml_context * ctx, ggml_tensor * k_ k_cur = ggml_reshape_2d(ctx, k_cur, k->ne[0], n_tokens); if (k_idxs && supports_set_rows) { + if (k->ne[2] > 1) { + k = ggml_reshape_2d(ctx, k, k->ne[0], k->ne[1]*k->ne[2]); + } + return ggml_set_rows(ctx, k, k_cur, k_idxs); } // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + ggml_tensor * k_view = ggml_view_1d(ctx, k, n_tokens*n_embd_k_gqa, ggml_row_size(k->type, n_embd_k_gqa)*sinfo.head()); @@ -843,37 +1120,38 @@ ggml_tensor * llama_kv_cache_unified::cpy_v(ggml_context * ctx, ggml_tensor * v_ auto * v = layers[ikv].v; - const int64_t n_embd_v_gqa = v->ne[0]; - const int64_t n_tokens = v_cur->ne[2]; + const int64_t n_embd_v_gqa = v_cur->ne[0]*v_cur->ne[1]; + const int64_t n_tokens = v_cur->ne[2]; v_cur = ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens); if (v_idxs && supports_set_rows) { if (!v_trans) { + if (v->ne[2] > 1) { + v = ggml_reshape_2d(ctx, v, v->ne[0], v->ne[1]*v->ne[2]); + } + return ggml_set_rows(ctx, v, v_cur, v_idxs); } - // the row becomes a single element - ggml_tensor * v_view = ggml_reshape_3d(ctx, v, 1, v->ne[1], v->ne[0]); + // [TAG_V_CACHE_VARIABLE] + if (n_embd_v_gqa < v->ne[0]) { + v_cur = ggml_pad(ctx, v_cur, v->ne[0] - n_embd_v_gqa, 0, 0, 0); + } - // note: the V cache is transposed when not using flash attention - v_cur = ggml_permute(ctx, ggml_reshape_3d(ctx, v_cur, v_cur->ne[0], 1, v_cur->ne[1]), 2, 0, 1, 3); + // the row becomes a single element + ggml_tensor * v_view = ggml_reshape_2d(ctx, v, 1, v->ne[0]*v->ne[1]*v->ne[2]); - // note: we can be more explicit here at the cost of extra cont - // however, above we take advantage that a row of single element is always continuous regardless of the row stride - //v_cur = ggml_transpose(ctx, v_cur); - //v_cur = ggml_cont_3d(ctx, v_cur, 1, v_cur->ne[0], v_cur->ne[1]); + v_cur = ggml_reshape_2d(ctx, v_cur, 1, v_cur->ne[0]*v_cur->ne[1]); - // we broadcast the KV indices n_embd_v_gqa times - // v [1, n_kv, n_embd_v_gqa] - // v_cur [1, n_tokens, n_embd_v_gqa] - // v_idxs [n_tokens, 1, 1] return ggml_set_rows(ctx, v_view, v_cur, v_idxs); } // TODO: fallback to old ggml_cpy() method for backwards compatibility // will be removed when ggml_set_rows() is adopted by all backends + GGML_ASSERT(n_stream == 1 && "n_stream > 1 not supported without LLAMA_SET_ROWS"); + ggml_tensor * v_view = nullptr; if (!v_trans) { @@ -904,7 +1182,13 @@ ggml_tensor * llama_kv_cache_unified::build_input_k_idxs(ggml_context * ctx, con ggml_tensor * llama_kv_cache_unified::build_input_v_idxs(ggml_context * ctx, const llama_ubatch & ubatch) const { const uint32_t n_tokens = ubatch.n_tokens; - ggml_tensor * v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + ggml_tensor * v_idxs; + + if (!v_trans) { + v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens); + } else { + v_idxs = ggml_new_tensor_1d(ctx, GGML_TYPE_I64, n_tokens*hparams.n_embd_v_gqa_max()); + } ggml_set_input(v_idxs); @@ -917,12 +1201,17 @@ void llama_kv_cache_unified::set_input_k_idxs(ggml_tensor * dst, const llama_uba } const uint32_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs.at(i); + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*get_size(); + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; + } } } @@ -932,12 +1221,48 @@ void llama_kv_cache_unified::set_input_v_idxs(ggml_tensor * dst, const llama_uba } const uint32_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_tokens == (int64_t) sinfo.size()*sinfo.n_stream()); GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); int64_t * data = (int64_t *) dst->data; - for (int64_t i = 0; i < n_tokens; ++i) { - data[i] = sinfo.idxs.at(i); + if (!v_trans) { + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*get_size(); + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + data[s*sinfo.size() + i] = offs + sinfo.idxs[s][i]; + } + } + } else { + // note: the V cache is transposed when not using flash attention + const int64_t kv_size = get_size(); + + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa_max(); + + for (uint32_t s = 0; s < sinfo.n_stream(); ++s) { + const int64_t offs = sinfo.strm[s]*kv_size*n_embd_v_gqa; + + for (uint32_t i = 0; i < sinfo.size(); ++i) { + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + data[s*sinfo.size()*n_embd_v_gqa + i*n_embd_v_gqa + j] = offs + j*kv_size + sinfo.idxs[s][i]; + } + } + } + } +} + +void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); + + int32_t * data = (int32_t *) dst->data; + + for (uint32_t s = 0; s < n_stream; ++s) { + const auto & cells = v_cells[s]; + + for (uint32_t i = 0; i < cells.size(); ++i) { + data[s*cells.size() + i] = cells.is_empty(i) ? 0 : cells.get_shift(i); + } } } @@ -947,7 +1272,16 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); float * data = (float *) dst->data; - const int64_t n_kv = dst->ne[0]; + const int64_t n_kv = dst->ne[0]; + const int64_t n_stream = dst->ne[3]; // num streams in the current ubatch + + GGML_ASSERT(n_tokens%n_stream == 0); + + // n_tps == n_tokens_per_stream + const int64_t n_tps = n_tokens/n_stream; + const int64_t n_tps_pad = GGML_PAD(n_tps, GGML_KQ_MASK_PAD); + + std::fill(data, data + ggml_nelements(dst), -INFINITY); // Use only the previous KV cells of the correct sequence for each token of the ubatch. // It's assumed that if a token in the batch has multiple sequences, they are equivalent. @@ -961,70 +1295,57 @@ void llama_kv_cache_unified::set_input_kq_mask(ggml_tensor * dst, const llama_ub // xxxxx----- // xxxxx----- // To visualize the mask, see https://github.com/ggml-org/llama.cpp/pull/12615 + // TODO: optimize this section for (uint32_t h = 0; h < 1; ++h) { - for (uint32_t i = 0; i < n_tokens; ++i) { - const llama_seq_id seq_id = ubatch->seq_id[i][0]; + for (uint32_t s = 0; s < n_stream; ++s) { + for (uint32_t ii = 0; ii < n_tps; ++ii) { + const uint32_t i = s*n_tps + ii; - const llama_pos p1 = ubatch->pos[i]; + const llama_seq_id seq_id = ubatch->seq_id[i][0]; - for (uint32_t j = 0; j < n_kv; ++j) { - float f = 0.0f; + const auto & cells = v_cells[seq_to_stream[seq_id]]; - bool masked = false; + const llama_pos p1 = ubatch->pos[i]; - if (cells.is_empty(j)) { - masked = true; - } else { - const llama_pos p0 = cells.pos_get(j); + const uint64_t idst = n_kv*(h*n_stream*n_tps_pad + s*n_tps_pad + ii); + + for (uint32_t j = 0; j < n_kv; ++j) { + if (cells.is_empty(j)) { + continue; + } // mask the token if not the same sequence - masked = masked || (!cells.seq_has(j, seq_id)); + if (!cells.seq_has(j, seq_id)) { + continue; + } + + const llama_pos p0 = cells.pos_get(j); // mask future tokens - masked = masked || (causal_attn && p0 > p1); + if (causal_attn && p0 > p1) { + continue; + } // apply SWA if any - masked = masked || (is_masked_swa(p0, p1)); - - if (!masked && hparams.use_alibi) { - f = -std::abs(p0 - p1); + if (is_masked_swa(p0, p1)) { + continue; } - } - - if (masked) { - f = -INFINITY; - } - - data[h*(n_kv*n_tokens) + i*n_kv + j] = f; - } - } - // mask padded tokens - if (data) { - for (uint32_t i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) { - for (uint32_t j = 0; j < n_kv; ++j) { - data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY; + data[idst + j] = hparams.use_alibi ? -std::abs(p0 - p1) : 0.0f; } } } } } -void llama_kv_cache_unified::set_input_k_shift(ggml_tensor * dst) const { - GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - - int32_t * data = (int32_t *) dst->data; - - for (uint32_t i = 0; i < cells.size(); ++i) { - data[i] = cells.is_empty(i) ? 0 : cells.get_shift(i); - } -} - void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const { const int64_t n_tokens = ubatch->n_tokens; + GGML_ASSERT(n_stream == 1 && "TODO: support multiple streams"); + const auto & cells = v_cells[0]; + GGML_ASSERT(ggml_backend_buffer_is_host(dst->buffer)); - GGML_ASSERT(!ubatch->equal_seqs); // TODO: use ubatch->n_seqs instead of failing + GGML_ASSERT(!ubatch->equal_seqs()); // TODO: use ubatch->n_seqs instead of failing int32_t * data = (int32_t *) dst->data; @@ -1129,7 +1450,7 @@ class llm_graph_input_k_shift : public llm_graph_input_i { void set_input(const llama_ubatch * ubatch) override; - ggml_tensor * k_shift; // I32 [kv_size] + ggml_tensor * k_shift; // I32 [kv_size*n_stream] const llama_kv_cache_unified * kv_self; }; @@ -1142,20 +1463,20 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } } -llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const { - auto res = std::make_unique(); +ggml_cgraph * llama_kv_cache_unified::build_graph_shift(llm_graph_result * res, llama_context * lctx) const { + auto * ctx = res->get_ctx(); + auto * gf = res->get_gf(); const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; auto inp = std::make_unique(this); - inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cells.size()); + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); ggml_set_input(inp->k_shift); + const auto & cparams = lctx->get_cparams(); + for (const auto & layer : layers) { const uint32_t il = layer.il; @@ -1169,7 +1490,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, cells.size(), + n_embd_head_k, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), 0); @@ -1181,18 +1502,24 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( res->add_input(std::move(inp)); - return res; + return gf; } -llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf, - const defrag_info & dinfo) const { - auto res = std::make_unique(); +ggml_cgraph * llama_kv_cache_unified::build_graph_defrag( + llm_graph_result * res, + llama_context * lctx, + const defrag_info & dinfo) const { + auto * ctx = res->get_ctx(); + auto * gf = res->get_gf(); + + GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); + + const auto & cells = v_cells[0]; const auto & ids = dinfo.ids; + const auto & cparams = lctx->get_cparams(); + #if 0 // CPU defrag // @@ -1329,10 +1656,14 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); #endif - return res; + return gf; } llama_kv_cache_unified::defrag_info llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) const { + GGML_ASSERT(n_stream == 1 && "n_stream > 1 does not support defrag"); + + const auto & cells = v_cells[0]; + const uint32_t n_layer = layers.size(); const uint32_t n_kv = cells.used_max_p1(); @@ -1478,64 +1809,94 @@ bool llama_kv_cache_unified::is_masked_swa(llama_pos p0, llama_pos p1) const { } void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { - std::vector> cell_ranges; // ranges, from inclusive, to exclusive - uint32_t cell_count = 0; + io.write(&n_stream, sizeof(n_stream)); - // Count the number of cells with the specified seq_id - // Find all the ranges of cells with this seq id (or all, when -1) - uint32_t cell_range_begin = cells.size(); + for (uint32_t s = 0; s < n_stream; ++s) { + cell_ranges_t cr { s, {} }; - for (uint32_t i = 0; i < cells.size(); ++i) { - if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { - ++cell_count; - if (cell_range_begin == cells.size()) { - cell_range_begin = i; - } - } else { - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, i); - cell_range_begin = cells.size(); + uint32_t cell_count = 0; + + const auto & cells = v_cells[s]; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = cells.size(); + + for (uint32_t i = 0; i < cells.size(); ++i) { + if (!cells.is_empty(i) && (seq_id == -1 || cells.seq_has(i, seq_id))) { + ++cell_count; + if (cell_range_begin == cells.size()) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != cells.size()) { + cr.data.emplace_back(cell_range_begin, i); + cell_range_begin = cells.size(); + } } } - } - if (cell_range_begin != cells.size()) { - cell_ranges.emplace_back(cell_range_begin, cells.size()); - } + if (cell_range_begin != cells.size()) { + cr.data.emplace_back(cell_range_begin, cells.size()); + } - // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count - uint32_t cell_count_check = 0; - for (const auto & range : cell_ranges) { - cell_count_check += range.second - range.first; - } - GGML_ASSERT(cell_count == cell_count_check); + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cr.data) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); - io.write(&cell_count, sizeof(cell_count)); + io.write(&cell_count, sizeof(cell_count)); - state_write_meta(io, cell_ranges, seq_id); - state_write_data(io, cell_ranges); + // skip empty streams + if (cell_count == 0) { + continue; + } + + state_write_meta(io, cr, seq_id); + state_write_data(io, cr); + } } void llama_kv_cache_unified::state_read(llama_io_read_i & io, llama_seq_id seq_id) { - uint32_t cell_count; - io.read_to(&cell_count, sizeof(cell_count)); + GGML_ASSERT(seq_id == -1 || (seq_id >= 0 && (size_t) seq_id < seq_to_stream.size())); - bool res = true; - res = res && state_read_meta(io, cell_count, seq_id); - res = res && state_read_data(io, cell_count); + uint32_t n_stream_cur; + io.read_to(&n_stream_cur, sizeof(n_stream_cur)); + if (n_stream_cur != n_stream) { + throw std::runtime_error("n_stream mismatch"); + } + + for (uint32_t s = 0; s < n_stream; ++s) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + if (cell_count == 0) { + continue; + } - if (!res) { - if (seq_id == -1) { - clear(true); - } else { - seq_rm(seq_id, -1, -1); + const uint32_t strm = seq_id == -1 ? s : seq_to_stream[seq_id]; + + bool res = true; + res = res && state_read_meta(io, strm, cell_count, seq_id); + res = res && state_read_data(io, strm, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(true); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); } - throw std::runtime_error("failed to restore kv cache"); } } -void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { - for (const auto & range : cell_ranges) { +void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id) const { + const auto & cells = v_cells[cr.strm]; + + for (const auto & range : cr.data) { for (uint32_t i = range.first; i < range.second; ++i) { std::vector seq_ids; @@ -1560,7 +1921,9 @@ void llama_kv_cache_unified::state_write_meta(llama_io_write_i & io, const std:: } } -void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { +void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const { + const auto & cells = v_cells[cr.strm]; + const uint32_t v_trans = this->v_trans ? 1 : 0; const uint32_t n_layer = layers.size(); @@ -1576,19 +1939,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + auto * k = layer.k_stream[cr.strm]; + // Write key type - const int32_t k_type_i = (int32_t)layer.k->type; + const int32_t k_type_i = (int32_t) k->type; io.write(&k_type_i, sizeof(k_type_i)); // Write row size of key - const uint64_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + const uint64_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); io.write(&k_size_row, sizeof(k_size_row)); // Read each range of cells of k_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * k_size_row; - io.write_tensor(layer.k, range.first * k_size_row, buf_size); + io.write_tensor(k, range.first * k_size_row, buf_size); } } @@ -1598,19 +1963,21 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[cr.strm]; + // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; io.write(&v_type_i, sizeof(v_type_i)); // Write row size of value - const uint64_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + const uint64_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); io.write(&v_size_row, sizeof(v_size_row)); // Read each range of cells of v_size length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t buf_size = range_size * v_size_row; - io.write_tensor(layer.v, range.first * v_size_row, buf_size); + io.write_tensor(v, range.first * v_size_row, buf_size); } } } else { @@ -1622,12 +1989,14 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[cr.strm]; + // Write value type - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; io.write(&v_type_i, sizeof(v_type_i)); // Write element size - const uint32_t v_size_el = ggml_type_size(layer.v->type); + const uint32_t v_size_el = ggml_type_size(v->type); io.write(&v_size_el, sizeof(v_size_el)); // Write GQA embedding size @@ -1636,27 +2005,31 @@ void llama_kv_cache_unified::state_write_data(llama_io_write_i & io, const std:: // For each row, we get the element values of each cell for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { // Read each range of cells of v_size_el length each into tmp_buf and write out - for (const auto & range : cell_ranges) { + for (const auto & range : cr.data) { const size_t range_size = range.second - range.first; const size_t src_offset = (range.first + j * kv_size) * v_size_el; const size_t buf_size = range_size * v_size_el; - io.write_tensor(layer.v, src_offset, buf_size); + io.write_tensor(v, src_offset, buf_size); } } } } } -bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { +bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id) { + auto & cells = v_cells[strm]; + auto & head = v_heads[strm]; + if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); + ubatch.seq_id_unq[0] = dest_seq_id; + for (uint32_t i = 0; i < cell_count; ++i) { llama_pos pos; uint32_t n_seq_id; @@ -1693,6 +2066,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell // keep the head at the old position because we will read the KV data into it in state_read_data() head = head_cur; + LLAMA_LOG_DEBUG("%s: head_cur = %d, head = %d, cell_count = %d, dest_seq_id = %d\n", __func__, head_cur, head, cell_count, dest_seq_id); + // DEBUG CHECK: head_cur should be our first cell, head_cur + cell_count - 1 should be our last cell (verify seq_id and pos values) // Assume that this is one contiguous block of cells GGML_ASSERT(head_cur + cell_count <= cells.size()); @@ -1738,7 +2113,10 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell return true; } -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count) { + auto & cells = v_cells[strm]; + auto & head = v_heads[strm]; + uint32_t v_trans; uint32_t n_layer; @@ -1766,10 +2144,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + auto * k = layer.k_stream[strm]; + // Read type of key int32_t k_type_i_ref; io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); - const int32_t k_type_i = (int32_t) layer.k->type; + const int32_t k_type_i = (int32_t) k->type; if (k_type_i != k_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); return false; @@ -1778,7 +2158,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of key uint64_t k_size_row_ref; io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); - const size_t k_size_row = ggml_row_size(layer.k->type, n_embd_k_gqa); + const size_t k_size_row = ggml_row_size(k->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; @@ -1786,7 +2166,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the keys for the whole cell range - ggml_backend_tensor_set(layer.k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + ggml_backend_tensor_set(k, io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); } } @@ -1796,10 +2176,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[strm]; + // Read type of value int32_t v_type_i_ref; io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; @@ -1808,7 +2190,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read row size of value uint64_t v_size_row_ref; io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); - const size_t v_size_row = ggml_row_size(layer.v->type, n_embd_v_gqa); + const size_t v_size_row = ggml_row_size(v->type, n_embd_v_gqa); if (v_size_row != v_size_row_ref) { LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); return false; @@ -1816,7 +2198,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell if (cell_count) { // Read and set the values for the whole cell range - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + ggml_backend_tensor_set(v, io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); } } } else { @@ -1826,10 +2208,12 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + auto * v = layer.v_stream[strm]; + // Read type of value int32_t v_type_i_ref; io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); - const int32_t v_type_i = (int32_t)layer.v->type; + const int32_t v_type_i = (int32_t) v->type; if (v_type_i != v_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; @@ -1838,7 +2222,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // Read element size of value uint32_t v_size_el_ref; io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); - const size_t v_size_el = ggml_type_size(layer.v->type); + const size_t v_size_el = ggml_type_size(v->type); if (v_size_el != v_size_el_ref) { LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); return false; @@ -1856,7 +2240,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // For each row in the transposed matrix, read the values for the whole cell range for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { const size_t dst_offset = (head + j * cells.size()) * v_size_el; - ggml_backend_tensor_set(layer.v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + ggml_backend_tensor_set(v, io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); } } } @@ -1875,18 +2259,26 @@ llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv) { n_kv = kv->get_size(); + const uint32_t n_stream = kv->get_n_stream(); + // create a dummy slot info - the actual data is irrelevant. we just need to build the graph sinfos.resize(1); - sinfos[0].idxs.resize(1); - sinfos[0].idxs[0] = 0; + sinfos[0].s0 = 0; + sinfos[0].s1 = n_stream - 1; + sinfos[0].idxs.resize(n_stream); + for (uint32_t s = 0; s < n_stream; ++s) { + sinfos[0].strm.push_back(s); + sinfos[0].idxs[s].resize(1, 0); + } } llama_kv_cache_unified_context::llama_kv_cache_unified_context( llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)) { - if (!do_shift && this->dinfo.empty()) { + defrag_info dinfo, + stream_copy_info sc_info) : status(LLAMA_MEMORY_STATUS_SUCCESS), kv(kv), lctx(lctx), do_shift(do_shift), dinfo(std::move(dinfo)), sc_info(std::move(sc_info)) { + if (!do_shift && this->dinfo.empty() && this->sc_info.empty()) { status = LLAMA_MEMORY_STATUS_NO_UPDATE; } } @@ -1914,7 +2306,7 @@ bool llama_kv_cache_unified_context::apply() { // no ubatches -> this is a KV cache update if (ubatches.empty()) { - kv->update(lctx, do_shift, dinfo); + kv->update(lctx, do_shift, dinfo, sc_info); return true; } @@ -1940,12 +2332,16 @@ uint32_t llama_kv_cache_unified_context::get_n_kv() const { return n_kv; } +bool llama_kv_cache_unified_context::get_supports_set_rows() const { + return kv->get_supports_set_rows(); +} + ggml_tensor * llama_kv_cache_unified_context::get_k(ggml_context * ctx, int32_t il) const { - return kv->get_k(ctx, il, n_kv); + return kv->get_k(ctx, il, n_kv, sinfos[i_cur]); } ggml_tensor * llama_kv_cache_unified_context::get_v(ggml_context * ctx, int32_t il) const { - return kv->get_v(ctx, il, n_kv); + return kv->get_v(ctx, il, n_kv, sinfos[i_cur]); } ggml_tensor * llama_kv_cache_unified_context::cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il) const { diff --git a/examples/talk-llama/llama-kv-cache-unified.h b/examples/talk-llama/llama-kv-cache-unified.h index b8b0356e830..3e28e346c3f 100644 --- a/examples/talk-llama/llama-kv-cache-unified.h +++ b/examples/talk-llama/llama-kv-cache-unified.h @@ -35,16 +35,50 @@ class llama_kv_cache_unified : public llama_memory_i { std::vector ids; }; + struct stream_copy_info { + bool empty() const { + assert(ssrc.size() == sdst.size()); + return ssrc.empty(); + } + + std::vector ssrc; + std::vector sdst; + }; + // for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the // KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]] struct slot_info { // data for ggml_set_rows using idx_vec_t = std::vector; - idx_vec_t idxs; + // number of streams: ns = s1 - s0 + 1 + llama_seq_id s0; + llama_seq_id s1; + + std::vector strm; // [ns] + std::vector idxs; // [ns] uint32_t head() const { - return idxs.at(0); + GGML_ASSERT(idxs.size() == 1); + GGML_ASSERT(!idxs[0].empty()); + + return idxs[0][0]; + } + + void resize(size_t n) { + strm.resize(n); + idxs.resize(n); + } + + size_t size() const { + GGML_ASSERT(idxs.size() == strm.size()); + GGML_ASSERT(!idxs.empty()); + + return idxs[0].size(); + } + + size_t n_stream() const { + return strm.size(); } bool empty() const { @@ -54,9 +88,6 @@ class llama_kv_cache_unified : public llama_memory_i { void clear() { idxs.clear(); } - - // TODO: implement - //std::vector seq_idxs; }; using slot_info_vec_t = std::vector; @@ -68,6 +99,7 @@ class llama_kv_cache_unified : public llama_memory_i { ggml_type type_v, bool v_trans, bool offload, + bool unified, uint32_t kv_size, uint32_t n_seq_max, uint32_t n_pad, @@ -111,7 +143,8 @@ class llama_kv_cache_unified : public llama_memory_i { // llama_kv_cache_unified specific API // - uint32_t get_size() const; + uint32_t get_size() const; + uint32_t get_n_stream() const; bool get_has_shift() const; @@ -121,9 +154,12 @@ class llama_kv_cache_unified : public llama_memory_i { uint32_t get_n_kv() const; + // TODO: temporary + bool get_supports_set_rows() const; + // get views of the current state of the cache - ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const; - ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const; + ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; + ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const; // store k_cur and v_cur in the cache based on the provided head location ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const; @@ -137,7 +173,7 @@ class llama_kv_cache_unified : public llama_memory_i { // return empty vector on failure slot_info_vec_t prepare(const std::vector & ubatches); - bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo); + bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info); // find a slot of kv cells that can hold the ubatch // if cont == true, then the slot must be continuous @@ -157,8 +193,9 @@ class llama_kv_cache_unified : public llama_memory_i { void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const; + void set_input_k_shift(ggml_tensor * dst) const; + void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const; - void set_input_k_shift (ggml_tensor * dst) const; void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const; private: @@ -172,15 +209,15 @@ class llama_kv_cache_unified : public llama_memory_i { ggml_tensor * k; ggml_tensor * v; + + std::vector k_stream; + std::vector v_stream; }; bool v_trans = true; // the value tensor is transposed - // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) - // note: this is not part of the KV state and it's only used to speed-up the find_slot() method - uint32_t head = 0; - const uint32_t n_seq_max = 1; + const uint32_t n_stream = 1; // required padding const uint32_t n_pad = 1; @@ -193,14 +230,24 @@ class llama_kv_cache_unified : public llama_memory_i { // env: LLAMA_SET_ROWS (temporary) // ref: https://github.com/ggml-org/llama.cpp/pull/14285 - int supports_set_rows = false; + bool supports_set_rows = false; const llama_swa_type swa_type = LLAMA_SWA_TYPE_NONE; std::vector ctxs; std::vector bufs; - llama_kv_cells_unified cells; + // the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot()) + // note: this is not part of the KV state and it's only used to speed-up the find_slot() method + std::vector v_heads; + + std::vector v_cells; + + // maps from a sequence id to a stream id + std::vector seq_to_stream; + + // pending stream copies that will be applied during the next update + stream_copy_info sc_info; std::vector layers; @@ -226,29 +273,34 @@ class llama_kv_cache_unified : public llama_memory_i { float freq_base, float freq_scale) const; - llm_graph_result_ptr build_graph_shift( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf) const; + ggml_cgraph * build_graph_shift( + llm_graph_result * res, + llama_context * lctx) const; - llm_graph_result_ptr build_graph_defrag( - const llama_cparams & cparams, - ggml_context * ctx, - ggml_cgraph * gf, + ggml_cgraph * build_graph_defrag( + llm_graph_result * res, + llama_context * lctx, const defrag_info & dinfo) const; - void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; - void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + struct cell_ranges_t { + uint32_t strm; - bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); - bool state_read_data(llama_io_read_i & io, uint32_t cell_count); + std::vector> data; // ranges, from inclusive, to exclusive + }; + + void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count); }; class llama_kv_cache_unified_context : public llama_memory_context_i { public: // some shorthands - using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; - using defrag_info = llama_kv_cache_unified::defrag_info; + using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t; + using defrag_info = llama_kv_cache_unified::defrag_info; + using stream_copy_info = llama_kv_cache_unified::stream_copy_info; // used for errors llama_kv_cache_unified_context(llama_memory_status status); @@ -262,7 +314,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { llama_kv_cache_unified * kv, llama_context * lctx, bool do_shift, - defrag_info dinfo); + defrag_info dinfo, + stream_copy_info sc_info); // used to create a batch procesing context from a batch llama_kv_cache_unified_context( @@ -288,6 +341,9 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { uint32_t get_n_kv() const; + // TODO: temporary + bool get_supports_set_rows() const; + // get views of the current state of the cache ggml_tensor * get_k(ggml_context * ctx, int32_t il) const; ggml_tensor * get_v(ggml_context * ctx, int32_t il) const; @@ -320,6 +376,8 @@ class llama_kv_cache_unified_context : public llama_memory_context_i { defrag_info dinfo; + stream_copy_info sc_info; + // // batch processing context // diff --git a/examples/talk-llama/llama-memory-hybrid.cpp b/examples/talk-llama/llama-memory-hybrid.cpp index 6cd10db06b7..d8e2086c875 100644 --- a/examples/talk-llama/llama-memory-hybrid.cpp +++ b/examples/talk-llama/llama-memory-hybrid.cpp @@ -38,6 +38,7 @@ llama_memory_hybrid::llama_memory_hybrid( type_v, v_trans, offload, + 1, kv_size, n_seq_max, n_pad, diff --git a/examples/talk-llama/llama-memory-recurrent.cpp b/examples/talk-llama/llama-memory-recurrent.cpp index 2c1ae67098c..c0c2ec084dc 100644 --- a/examples/talk-llama/llama-memory-recurrent.cpp +++ b/examples/talk-llama/llama-memory-recurrent.cpp @@ -446,7 +446,7 @@ bool llama_memory_recurrent::find_slot(const llama_ubatch & ubatch) { // A slot should be always be contiguous. // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs()); int32_t min = size - 1; int32_t max = 0; @@ -768,6 +768,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // Iterate and write all the keys first, each row is a cell // Get whole range at a time for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (r_l[il] == nullptr) continue; // Write key type const int32_t r_type_i = (int32_t)r_l[il]->type; @@ -787,6 +789,8 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (s_l[il] == nullptr) continue; // Write value type const int32_t s_type_i = (int32_t)s_l[il]->type; @@ -807,6 +811,9 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: // When v is transposed, we also need the element size and get the element ranges from each row const uint32_t mem_size = size; for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers (read_data will handle this by checking "r_l" and "s_l" for null) + if (s_l[il] == nullptr) continue; + const uint32_t n_embd_s = hparams.n_embd_s(); // Write value type @@ -951,6 +958,8 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (r_l[il] == nullptr) continue; // Read type of key int32_t r_type_i_ref; @@ -978,11 +987,14 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell if (!s_trans) { for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (s_l[il] == nullptr) continue; // Read type of value int32_t s_type_i_ref; io.read_to(&s_type_i_ref, sizeof(s_type_i_ref)); const int32_t s_type_i = (int32_t)s_l[il]->type; + if (s_type_i != s_type_i_ref) { LLAMA_LOG_ERROR("%s: mismatched s type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il); return false; @@ -1005,6 +1017,9 @@ bool llama_memory_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell } else { // For each layer, read the values for each cell (transposed) for (uint32_t il = 0; il < n_layer; ++il) { + // skip null layers + if (s_l[il] == nullptr) continue; + const uint32_t n_embd_s = hparams.n_embd_s(); // Read type of value diff --git a/examples/talk-llama/llama-model.cpp b/examples/talk-llama/llama-model.cpp index a322fc39352..71f89e19072 100644 --- a/examples/talk-llama/llama-model.cpp +++ b/examples/talk-llama/llama-model.cpp @@ -107,8 +107,10 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_A13B: return "A13B"; + case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; + case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -644,6 +646,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); + // MiniCPM uses rope by default, unlike Granite which uses it as a switch + hparams.rope_finetuned = true; + switch (hparams.n_layer) { case 52: type = LLM_TYPE_1B; break; case 40: type = LLM_TYPE_2B; break; @@ -849,6 +854,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_DREAM: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + // Dream models are primarily 7B with 28 layers + switch (hparams.n_layer) { + case 28: + type = LLM_TYPE_7B; + break; + default: + type = LLM_TYPE_UNKNOWN; + } + // Set non-causal attention for diffusion models + hparams.causal_attn = false; + } + break; case LLM_ARCH_QWEN2MOE: { ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); @@ -935,6 +955,33 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_PLAMO2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + // Load Mamba SSM parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = hparams.n_head_kv(i) == 0; + } + + switch (hparams.n_layer) { + case 16: type = LLM_TYPE_1B; break; + case 32: + if (hparams.n_embd == 2048) { + type = LLM_TYPE_2B; + } else if (hparams.n_embd == 4096) { + type = LLM_TYPE_8B; + } + break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_GPT2: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1322,7 +1369,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { // that have no expert_gating_func model parameter set hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX; } - ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul); + ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false); switch (hparams.n_layer) { case 27: type = LLM_TYPE_16B; break; @@ -1446,6 +1493,23 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_EXAONE4: + { + if (hparams.n_layer == 64) { // 32B + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + hparams.n_swa = 4096; + hparams.set_swa_pattern(4); + } + + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 30: type = LLM_TYPE_1_2B; break; + case 64: type = LLM_TYPE_32B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: { @@ -1483,7 +1547,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_TOKEN_SHIFT_COUNT, hparams.token_shift_count, false); switch (hparams.n_layer) { - case 12: type = LLM_TYPE_190M; break; + case 12: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_190M; break; + default: type = LLM_TYPE_UNKNOWN; + } break; case 24: switch (hparams.n_embd) { case 1024: type = LLM_TYPE_450M; break; @@ -1496,7 +1564,17 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 3584: type = LLM_TYPE_7B; break; default: type = LLM_TYPE_UNKNOWN; } break; - case 32: type = LLM_TYPE_2_9B; break; // RWKV-7-World + case 32: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_2_9B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 61: + switch (hparams.n_embd) { + case 4096: type = LLM_TYPE_14B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -1607,10 +1685,20 @@ void llama_model::load_hparams(llama_model_loader & ml) { } } break; case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_ERNIE4_5_MOE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + if (arch == LLM_ARCH_ERNIE4_5_MOE) { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_INTERLEAVE_MOE_LAYER_STEP, hparams.n_moe_layer_step); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + } + switch (hparams.n_layer) { case 18: type = LLM_TYPE_0_3B; break; + case 28: type = LLM_TYPE_21B_A3B; break; + case 54: type = LLM_TYPE_300B_A47B; break; default: type = LLM_TYPE_UNKNOWN; } } break; @@ -2643,12 +2731,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_QWEN2: case LLM_ARCH_QWEN2VL: + case LLM_ARCH_DREAM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_vocab}, TENSOR_NOT_REQUIRED); // if output is NULL, init from the input tok embed if (output == NULL) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); @@ -2938,6 +3028,73 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_PLAMO2: + { + const uint32_t d_conv = hparams.ssm_d_conv; + const uint32_t d_state = hparams.ssm_d_state; + const uint32_t num_heads = hparams.ssm_dt_rank; + const uint32_t intermediate_size = hparams.ssm_d_inner; + const uint32_t head_dim = intermediate_size / num_heads; + const uint32_t qk_dim = head_dim; + const uint32_t v_dim = head_dim; + const int64_t num_attention_heads = hparams.n_head(); + const int64_t q_num_heads = num_attention_heads; + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + bool is_mamba_layer = hparams.is_recurrent(i); + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + if (is_mamba_layer) { + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, 2 * intermediate_size}, 0); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, intermediate_size}, 0); + + layer.ssm_x = create_tensor(tn(LLM_TENSOR_SSM_X, "weight", i), {intermediate_size, dt_dim + 2*d_state}, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "weight", i), {dt_dim, num_heads}, 0); + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {num_heads}, 0); + + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {num_heads}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {num_heads}, 0); + + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {intermediate_size, n_embd}, 0); + + layer.ssm_dt_norm = create_tensor(tn(LLM_TENSOR_SSM_DT_NORM, i), {dt_dim}, 0); + layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); + layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); + } else { + const int64_t num_key_value_heads = hparams.n_head_kv(i); + const int64_t k_num_heads = num_key_value_heads; + const int64_t v_num_heads = num_key_value_heads; + const int64_t q_proj_dim = q_num_heads * qk_dim; + const int64_t k_proj_dim = k_num_heads * qk_dim; + const int64_t v_proj_dim = v_num_heads * v_dim; + + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); + } + + // All layers have post-attention norm, FFN norm, and FFN tensors + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, i), {n_embd}, 0); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, i), {n_embd}, 0); + } + } break; case LLM_ARCH_GPT2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4232,6 +4389,39 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); } } break; + case LLM_ARCH_EXAONE4: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + } + } break; case LLM_ARCH_RWKV6: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4747,6 +4937,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_ERNIE4_5_MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -4775,9 +4966,27 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + + if (arch == LLM_ARCH_ERNIE4_5_MOE && static_cast(i) >= hparams.n_layer_dense_lead) { // MoE layers + int n_ff_exp = hparams.n_ff_exp; + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0); + + // Shared expert (if present) + if (hparams.n_ff_shexp > 0) { + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp}, 0); + } + } else { // Dense layers + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } } } break; case LLM_ARCH_FALCON_H1: @@ -5209,6 +5418,7 @@ void llama_model::print_info() const { arch == LLM_ARCH_MAMBA2 || arch == LLM_ARCH_JAMBA || arch == LLM_ARCH_FALCON_H1 || + arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); @@ -5381,7 +5591,7 @@ ggml_tensor * llama_model::get_rope_factors(const llama_cparams & cparams, int i } struct llm_build_llama : public llm_graph_context { - llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_llama(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5457,7 +5667,7 @@ struct llm_build_llama : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -5537,7 +5747,7 @@ struct llm_build_llama : public llm_graph_context { }; struct llm_build_llama_iswa : public llm_graph_context { - llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5631,7 +5841,7 @@ struct llm_build_llama_iswa : public llm_graph_context { cb(Kcur, "Kcur_normed", il); } - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -5720,7 +5930,7 @@ struct llm_build_llama_iswa : public llm_graph_context { }; struct llm_build_deci : public llm_graph_context { - llm_build_deci(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_deci(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5808,7 +6018,7 @@ struct llm_build_deci : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } @@ -5876,7 +6086,7 @@ struct llm_build_deci : public llm_graph_context { }; struct llm_build_baichuan : public llm_graph_context { - llm_build_baichuan(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_baichuan(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -5940,7 +6150,7 @@ struct llm_build_baichuan : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -5998,7 +6208,7 @@ struct llm_build_baichuan : public llm_graph_context { }; struct llm_build_xverse : public llm_graph_context { - llm_build_xverse(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_xverse(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6055,7 +6265,7 @@ struct llm_build_xverse : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6111,7 +6321,7 @@ struct llm_build_xverse : public llm_graph_context { }; struct llm_build_falcon : public llm_graph_context { - llm_build_falcon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_falcon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6178,7 +6388,7 @@ struct llm_build_falcon : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6233,7 +6443,7 @@ struct llm_build_falcon : public llm_graph_context { }; struct llm_build_grok : public llm_graph_context { - llm_build_grok(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_grok(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6308,7 +6518,7 @@ struct llm_build_grok : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } @@ -6395,7 +6605,7 @@ struct llm_build_grok : public llm_graph_context { }; struct llm_build_dbrx : public llm_graph_context { - llm_build_dbrx(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_dbrx(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6457,7 +6667,7 @@ struct llm_build_dbrx : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6520,7 +6730,7 @@ struct llm_build_dbrx : public llm_graph_context { }; struct llm_build_starcoder : public llm_graph_context { - llm_build_starcoder(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_starcoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6571,7 +6781,7 @@ struct llm_build_starcoder : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6629,7 +6839,7 @@ struct llm_build_starcoder : public llm_graph_context { }; struct llm_build_refact : public llm_graph_context { - llm_build_refact(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_refact(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -6670,7 +6880,7 @@ struct llm_build_refact : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -6728,7 +6938,7 @@ struct llm_build_refact : public llm_graph_context { }; struct llm_build_bert : public llm_graph_context { - llm_build_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6827,7 +7037,7 @@ struct llm_build_bert : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); @@ -6914,7 +7124,7 @@ struct llm_build_bert : public llm_graph_context { }; struct llm_build_neo_bert : public llm_graph_context { - llm_build_neo_bert(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_neo_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -6972,7 +7182,7 @@ struct llm_build_neo_bert : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "kqv_out", il); @@ -7024,7 +7234,7 @@ struct llm_build_neo_bert : public llm_graph_context { }; struct llm_build_bloom : public llm_graph_context { - llm_build_bloom(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_bloom(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -7072,7 +7282,7 @@ struct llm_build_bloom : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7130,7 +7340,7 @@ struct llm_build_bloom : public llm_graph_context { }; struct llm_build_mpt : public llm_graph_context { - llm_build_mpt(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_mpt(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -7219,7 +7429,7 @@ struct llm_build_mpt : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7278,7 +7488,7 @@ struct llm_build_mpt : public llm_graph_context { }; struct llm_build_stablelm : public llm_graph_context { - llm_build_stablelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_stablelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7365,7 +7575,7 @@ struct llm_build_stablelm : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7430,7 +7640,7 @@ struct llm_build_stablelm : public llm_graph_context { }; struct llm_build_qwen : public llm_graph_context { - llm_build_qwen(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_qwen(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7486,7 +7696,7 @@ struct llm_build_qwen : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7544,7 +7754,7 @@ struct llm_build_qwen : public llm_graph_context { }; struct llm_build_qwen2 : public llm_graph_context { - llm_build_qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_qwen2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7606,7 +7816,7 @@ struct llm_build_qwen2 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7654,6 +7864,113 @@ struct llm_build_qwen2 : public llm_graph_context { // lm_head cur = build_lora_mm(model.output, cur); + if (model.output_b != nullptr) { + cur = ggml_add(ctx0, cur, model.output_b); + } + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_dream : public llm_graph_context { + llm_build_dream(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + //copied from qwen2 + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, + nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); res->t_logits = cur; @@ -7662,7 +7979,7 @@ struct llm_build_qwen2 : public llm_graph_context { }; struct llm_build_qwen2vl : public llm_graph_context { - llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7727,7 +8044,7 @@ struct llm_build_qwen2vl : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7783,7 +8100,7 @@ struct llm_build_qwen2vl : public llm_graph_context { }; struct llm_build_qwen2moe : public llm_graph_context { - llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_qwen2moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -7854,7 +8171,7 @@ struct llm_build_qwen2moe : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -7942,7 +8259,7 @@ struct llm_build_qwen2moe : public llm_graph_context { }; struct llm_build_qwen3 : public llm_graph_context { - llm_build_qwen3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_qwen3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8007,7 +8324,7 @@ struct llm_build_qwen3 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8063,7 +8380,7 @@ struct llm_build_qwen3 : public llm_graph_context { }; struct llm_build_qwen3moe : public llm_graph_context { - llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_qwen3moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8128,7 +8445,7 @@ struct llm_build_qwen3moe : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8191,7 +8508,7 @@ struct llm_build_qwen3moe : public llm_graph_context { }; struct llm_build_phi2 : public llm_graph_context { - llm_build_phi2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_phi2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -8268,7 +8585,7 @@ struct llm_build_phi2 : public llm_graph_context { // ref: https://github.com/ml-explore/mlx-examples/blob/08e862336ade809bc37d1035f94b359e7d1a5152/phi2/phi2.py#L64-L66 Qcur = ggml_scale(ctx0, Qcur, 1.0f/sqrtf(float(n_embd_head))); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } @@ -8322,7 +8639,7 @@ struct llm_build_phi2 : public llm_graph_context { template struct llm_build_phi3 : public llm_graph_context { - llm_build_phi3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_phi3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -8405,7 +8722,7 @@ struct llm_build_phi3 : public llm_graph_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); cb(Qcur, "Qcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } @@ -8480,7 +8797,7 @@ struct llm_build_phi3 : public llm_graph_context { }; struct llm_build_plamo : public llm_graph_context { - llm_build_plamo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_plamo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8539,7 +8856,7 @@ struct llm_build_plamo : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8595,7 +8912,7 @@ struct llm_build_plamo : public llm_graph_context { }; struct llm_build_gpt2 : public llm_graph_context { - llm_build_gpt2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_gpt2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -8647,7 +8964,7 @@ struct llm_build_gpt2 : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8705,7 +9022,7 @@ struct llm_build_gpt2 : public llm_graph_context { }; struct llm_build_codeshell : public llm_graph_context { - llm_build_codeshell(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_codeshell(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -8761,7 +9078,7 @@ struct llm_build_codeshell : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8819,7 +9136,7 @@ struct llm_build_codeshell : public llm_graph_context { }; struct llm_build_orion : public llm_graph_context { - llm_build_orion(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_orion(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -8890,7 +9207,7 @@ struct llm_build_orion : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -8946,7 +9263,7 @@ struct llm_build_orion : public llm_graph_context { }; struct llm_build_internlm2 : public llm_graph_context { - llm_build_internlm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_internlm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -9017,7 +9334,7 @@ struct llm_build_internlm2 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -9073,7 +9390,7 @@ struct llm_build_internlm2 : public llm_graph_context { }; struct llm_build_minicpm3 : public llm_graph_context { - llm_build_minicpm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_minicpm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { //TODO: if the model varies, these parameters need to be read from the model const int64_t n_embd_base = 256; const float scale_embd = 12.0f; @@ -9205,7 +9522,7 @@ struct llm_build_minicpm3 : public llm_graph_context { ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(k_states, "k_states", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } @@ -9277,7 +9594,7 @@ struct llm_build_minicpm3 : public llm_graph_context { }; struct llm_build_gemma : public llm_graph_context { - llm_build_gemma(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_gemma(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; @@ -9335,7 +9652,7 @@ struct llm_build_gemma : public llm_graph_context { Qcur = ggml_scale(ctx0, Qcur, 1.0f / sqrtf(float(n_embd_head))); cb(Qcur, "Qcur_scaled", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } @@ -9393,7 +9710,7 @@ struct llm_build_gemma : public llm_graph_context { }; struct llm_build_gemma2_iswa : public llm_graph_context { - llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_gemma2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -9450,7 +9767,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } @@ -9523,7 +9840,7 @@ struct llm_build_gemma2_iswa : public llm_graph_context { }; struct llm_build_gemma3_iswa : public llm_graph_context { - llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_gemma3_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; ggml_tensor * cur; @@ -9592,7 +9909,7 @@ struct llm_build_gemma3_iswa : public llm_graph_context { // ref: https://github.com/google/gemma_pytorch/blob/014acb7ac4563a5f77c76d7ff98f31b568c16508/gemma/model.py#L315 Qcur = ggml_scale(ctx0, Qcur, hparams.f_attention_scale); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); } @@ -9661,7 +9978,6 @@ struct llm_build_gemma3_iswa : public llm_graph_context { struct llm_build_gemma3n_iswa : public llm_graph_context { const llama_model & model; - ggml_cgraph * gf; const int64_t n_embd_head; const int64_t n_embd_altup; @@ -9671,10 +9987,9 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { const int n_layer_sparsity = 10; // number of layers using activation sparsity const float f_sparsity_std_mul = 1.6448533535003662f; // std_multiplier = normal_dist.icdf(0.95) - llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) + llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), - gf(gf), n_embd_head(model.hparams.n_embd_head_k), n_embd_altup(model.hparams.n_embd_altup), n_altup(model.hparams.n_altup), @@ -9775,7 +10090,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { cb(Qcur, "Qcur_pos", il); cb(Kcur, "Kcur_pos", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, hparams.f_attention_scale, il); } else { @@ -9793,7 +10108,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur_pos", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, nullptr, nullptr, nullptr, nullptr, hparams.f_attention_scale, il); } @@ -10087,7 +10402,7 @@ struct llm_build_gemma3n_iswa : public llm_graph_context { // TODO: move up next to build_starcoder struct llm_build_starcoder2 : public llm_graph_context { - llm_build_starcoder2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_starcoder2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -10158,7 +10473,7 @@ struct llm_build_starcoder2 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -10219,7 +10534,6 @@ struct llm_graph_context_mamba : public llm_graph_context { ggml_tensor * build_mamba_layer( llm_graph_input_rs * inp, - ggml_cgraph * gf, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, @@ -10244,13 +10558,13 @@ struct llm_graph_context_mamba : public llm_graph_context { const int64_t n_seq_tokens = ubatch.n_seq_tokens; GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -10331,7 +10645,7 @@ struct llm_graph_context_mamba : public llm_graph_context { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -10358,11 +10672,10 @@ struct llm_graph_context_mamba : public llm_graph_context { ggml_tensor * build_mamba2_layer( llm_graph_input_rs * inp, - ggml_cgraph * gf, - ggml_tensor * cur, - const llama_model & model, - const llama_ubatch & ubatch, - int il) const { + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) const { const auto * mctx_cur = inp->mctx; @@ -10379,13 +10692,13 @@ struct llm_graph_context_mamba : public llm_graph_context { const int64_t n_seq_tokens = ubatch.n_seq_tokens; GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.equal_seqs()); GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs); + ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} @@ -10455,7 +10768,7 @@ struct llm_graph_context_mamba : public llm_graph_context { return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); }; - ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); // store last states ggml_build_forward_expand(gf, @@ -10491,7 +10804,7 @@ struct llm_graph_context_mamba : public llm_graph_context { }; struct llm_build_mamba : public llm_graph_context_mamba { - llm_build_mamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -10510,9 +10823,9 @@ struct llm_build_mamba : public llm_graph_context_mamba { cb(cur, "attn_norm", il); if (model.arch == LLM_ARCH_MAMBA2) { - cur = build_mamba2_layer(rs_inp, gf, cur, model, ubatch, il); + cur = build_mamba2_layer(rs_inp, cur, model, ubatch, il); } else { - cur = build_mamba_layer(rs_inp, gf, cur, model, ubatch, il); + cur = build_mamba_layer(rs_inp, cur, model, ubatch, il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10548,7 +10861,7 @@ struct llm_build_mamba : public llm_graph_context_mamba { }; struct llm_build_jamba : public llm_graph_context_mamba { - llm_build_jamba(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; @@ -10568,7 +10881,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { cb(cur, "attn_norm", il); if (n_head_kv == 0) { - cur = build_mamba_layer(inp_hybrid->get_recr(), gf, cur, model, ubatch, il); + cur = build_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il); } else { // Attention @@ -10589,7 +10902,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { cb(Vcur, "Vcur", il); // No RoPE :) - cur = build_attn(inp_hybrid->get_attn(), gf, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); + cur = build_attn(inp_hybrid->get_attn(), model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head)), il); } if (il == n_layer - 1 && inp_out_ids) { @@ -10657,7 +10970,7 @@ struct llm_build_jamba : public llm_graph_context_mamba { }; struct llm_build_command_r : public llm_graph_context { - llm_build_command_r(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_command_r(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -10745,7 +11058,7 @@ struct llm_build_command_r : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -10804,7 +11117,7 @@ struct llm_build_command_r : public llm_graph_context { }; struct llm_build_cohere2_iswa : public llm_graph_context { - llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_cohere2_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -10880,7 +11193,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -10940,7 +11253,7 @@ struct llm_build_cohere2_iswa : public llm_graph_context { // * removed bias // * removed MoE struct llm_build_olmo : public llm_graph_context { - llm_build_olmo(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_olmo(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11011,7 +11324,7 @@ struct llm_build_olmo : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11068,7 +11381,7 @@ struct llm_build_olmo : public llm_graph_context { }; struct llm_build_olmo2 : public llm_graph_context { - llm_build_olmo2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_olmo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11131,7 +11444,7 @@ struct llm_build_olmo2 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11197,7 +11510,7 @@ struct llm_build_olmo2 : public llm_graph_context { // * removed bias // * added q, k norm struct llm_build_olmoe : public llm_graph_context { - llm_build_olmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_olmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11264,7 +11577,7 @@ struct llm_build_olmoe : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11325,7 +11638,7 @@ struct llm_build_olmoe : public llm_graph_context { }; struct llm_build_openelm : public llm_graph_context { - llm_build_openelm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_openelm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11397,7 +11710,7 @@ struct llm_build_openelm : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Qcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11454,7 +11767,7 @@ struct llm_build_openelm : public llm_graph_context { }; struct llm_build_gptneox : public llm_graph_context { - llm_build_gptneox(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_gptneox(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -11509,7 +11822,7 @@ struct llm_build_gptneox : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11600,7 +11913,7 @@ struct llm_build_gptneox : public llm_graph_context { }; struct llm_build_arctic : public llm_graph_context { - llm_build_arctic(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_arctic(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11659,7 +11972,7 @@ struct llm_build_arctic : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -11738,7 +12051,7 @@ struct llm_build_arctic : public llm_graph_context { }; struct llm_build_deepseek : public llm_graph_context { - llm_build_deepseek(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_deepseek(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -11814,7 +12127,7 @@ struct llm_build_deepseek : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } @@ -11900,7 +12213,7 @@ struct llm_build_deepseek : public llm_graph_context { }; struct llm_build_deepseek2 : public llm_graph_context { - llm_build_deepseek2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_deepseek2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { bool is_lite = (hparams.n_layer == 27); const bool is_mla = (hparams.n_embd_head_k_mla != 0 && hparams.n_embd_head_v_mla != 0); @@ -12042,7 +12355,7 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(Vcur, "Vcur", il); // note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group) - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, model.layers[il].wv_b, kq_scale, il); } else { @@ -12076,7 +12389,7 @@ struct llm_build_deepseek2 : public llm_graph_context { cb(Kcur, "Kcur", il); // note: MLA without the absorption optimization converts into MHA (ie: GQA with full n_head groups) - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); } @@ -12163,7 +12476,7 @@ struct llm_build_deepseek2 : public llm_graph_context { }; struct llm_build_bitnet : public llm_graph_context { - llm_build_bitnet(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_bitnet(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -12243,7 +12556,7 @@ struct llm_build_bitnet : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, NULL, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); @@ -12323,7 +12636,7 @@ struct llm_build_bitnet : public llm_graph_context { }; struct llm_build_t5_enc : public llm_graph_context { - llm_build_t5_enc(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_t5_enc(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -12366,7 +12679,7 @@ struct llm_build_t5_enc : public llm_graph_context { ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b_enc ? model.layers[il].attn_rel_b_enc : model.layers[0].attn_rel_b_enc; ggml_tensor * kq_b = build_pos_bias(pos_bucket_enc, attn_rel_b); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo_enc, nullptr, Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); @@ -12424,7 +12737,7 @@ struct llm_build_t5_enc : public llm_graph_context { }; struct llm_build_t5_dec : public llm_graph_context { - llm_build_t5_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_t5_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; //const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -12472,7 +12785,7 @@ struct llm_build_t5_dec : public llm_graph_context { ggml_tensor * attn_rel_b = model.layers[il].attn_rel_b ? model.layers[il].attn_rel_b : model.layers[0].attn_rel_b; ggml_tensor * kq_b = build_pos_bias(pos_bucket_dec, attn_rel_b); - cur = build_attn(inp_attn_self, gf, + cur = build_attn(inp_attn_self, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, kq_b, nullptr, 1.0f, il); cb(cur, "kqv_out", il); @@ -12504,7 +12817,7 @@ struct llm_build_t5_dec : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_outputs_enc); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_outputs_enc); - cur = build_attn(inp_attn_cross, gf, + cur = build_attn(inp_attn_cross, model.layers[il].wo_cross, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f, il); cb(cur, "kqv_out", il); @@ -12594,7 +12907,7 @@ struct llm_build_t5_dec : public llm_graph_context { }; struct llm_build_jais : public llm_graph_context { - llm_build_jais(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_jais(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -12636,7 +12949,7 @@ struct llm_build_jais : public llm_graph_context { Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/float(n_embd_head), il); } @@ -12689,7 +13002,7 @@ struct llm_build_jais : public llm_graph_context { }; struct llm_build_chatglm : public llm_graph_context { - llm_build_chatglm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_chatglm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -12768,7 +13081,7 @@ struct llm_build_chatglm : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -12822,7 +13135,7 @@ struct llm_build_chatglm : public llm_graph_context { }; struct llm_build_glm4 : public llm_graph_context { - llm_build_glm4(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_glm4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -12901,7 +13214,7 @@ struct llm_build_glm4 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -12973,7 +13286,7 @@ struct llm_build_glm4 : public llm_graph_context { }; struct llm_build_nemotron : public llm_graph_context { - llm_build_nemotron(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_nemotron(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13045,7 +13358,7 @@ struct llm_build_nemotron : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -13102,7 +13415,7 @@ struct llm_build_nemotron : public llm_graph_context { }; struct llm_build_exaone : public llm_graph_context { - llm_build_exaone(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -13176,7 +13489,7 @@ struct llm_build_exaone : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -13232,32 +13545,168 @@ struct llm_build_exaone : public llm_graph_context { } }; -struct llm_build_rwkv6_base : public llm_graph_context { - const llama_model & model; +template +struct llm_build_exaone4 : public llm_graph_context { + llm_build_exaone4(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_k; - llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) { - } + GGML_ASSERT(n_embd_head == hparams.n_embd_head_v); + GGML_ASSERT(n_embd_head == hparams.n_rot); - ggml_tensor * build_rwkv6_channel_mix( - const llama_layer * layer, - ggml_tensor * cur, - ggml_tensor * x_prev, - llm_arch arch) const { - ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); - switch (arch) { - case LLM_ARCH_RWKV6: - { - ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); - ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur); + ggml_tensor * cur; + ggml_tensor * inpL; - ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr)); - ggml_tensor * k = ggml_sqr( - ctx0, - ggml_relu( - ctx0, - build_lora_mm(layer->channel_mix_key, xk) - ) - ); + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + using inp_attn_type = std::conditional_t; + inp_attn_type * inp_attn = nullptr; + + if constexpr (iswa) { + inp_attn = build_attn_inp_kv_unified_iswa(); + } else { + inp_attn = build_attn_inp_kv_unified(); + } + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // use RoPE for SWA layers or non-SWA models + const bool use_rope = hparams.is_swa(il) || hparams.swa_type == LLAMA_SWA_TYPE_NONE; + + cur = inpL; + + // self-attention + { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + cb(Kcur, "Kcur_normed", il); + + if (use_rope) { + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + } + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + cur = build_norm(cur, + model.layers[il].attn_post_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_ffn(ffn_inp, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = build_norm(cur, + model.layers[il].ffn_post_norm, NULL, + LLM_NORM_RMS, -1); + cb(cur, "ffn_post_norm", -1); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + +struct llm_build_rwkv6_base : public llm_graph_context { + const llama_model & model; + + llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) { + } + + ggml_tensor * build_rwkv6_channel_mix( + const llama_layer * layer, + ggml_tensor * cur, + ggml_tensor * x_prev, + llm_arch arch) const { + ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); + switch (arch) { + case LLM_ARCH_RWKV6: + { + ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); + ggml_tensor * xr = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_r), cur); + + ggml_tensor * r = ggml_sigmoid(ctx0, build_lora_mm(layer->channel_mix_receptance, xr)); + ggml_tensor * k = ggml_sqr( + ctx0, + ggml_relu( + ctx0, + build_lora_mm(layer->channel_mix_key, xk) + ) + ); cur = ggml_mul(ctx0, r, build_lora_mm(layer->channel_mix_value, k)); } break; default: @@ -13269,7 +13718,6 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * build_rwkv6_time_mix( llm_graph_input_rs * inp, - ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, const llama_ubatch & ubatch, @@ -13396,7 +13844,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { } ggml_tensor * wkv_state = build_rs( - inp, gf, mctx_cur->get_s_l(il), + inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); ggml_tensor * wkv_output; @@ -13442,7 +13890,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { }; struct llm_build_rwkv6 : public llm_build_rwkv6_base { - llm_build_rwkv6(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) { + llm_build_rwkv6(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { GGML_ASSERT(hparams.token_shift_count == 2); ggml_tensor * cur; @@ -13463,7 +13911,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); @@ -13478,7 +13926,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il); + cur = build_rwkv6_time_mix(rs_inp, att_norm, x_prev, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -13543,7 +13991,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { // ref: https://huggingface.co/recursal/QRWKV6-32B-Instruct-Preview-v0.1/blob/main/modeling_rwkv6qwen2.py struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { - llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv6_base(model, params) { + llm_build_rwkv6qwen2(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv6_base(model, params) { GGML_ASSERT(n_embd == hparams.n_embd_r()); ggml_tensor * cur; @@ -13563,7 +14011,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); cb(att_norm, "attn_norm", il); @@ -13575,7 +14023,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(rs_inp, gf, att_norm, x_prev, ubatch, il); + cur = build_rwkv6_time_mix(rs_inp, att_norm, x_prev, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -13665,7 +14113,6 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor * build_rwkv7_time_mix( llm_graph_input_rs * inp, - ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor *& first_layer_value, @@ -13751,7 +14198,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); ggml_tensor * wkv_state = build_rs( - inp, gf, mctx_cur->get_s_l(il), + inp, mctx_cur->get_s_l(il), hparams.n_embd_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -13798,7 +14245,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { }; struct llm_build_rwkv7 : public llm_build_rwkv7_base { - llm_build_rwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) { + llm_build_rwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { GGML_ASSERT(hparams.token_shift_count == 2); ggml_tensor * cur; @@ -13820,7 +14267,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); ggml_tensor * ffn_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], n_embd * ggml_element_size(token_shift)); @@ -13835,7 +14282,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il); + cur = build_rwkv7_time_mix(rs_inp, att_norm, x_prev, v_first, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -13894,7 +14341,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { struct llm_build_arwkv7 : public llm_build_rwkv7_base { - llm_build_arwkv7(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_build_rwkv7_base(model, params) { + llm_build_arwkv7(const llama_model & model, const llm_graph_params & params) : llm_build_rwkv7_base(model, params) { GGML_ASSERT(n_embd == hparams.n_embd_r()); ggml_tensor * cur; @@ -13915,7 +14362,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { const llama_layer * layer = &model.layers[il]; inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); - ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, gf, ubatch, il); + ggml_tensor * token_shift = build_rwkv_token_shift_load(rs_inp, ubatch, il); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); cb(att_norm, "attn_norm", il); @@ -13927,7 +14374,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(rs_inp, gf, att_norm, x_prev, v_first, ubatch, il); + cur = build_rwkv7_time_mix(rs_inp, att_norm, x_prev, v_first, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -13984,8 +14431,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { struct llm_build_granite : public llm_graph_context { llm_build_granite( const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf) + const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14019,7 +14465,7 @@ struct llm_build_granite : public llm_graph_context { // self-attention cur = build_attention_layer( - gf, cur, inp_pos, inp_attn, + cur, inp_pos, inp_attn, model, n_embd_head, il); if (il == n_layer - 1 && inp_out_ids) { @@ -14055,7 +14501,6 @@ struct llm_build_granite : public llm_graph_context { } ggml_tensor * build_attention_layer( - ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv_unified * inp_attn, @@ -14110,7 +14555,7 @@ struct llm_build_granite : public llm_graph_context { cb(Vcur, "Vcur", il); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -14198,11 +14643,9 @@ struct llm_build_granite : public llm_graph_context { }; struct llm_build_granite_hybrid : public llm_graph_context_mamba { - llm_build_granite_hybrid( const llama_model & model, - const llm_graph_params & params, - ggml_cgraph * gf) : + const llm_graph_params & params) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -14234,11 +14677,11 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { if (hparams.is_recurrent(il)) { // ssm layer // - cur = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il); + cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); } else { // attention layer // cur = build_attention_layer( - gf, cur, inp_pos, inp->get_attn(), model, + cur, inp_pos, inp->get_attn(), model, n_embd_head, il); } @@ -14277,7 +14720,6 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } ggml_tensor * build_attention_layer( - ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv_unified * inp_attn, @@ -14332,7 +14774,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { cb(Vcur, "Vcur", il); const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale; - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -14426,7 +14868,7 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { // * removed bias // * removed MoE struct llm_build_chameleon : public llm_graph_context { - llm_build_chameleon(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_chameleon(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -14517,7 +14959,7 @@ struct llm_build_chameleon : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, nullptr, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -14603,7 +15045,7 @@ struct llm_build_chameleon : public llm_graph_context { }; struct llm_build_wavtokenizer_dec : public llm_graph_context { - llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_wavtokenizer_dec(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -14755,7 +15197,7 @@ struct llm_build_wavtokenizer_dec : public llm_graph_context { }; struct llm_build_plm : public llm_graph_context { - llm_build_plm(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_plm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const float kq_scale = 1.0f/sqrtf(float(hparams.n_embd_head_k)); const uint32_t n_embd_head_qk_rope = hparams.n_rot; @@ -14873,7 +15315,7 @@ struct llm_build_plm : public llm_graph_context { ggml_tensor * k_states = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(k_states, "k_states", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, q_states, k_states, v_states, nullptr, nullptr, kq_scale, il); } @@ -14927,7 +15369,7 @@ struct llm_build_plm : public llm_graph_context { }; struct llm_build_bailingmoe : public llm_graph_context { - llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_bailingmoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { ggml_tensor * cur; ggml_tensor * inpL; @@ -14996,7 +15438,7 @@ struct llm_build_bailingmoe : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_rot)), il); } @@ -15071,7 +15513,7 @@ struct llm_build_bailingmoe : public llm_graph_context { }; struct llm_build_dots1 : public llm_graph_context { - llm_build_dots1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_dots1(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -15136,7 +15578,7 @@ struct llm_build_dots1 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -15221,7 +15663,7 @@ struct llm_build_dots1 : public llm_graph_context { }; struct llm_build_ernie4_5 : public llm_graph_context { - llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_ernie4_5(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -15291,7 +15733,7 @@ struct llm_build_ernie4_5 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); } @@ -15350,8 +15792,178 @@ struct llm_build_ernie4_5 : public llm_graph_context { } }; +struct llm_build_ernie4_5_moe : public llm_graph_context { + llm_build_ernie4_5_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv_unified(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + GGML_ASSERT(hparams.n_moe_layer_step > 0 && "Ernie 4.5 MoE requires n_moe_layer_step > 0"); + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + // norm + { + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + } + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + if (model.layers[il].bq) { + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur", il); + } + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + if (model.layers[il].bk) { + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur", il); + } + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + if (model.layers[il].bv) { + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur", il); + } + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + Qcur, Kcur, Vcur, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + bool is_moe_layer = static_cast(il) >= hparams.n_layer_dense_lead && (il + 1) % hparams.n_moe_layer_step == 0; + + if (!is_moe_layer) { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + } else { + // MoE branch + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + ggml_tensor * moe_out = build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, + il); + cb(moe_out, "ffn_moe_out", il); + + // Shared expert (if present) + if (hparams.n_ff_shexp > 0) { + ggml_tensor * ffn_shexp = build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + } else { + cur = moe_out; + } + cb(cur, "ffn_out", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + struct llm_build_falcon_h1 : public llm_graph_context_mamba { - llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context_mamba(params) { + llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; @@ -15407,7 +16019,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba { cb(Kcur, "Kcur-post-rope", il); cb(Vcur, "Vcur-post-rope", il); - ggml_tensor * attn_out = build_attn(inp->get_attn(), gf, + ggml_tensor * attn_out = build_attn(inp->get_attn(), model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(attn_out, "attn_out", il); @@ -15418,7 +16030,7 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba { // Mamba2 layer cb(cur, "ssm_in", il); - ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), gf, cur, model, ubatch, il); + ggml_tensor * ssm_out = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il); cb(ssm_out, "ssm_out", il); // // Aggregation @@ -15476,8 +16088,321 @@ struct llm_build_falcon_h1 : public llm_graph_context_mamba { } }; +struct llm_build_plamo2 : public llm_graph_context_mamba { + llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { + ggml_tensor * cur; + ggml_tensor * inpL; + + // {n_embd, n_tokens} + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "embedding_output", -1); + + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_hybrid = build_inp_mem_hybrid(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * residual = inpL; + + // ggml_graph_add_node(gf, model.layers[il].attn_norm); + // cb(model.layers[il].attn_norm, "attn_norm", il); + + // pre_mixer_norm + cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + + // check if this layer is Mamba or Attention + bool is_mamba_layer = hparams.is_recurrent(il); + + if (is_mamba_layer) { + // PLaMo-2 Mamba layer + cur = build_plamo2_mamba_layer(inp_hybrid->get_recr(), cur, model, ubatch, il); + } else { + // PLaMo-2 Attention layer + cur = build_plamo2_attn_layer(inp_hybrid->get_attn(), inp_pos, cur, model, il); + } + + // post_mixer_norm + cur = build_norm(cur, model.layers[il].attn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "attn_post_norm", il); + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "attn_residual", il); + residual = cur; + + // pre-ffn norm + cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_pre_norm", il); + + // feed-forward network + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + NULL, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SWIGLU, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // post ffn norm + cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); + cb(cur, "ffn_post_norm", il); + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + residual = ggml_get_rows(ctx0, residual, inp_out_ids); + } + + // residual connection + cur = ggml_add(ctx0, cur, residual); + cb(cur, "ffn_residual", il); + + inpL = cur; + } + + cur = inpL; + + // final norm + cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); + cb(cur, "result_norm", -1); + + // lm_head + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + + // Explicitly mark as output tensor to ensure proper backend assignment + ggml_set_output(cur); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } + +private: + ggml_tensor * build_plamo2_attn_layer( + llm_graph_input_attn_kv_unified * inp, + ggml_tensor * inp_pos, + ggml_tensor * cur, + const llama_model & model, + int il) { + + // self-attention + { + // PLaMo-2 uses combined QKV tensor + ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur); + cb(qkv, "wqkv", il); + + // split QKV tensor into Q, K, V + const int64_t n_embd_head_q = hparams.n_embd_head_k; + const int64_t n_embd_head_k = hparams.n_embd_head_k; + const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head_kv = hparams.n_head_kv(il); + + const int64_t q_offset = 0; + const int64_t k_offset = n_embd_head_q * n_head; + const int64_t v_offset = k_offset + n_embd_head_k * n_head_kv; + + ggml_tensor * Qcur = ggml_view_3d(ctx0, qkv, n_embd_head_q, n_head, n_tokens, n_embd_head_q * sizeof(float), qkv->nb[1], q_offset * ggml_element_size(qkv)); + ggml_tensor * Kcur = ggml_view_3d(ctx0, qkv, n_embd_head_k, n_head_kv, n_tokens, n_embd_head_k * sizeof(float), qkv->nb[1], k_offset * ggml_element_size(qkv)); + ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, qkv, n_embd_head_v * n_head_kv, n_tokens, qkv->nb[1], v_offset * ggml_element_size(qkv))); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head_v, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il); + } + + cb(cur, "attn_out", il); + + return cur; + } + + ggml_tensor * build_plamo2_mamba_layer( + llm_graph_input_rs * inp, + ggml_tensor * cur, + const llama_model & model, + const llama_ubatch & ubatch, + int il) { + + const auto * mctx_cur = inp->mctx; + + const auto kv_head = mctx_cur->get_head(); + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_heads = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_heads; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + ggml_tensor * conv = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // in_proj: {n_embd, 2*d_inner} @ {n_embd, n_seq_tokens, n_seqs} => {2*d_inner, n_seq_tokens, n_seqs} + ggml_tensor * zx = build_lora_mm(model.layers[il].ssm_in, cur); + cb(zx, "mamba_in_proj", il); + // {8192, 5, 1, 1} -> {8192, 1, 5, 1} + zx = ggml_permute(ctx0, zx, 0, 2, 1, 3); + zx = ggml_cont(ctx0, zx); + zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs); + cb(zx, "mamba_in_proj_out", il); + + // split into z and x + // => {head_dim * n_heads, n_seq_tokens, n_seqs} + ggml_tensor * x = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], head_dim*ggml_element_size(zx)); + x = ggml_cont(ctx0, x); + x = ggml_reshape_3d(ctx0, x, head_dim * n_heads, n_seq_tokens, n_seqs); + // x = ggml_permute(ctx0, x, 0, 2, 1, 3); + cb(x, "mamba_x_split", il); + + ggml_tensor * z = ggml_view_4d(ctx0, zx, head_dim, n_heads, n_seq_tokens, n_seqs, zx->nb[1], zx->nb[2], zx->nb[3], 0); + cb(z, "mamba_z_split", il); + + // conv1d + { + // => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0); + cb(conv_x, "mamba_conv1d_input", il); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner, n_seqs, + conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + cb(conv_states_all, "mamba_conv1d_state", il); + + // 1D convolution + x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + cb(x, "mamba_conv1d", il); + + x = ggml_silu(ctx0, x); + cb(x, "mamba_conv1d_silu", il); + } + + // SSM + { + // bcdt_proj: {d_inner, dt_rank + 2*d_state} @ {d_inner, n_seq_tokens, n_seqs} => {dt_rank + 2*d_state, n_seq_tokens, n_seqs} + ggml_tensor * x_bcdt = build_lora_mm(model.layers[il].ssm_x, x); + cb(x_bcdt, "mamba_bcdt_proj", il); + + // split into dt, B, C + const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + ggml_tensor * B = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], 0); + ggml_tensor * C = ggml_view_3d(ctx0, x_bcdt, d_state, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*d_state); + ggml_tensor * dt = ggml_view_3d(ctx0, x_bcdt, dt_dim, n_seq_tokens, n_seqs, x_bcdt->nb[1], x_bcdt->nb[2], ggml_element_size(x_bcdt)*(2*d_state)); + cb(B, "mamba_B_raw", il); + cb(C, "mamba_C_raw", il); + cb(dt, "mamba_dt_raw", il); + + // Apply RMS norm to dt, B, C (PLaMo-2 specific) + B = build_norm(B, model.layers[il].ssm_b_norm, NULL, LLM_NORM_RMS, il); + C = build_norm(C, model.layers[il].ssm_c_norm, NULL, LLM_NORM_RMS, il); + dt = build_norm(dt, model.layers[il].ssm_dt_norm, NULL, LLM_NORM_RMS, il); + cb(B, "mamba_B_normed", il); + cb(C, "mamba_C_normed", il); + cb(dt, "mamba_dt_normed", il); + + // dt_proj: {dt_rank, d_inner} @ {dt_rank, n_seq_tokens, n_seqs} => {d_inner, n_seq_tokens, n_seqs} + dt = build_lora_mm(model.layers[il].ssm_dt, dt); + dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + cb(dt, "mamba_dt_proj", il); + + ggml_tensor * A = ggml_reshape_2d(ctx0, model.layers[il].ssm_a, 1, n_heads); + cb(A, "mamba_A", il); + + x = ggml_view_4d(ctx0, x, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); + B = ggml_view_4d(ctx0, B, d_state, 1, n_seq_tokens, n_seqs, d_state * B->nb[0], B->nb[1], B->nb[2], 0); + C = ggml_view_4d(ctx0, C, d_state, 1, n_seq_tokens, n_seqs, d_state * C->nb[0], C->nb[1], C->nb[2], 0); + + // use the states and the indices provided by build_recurrent_state + // (this is necessary in order to properly use the states before they are overwritten, + // while avoiding to make unnecessary copies of the states) + auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) { + ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_heads, mctx_cur->get_size()); + + // Custom operator to optimize the parallel associative scan + // as described in the Annex D of the Mamba paper. + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids); + }; + + ggml_tensor * y_ssm = build_rs(inp, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows); + cb(y_ssm, "mamba_ssm_scan", il); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, n_heads*head_dim*d_state*n_seqs, n_heads*head_dim*n_seq_tokens*n_seqs*ggml_element_size(y_ssm)), + ggml_view_1d(ctx0, ssm_states_all, n_heads*head_dim*d_state*n_seqs, kv_head*n_seqs*n_heads*head_dim*d_state*ggml_element_size(ssm_states_all)))); + cb(ssm_states_all, "mamba_ssm_states", il); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0); + cb(y, "mamba_y_view", il); + + // Add D parameter and apply gating with z + // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} + ggml_tensor * D = ggml_reshape_2d(ctx0, model.layers[il].ssm_d, 1, n_heads); + y = ggml_add(ctx0, y, ggml_mul(ctx0, x, D)); + cb(y, "mamba_y_add_d", il); + + y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); + cb(y, "mamba_y_swiglu_z", il); + + // out_proj: {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + y = ggml_view_3d(ctx0, y, head_dim * n_heads, n_seq_tokens, n_seqs, y->nb[2], y->nb[3], 0); + cur = build_lora_mm(model.layers[il].ssm_out, y); + cb(cur, "mamba_out_proj", il); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + cb(cur, "mamba_out", il); + + return cur; + } +}; + struct llm_build_arcee : public llm_graph_context { - llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_arcee(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -15553,7 +16478,7 @@ struct llm_build_arcee : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -15612,7 +16537,7 @@ struct llm_build_arcee : public llm_graph_context { }; struct llm_build_hunyuan_moe : public llm_graph_context { - llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_hunyuan_moe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -15698,7 +16623,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context { LLM_NORM_RMS, il); cb(Qcur, "Qcur_norm", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -15773,7 +16698,7 @@ struct llm_build_hunyuan_moe : public llm_graph_context { }; struct llm_build_smollm3 : public llm_graph_context { - llm_build_smollm3(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { + llm_build_smollm3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -15850,7 +16775,7 @@ struct llm_build_smollm3 : public llm_graph_context { cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - cur = build_attn(inp_attn, gf, + cur = build_attn(inp_attn, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il); cb(cur, "attn_out", il); @@ -15912,7 +16837,7 @@ struct llm_build_smollm3 : public llm_graph_context { struct llm_build_lfm2 : public llm_graph_context { const llama_model & model; - llm_build_lfm2(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) { + llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) { ggml_tensor * cur = build_inp_embd(model.tok_embd); cb(cur, "model.embed_tokens", -1); @@ -15927,8 +16852,8 @@ struct llm_build_lfm2 : public llm_graph_context { cb(cur, "model.layers.{}.operator_norm", il); cur = hparams.is_recurrent(il) ? - build_shortconv_block(gf, cur, inp_hybrid->get_recr(), il) : - build_attn_block(gf, cur, inp_pos, inp_hybrid->get_attn(), il) ; + build_shortconv_block(cur, inp_hybrid->get_recr(), il) : + build_attn_block(cur, inp_pos, inp_hybrid->get_attn(), il) ; if (il == n_layer - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); @@ -15971,8 +16896,7 @@ struct llm_build_lfm2 : public llm_graph_context { return cur; } - ggml_tensor * build_attn_block(ggml_cgraph * gf, - ggml_tensor * cur, + ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv_unified * inp_attn, int il) const { @@ -16009,7 +16933,7 @@ struct llm_build_lfm2 : public llm_graph_context { ext_factor, attn_factor, beta_fast, beta_slow ); - cur = build_attn(inp_attn, gf, model.layers[il].wo, NULL, + cur = build_attn(inp_attn, model.layers[il].wo, NULL, q, k, v, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); cb(cur, "model.layers.{}.self_attn.out_proj", il); @@ -16017,11 +16941,22 @@ struct llm_build_lfm2 : public llm_graph_context { return cur; } - ggml_tensor * build_shortconv_block(ggml_cgraph * gf, - ggml_tensor * cur, + ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) { - const auto * mctx_cur = static_cast(mctx)->get_recr(); + const auto * mctx_cur = static_cast(mctx)->get_recr(); + const uint32_t kv_head = mctx_cur->get_head(); + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + GGML_ASSERT(hparams.n_shortconv_l_cache > 1); + const uint32_t d_conv = hparams.n_shortconv_l_cache - 1; + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur); cb(bcx, "model.layers.{}.conv.in_proj", il); @@ -16029,38 +16964,48 @@ struct llm_build_lfm2 : public llm_graph_context { constexpr auto n_chunks = 3; GGML_ASSERT(bcx->ne[0] % n_chunks == 0); auto const chunk_size = bcx->ne[0] / n_chunks; - auto * b = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 0 * chunk_size * ggml_element_size(bcx)); - auto * c = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 1 * chunk_size * ggml_element_size(bcx)); - auto * x = ggml_view_2d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->nb[1], 2 * chunk_size * ggml_element_size(bcx)); + auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 0*chunk_size*ggml_element_size(bcx)); + auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 1*chunk_size*ggml_element_size(bcx)); + auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], 2*chunk_size*ggml_element_size(bcx)); auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); - // read conv state directly, with build_rs generation is slower - ggml_tensor * conv_state = mctx_cur->get_r_l(il); - const int64_t n_seqs = ubatch.n_seqs; - ggml_tensor * conv = build_rs(inp_recr, gf, conv_state, hparams.n_embd_r(), n_seqs); - conv = ggml_reshape_3d(ctx0, conv_state, hparams.n_shortconv_l_cache - 1, hparams.n_embd, n_seqs); + // read conv state + auto * conv_state = mctx_cur->get_r_l(il); + auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); + auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); bx = ggml_concat(ctx0, conv, bx, 0); GGML_ASSERT(bx->ne[0] > conv->ne[0]); - auto * new_conv = ggml_view_2d(ctx0, bx, conv->ne[0], bx->ne[1], bx->nb[1], (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); + // last d_conv columns is a new conv state + auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], (bx->ne[0] - conv->ne[0])*ggml_element_size(bx)); GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); - // write conv state - ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, conv_state)); + // write new conv conv state + ggml_build_forward_expand( + gf, + ggml_cpy( + ctx0, + new_conv, + ggml_view_1d( + ctx0, + conv_state, + ggml_nelements(new_conv), + kv_head*d_conv*n_embd*ggml_element_size(new_conv) + ) + ) + ); auto * conv_kernel = model.layers[il].shortconv.conv; - GGML_ASSERT(hparams.n_shortconv_l_cache > 0); - - // construct ssm_conv op - ggml_tensor * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); + auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); cb(conv_out, "model.layers.{}.conv.conv", il); auto * y = ggml_mul(ctx0, c, conv_out); - y = build_lora_mm(model.layers[il].shortconv.out_proj, y); cb(y, "model.layers.{}.conv.out_proj", il); + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs); return y; } @@ -16078,6 +17023,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: case LLM_ARCH_WAVTOKENIZER_DEC: + case LLM_ARCH_DREAM: { res = nullptr; } break; @@ -16118,7 +17064,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, } else { const auto padding = llama_kv_cache_unified::get_padding(cparams); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); + uint32_t n_ctx_per_stream = cparams.n_ctx; + + if (!cparams.kv_unified) { + n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max; + n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); + + cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max; + } else { + n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding); + + cparams.n_ctx = n_ctx_per_stream; + } LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); @@ -16132,7 +17089,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, !cparams.flash_attn, cparams.offload_kqv, params.swa_full, - cparams.n_ctx, + cparams.kv_unified, + n_ctx_per_stream, cparams.n_seq_max, cparams.n_ubatch, padding); @@ -16146,7 +17104,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, params.type_v, !cparams.flash_attn, cparams.offload_kqv, - cparams.n_ctx, + cparams.kv_unified, + n_ctx_per_stream, cparams.n_seq_max, padding, hparams.n_swa, @@ -16159,227 +17118,233 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, return res; } -llm_graph_result_ptr llama_model::build_graph( - const llm_graph_params & params, - ggml_cgraph * gf, - llm_graph_type type) const { +ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { std::unique_ptr llm; switch (arch) { case LLM_ARCH_LLAMA: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_LLAMA4: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_DECI: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BAICHUAN: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_FALCON: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GROK: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_STARCODER: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_REFACT: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_NEO_BERT: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BLOOM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_MPT: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_STABLELM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_QWEN: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_QWEN2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; + case LLM_ARCH_DREAM: + { + llm = std::make_unique(*this, params); + } + break; case LLM_ARCH_QWEN2VL: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_QWEN2MOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_QWEN3: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_QWEN3MOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_PHI2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: { if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { - llm = std::make_unique> (*this, params, gf); + llm = std::make_unique> (*this, params); } else { - llm = std::make_unique>(*this, params, gf); + llm = std::make_unique>(*this, params); } } break; case LLM_ARCH_PLAMO: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_PLAMO2: + { + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GPT2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_CODESHELL: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_ORION: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_INTERNLM2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_MINICPM3: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GEMMA: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GEMMA2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GEMMA3: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GEMMA3N: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_STARCODER2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_MAMBA: case LLM_ARCH_MAMBA2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_JAMBA: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_XVERSE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_COMMAND_R: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_COHERE2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_DBRX: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_OLMO: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_OLMO2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_OLMOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_OPENELM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GPTNEOX: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_ARCTIC: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_CHATGLM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GLM4: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BITNET: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_T5: { - switch (type) { + switch (params.gtype) { case LLM_GRAPH_TYPE_ENCODER: - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); break; case LLM_GRAPH_TYPE_DEFAULT: case LLM_GRAPH_TYPE_DECODER: - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); break; default: GGML_ABORT("invalid graph type"); @@ -16387,99 +17352,111 @@ llm_graph_result_ptr llama_model::build_graph( } break; case LLM_ARCH_T5ENCODER: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_JAIS: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_NEMOTRON: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_EXAONE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_EXAONE4: + { + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_RWKV6: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_RWKV6QWEN2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_RWKV7: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_ARWKV7: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GRANITE: case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_MINICPM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_GRANITE_HYBRID: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_CHAMELEON: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_WAVTOKENIZER_DEC: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_PLM: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_BAILINGMOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_DOTS1: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_ARCEE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_ERNIE4_5: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_ERNIE4_5_MOE: + { + llm = std::make_unique(*this, params); } break; case LLM_ARCH_HUNYUAN_MOE: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_SMOLLM3: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_FALCON_H1: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_LFM2: { - llm = std::make_unique(*this, params, gf); + llm = std::make_unique(*this, params); } break; default: GGML_ABORT("fatal error"); } // add on pooling layer - llm->build_pooling(gf, cls, cls_b, cls_out, cls_out_b); + llm->build_pooling(cls, cls_b, cls_out, cls_out_b); - return std::move(llm->res); + return llm->res->get_gf(); } // @@ -16628,6 +17605,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_SMOLLM3: case LLM_ARCH_ARCEE: case LLM_ARCH_ERNIE4_5: + case LLM_ARCH_ERNIE4_5_MOE: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -16642,6 +17620,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: case LLM_ARCH_QWEN2: + case LLM_ARCH_DREAM: case LLM_ARCH_QWEN2MOE: case LLM_ARCH_QWEN3: case LLM_ARCH_QWEN3MOE: @@ -16651,6 +17630,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_PHI3: case LLM_ARCH_PHIMOE: case LLM_ARCH_PLAMO: + case LLM_ARCH_PLAMO2: case LLM_ARCH_GEMMA: case LLM_ARCH_GEMMA2: case LLM_ARCH_GEMMA3: @@ -16662,6 +17642,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ORION: case LLM_ARCH_NEMOTRON: case LLM_ARCH_EXAONE: + case LLM_ARCH_EXAONE4: case LLM_ARCH_MINICPM3: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: diff --git a/examples/talk-llama/llama-model.h b/examples/talk-llama/llama-model.h index 027a7f0c3e2..094e23808a8 100644 --- a/examples/talk-llama/llama-model.h +++ b/examples/talk-llama/llama-model.h @@ -99,8 +99,10 @@ enum llm_type { LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_A13B, + LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_235B_A22B, + LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_E2B, LLM_TYPE_E4B, }; @@ -452,10 +454,7 @@ struct llama_model { llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; // TODO: move this to new llm_arch_model_i interface - llm_graph_result_ptr build_graph( - const llm_graph_params & params, - ggml_cgraph * gf, - llm_graph_type type) const; + ggml_cgraph * build_graph(const llm_graph_params & params) const; private: struct impl; diff --git a/examples/talk-llama/llama-quant.cpp b/examples/talk-llama/llama-quant.cpp index 4dbd1e30991..a00af7a1d17 100644 --- a/examples/talk-llama/llama-quant.cpp +++ b/examples/talk-llama/llama-quant.cpp @@ -884,8 +884,7 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std:: if (std::regex pattern(tname); std::regex_search(tensor_name, pattern)) { if (qtype != new_type) { LLAMA_LOG_DEBUG("(overriding %s) ", ggml_type_name(new_type)); - new_type = qtype; - break; // if two or more types are specified for the tensor, first match wins + new_type = qtype; // if two or more types are specified for the same tensor, the last match wins } } } diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index e0e578d6394..e8bae645088 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -404,6 +405,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: + regex_exprs = { + // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp + // The custom handler implements all K2 patterns with proper Han character exclusion + "\\p{Han}+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_SUPERBPE: regex_exprs = { "\\p{N}+", @@ -1196,6 +1204,284 @@ struct llm_tokenizer_rwkv_session { const llm_tokenizer_rwkv & tokenizer; }; +struct llm_tokenizer_plamo2 : llm_tokenizer { + llm_tokenizer_plamo2(const llama_vocab & vocab) { + build(vocab); + } + + void build(const llama_vocab & vocab) { + // Reset internal structures + tokens_.clear(); + bytes_.assign(256, 0); + to_suffix_id_.clear(); + table_.clear(); + + // Build token list and byte mapping + std::unordered_map suffix_to_score; + std::unordered_map token_to_id; + + for (size_t token_id = 0; token_id < vocab.n_tokens(); ++token_id) { + const auto & entry = vocab.get_token_data(token_id); + tokens_.push_back(entry.text); + token_to_id[entry.text] = static_cast(token_id); + + // Handle byte tokens + if (vocab.is_byte(token_id)) { + if (entry.text.length() == 6 && entry.text.substr(0, 3) == "<0x" && entry.text.back() == '>') { + std::string hex_str = entry.text.substr(3, 2); + int byte_val = std::stoi(hex_str, nullptr, 16); + bytes_[byte_val] = static_cast(token_id); + } + continue; + } + + // Add token and all its suffixes to suffix_to_score + suffix_to_score[entry.text] = entry.score; + + // Extract suffixes character by character (UTF-8 aware) + std::vector cpts = unicode_cpts_from_utf8(entry.text); + for (size_t i = 1; i < cpts.size(); ++i) { + std::string suffix; + for (size_t j = i; j < cpts.size(); ++j) { + suffix += unicode_cpt_to_utf8(cpts[j]); + } + if (suffix_to_score.find(suffix) == suffix_to_score.end()) { + suffix_to_score[suffix] = std::numeric_limits::quiet_NaN(); + } + } + } + + // Check that all byte tokens are set + for (int i = 0; i < 256; ++i) { + if (bytes_[i] == 0) { + throw std::runtime_error("Byte token for <0x" + std::to_string(i) + "> is not set"); + } + } + + // Build suffix list in lexicographical order of reversed strings + std::vector suffixes; + for (const auto & pair : suffix_to_score) { + suffixes.push_back(pair.first); + } + suffixes.push_back(""); // Empty suffix + + std::sort(suffixes.begin(), suffixes.end(), [](const std::string & a, const std::string & b) { + std::string rev_a(a.rbegin(), a.rend()); + std::string rev_b(b.rbegin(), b.rend()); + return rev_a < rev_b; + }); + + // Build suffix_to_id and to_suffix_id_ + std::unordered_map suffix_to_id; + int32_t num_pieces = 0; + + for (const auto & suffix : suffixes) { + suffix_to_id[suffix] = num_pieces; + if (!suffix.empty()) { + std::vector cpts = unicode_cpts_from_utf8(suffix); + + std::string remaining; + for (size_t i = 1; i < cpts.size(); ++i) { + remaining += unicode_cpt_to_utf8(cpts[i]); + } + + int64_t piece_code = (static_cast(cpts[0]) << 32) | suffix_to_id[remaining]; + to_suffix_id_[piece_code] = num_pieces; + + // Count number of pieces for this suffix + int32_t pieces_for_suffix = 1; // sentinel row + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); + } + if (suffix_to_score.find(piece) != suffix_to_score.end()) { + pieces_for_suffix++; + } + } + num_pieces += pieces_for_suffix; + } else { + num_pieces++; // Empty suffix contributes one piece (sentinel row) + } + } + + // Build flattened table + table_.resize(num_pieces, std::vector(4, 0)); + int32_t table_idx = 0; + + for (const auto & suffix : suffixes) { + // Add all prefixes of the suffix to the table (in decreasing order of length) + std::vector cpts = unicode_cpts_from_utf8(suffix); + for (int32_t piece_length = static_cast(cpts.size()); piece_length > 0; --piece_length) { + std::string piece; + for (int32_t i = 0; i < piece_length; ++i) { + piece += unicode_cpt_to_utf8(cpts[i]); + } + + auto score_it = suffix_to_score.find(piece); + if (score_it == suffix_to_score.end()) { + continue; + } + + table_[table_idx][TABLE_PIECE_LENGTH] = piece_length; + auto token_it = token_to_id.find(piece); + table_[table_idx][TABLE_TOKEN_ID] = (token_it != token_to_id.end()) ? token_it->second : -1; + + float score = score_it->second; + table_[table_idx][TABLE_SCORE] = std::isfinite(score) ? + static_cast(std::round(score * 1e4)) : INVALID_SCORE; + table_[table_idx][TABLE_PIECE_ID] = suffix_to_id[piece]; + + table_idx++; + } + + // Add sentinel row + table_[table_idx][TABLE_PIECE_LENGTH] = 1; + table_[table_idx][TABLE_TOKEN_ID] = -1; + table_[table_idx][TABLE_SCORE] = UNKNOWN_SCORE; + table_idx++; + } + } + + std::vector encode(const std::string & text) const { + std::vector unicode_data = unicode_cpts_from_utf8(text); + // Skip the first code point if it is a BOM (Byte Order Mark) + if (!unicode_data.empty() && unicode_data[0] == 0xFEFF) { + unicode_data.erase(unicode_data.begin()); + } + + if (unicode_data.empty()) { + return {}; + } + + const size_t data_len = unicode_data.size(); + + // Initialize scores array (dynamic programming) + std::vector scores(data_len + 1, static_cast(1) << 60); + scores[data_len] = 0; + + // Path array to track best tokenization + std::vector> path(data_len + 1, std::vector(3, 0)); + + int32_t suffix_id = 0; + + // Process from end to beginning + for (int i = static_cast(data_len) - 1; i >= 0; --i) { + uint32_t c = unicode_data[i]; + + // Find next suffix ID + for (size_t p = suffix_id; p < table_.size(); ++p) { + int64_t piece_code = (static_cast(c) << 32) | table_[p][TABLE_PIECE_ID]; + auto it = to_suffix_id_.find(piece_code); + suffix_id = (it != to_suffix_id_.end()) ? it->second : 0; + + if (suffix_id > 0 || table_[p][TABLE_SCORE] == UNKNOWN_SCORE) { + break; + } + } + + // Update best path + for (size_t p = suffix_id; p < table_.size(); ++p) { + int32_t score = table_[p][TABLE_SCORE]; + if (score > INVALID_SCORE) { + int32_t piece_length = table_[p][TABLE_PIECE_LENGTH]; + int64_t s = scores[i + piece_length] - score; + + if (s < scores[i]) { + scores[i] = s; + path[i][PATH_TOKEN_LENGTH] = piece_length; + path[i][PATH_TOKEN_ID] = table_[p][TABLE_TOKEN_ID]; + path[i][PATH_NUM_TOKENS] = path[i + piece_length][PATH_NUM_TOKENS] + 1; + + if (score == UNKNOWN_SCORE) { + // Add UTF-8 byte count + path[i][PATH_NUM_TOKENS] += (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + } + } + } + + if (score == UNKNOWN_SCORE) { + break; + } + } + } + + // Decode the best path + std::vector token_ids; + token_ids.reserve(path[0][PATH_NUM_TOKENS]); + + int pos = 0; + while (pos < static_cast(data_len)) { + if (path[pos][PATH_TOKEN_ID] >= 0) { + token_ids.push_back(path[pos][PATH_TOKEN_ID]); + } else { + // Fall back to byte tokens + uint32_t c = unicode_data[pos]; + int s = 1 + (c >= 0x80) + (c >= 0x800) + (c >= 0x10000); + + for (int i = 0; i < s; ++i) { + uint8_t b; + if (s == 1) { + b = c; + } else { + if (i == 0) { + b = (0xF00 >> s) & 0xFF; + } else { + b = 0x80; + } + } + token_ids.push_back(bytes_[b | ((c >> ((s - i - 1) * 6)) & 0x3F)]); + } + } + + assert(path[pos][PATH_TOKEN_LENGTH] > 0); + pos += path[pos][PATH_TOKEN_LENGTH]; + } + + return token_ids; + } +private: + // Constants for table structure + static constexpr int32_t TABLE_PIECE_LENGTH = 0; + static constexpr int32_t TABLE_TOKEN_ID = 1; + static constexpr int32_t TABLE_SCORE = 2; + static constexpr int32_t TABLE_PIECE_ID = 3; + + // Constants for path array + static constexpr int32_t PATH_TOKEN_LENGTH = 0; + static constexpr int32_t PATH_TOKEN_ID = 1; + static constexpr int32_t PATH_NUM_TOKENS = 2; + + // Score constants + static constexpr int32_t INVALID_SCORE = -20000000; + static constexpr int32_t UNKNOWN_SCORE = -10000000; + + // List of tokens in the vocabulary + std::vector tokens_; + + // Mapping from byte code point to token ID (for byte fallback) + std::vector bytes_; + + // Mapping from piece code to suffix ID + std::unordered_map to_suffix_id_; + + // Flattened table representing the Trie structure + // Each row contains: [piece_length, token_id, score, piece_id] + std::vector> table_; +}; + +struct llm_tokenizer_plamo2_session { + llm_tokenizer_plamo2_session(const llm_tokenizer_plamo2 & tokenizer) : tokenizer(tokenizer) {} + + void tokenize(const std::string & text, std::vector & output) { + std::vector tokens = tokenizer.encode(text); + output.insert(output.end(), tokens.begin(), tokens.end()); + } + +private: + const llm_tokenizer_plamo2 & tokenizer; +}; + // // impl // @@ -1499,6 +1785,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { special_unk_id = LLAMA_TOKEN_NULL; special_sep_id = LLAMA_TOKEN_NULL; special_pad_id = LLAMA_TOKEN_NULL; + } else if (tokenizer_model == "plamo2") { + type = LLAMA_VOCAB_TYPE_PLAMO2; + + // PLaMo-2 default special tokens (these will be overridden by model config) + special_bos_id = 1; // <|plamo:bos|> + special_eos_id = 2; // <|plamo:eos|> + special_unk_id = 0; // <|plamo:unk|> + special_sep_id = LLAMA_TOKEN_NULL; + special_pad_id = 3; // <|plamo:pad|> + special_mask_id = LLAMA_TOKEN_NULL; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -1629,6 +1925,9 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } else if ( tokenizer_pre == "exaone") { pre_type = LLAMA_VOCAB_PRE_TYPE_EXAONE; + } else if ( + tokenizer_pre == "exaone4") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; } else if ( tokenizer_pre == "chameleon") { pre_type = LLAMA_VOCAB_PRE_TYPE_CHAMELEON; @@ -1665,6 +1964,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN; clean_spaces = false; + } else if ( + tokenizer_pre == "kimi-k2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -2145,13 +2448,14 @@ enum llama_vocab_type llama_vocab::impl::get_type() const { std::string llama_vocab::impl::type_name() const{ switch (type) { - case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; - case LLAMA_VOCAB_TYPE_SPM: return "SPM"; - case LLAMA_VOCAB_TYPE_BPE: return "BPE"; - case LLAMA_VOCAB_TYPE_WPM: return "WPM"; - case LLAMA_VOCAB_TYPE_UGM: return "UGM"; - case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; - default: return "unknown"; + case LLAMA_VOCAB_TYPE_NONE: return "no vocab"; + case LLAMA_VOCAB_TYPE_SPM: return "SPM"; + case LLAMA_VOCAB_TYPE_BPE: return "BPE"; + case LLAMA_VOCAB_TYPE_WPM: return "WPM"; + case LLAMA_VOCAB_TYPE_UGM: return "UGM"; + case LLAMA_VOCAB_TYPE_RWKV: return "RWKV"; + case LLAMA_VOCAB_TYPE_PLAMO2: return "PLaMo2"; + default: return "unknown"; } } @@ -2234,6 +2538,9 @@ void llama_vocab::impl::init_tokenizer(enum llama_vocab_type type) { case LLAMA_VOCAB_TYPE_RWKV: tokenizer = std::make_unique(vocab); break; + case LLAMA_VOCAB_TYPE_PLAMO2: + tokenizer = std::make_unique(vocab); + break; default: GGML_ABORT("unsupported vocab type"); } @@ -2566,6 +2873,23 @@ std::vector llama_vocab::impl::tokenize( if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); +#ifdef PRETOKENIZERDEBUG + LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); +#endif + + session.tokenize(text, output); + } else { // if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_TOKEN) + output.push_back(fragment.token); + } + } + } break; + case LLAMA_VOCAB_TYPE_PLAMO2: + { + llm_tokenizer_plamo2_session session(*static_cast(tokenizer.get())); + for (const auto & fragment : fragment_buffer) { + if (fragment.type == FRAGMENT_BUFFER_VARIANT_TYPE_RAW_TEXT) { + std::string text = fragment.raw_text.substr(fragment.offset, fragment.length); + #ifdef PRETOKENIZERDEBUG LLAMA_LOG_WARN("TT: (%ld %ld %ld) '%s'\n", text.length(), fragment.offset, fragment.length, text.c_str()); #endif @@ -2664,6 +2988,24 @@ int32_t llama_vocab::impl::token_to_piece(llama_token token, char * buf, int32_t memcpy(buf, result.data(), result.size()); return (int)result.size(); } + case LLAMA_VOCAB_TYPE_PLAMO2: { + // PLaMo-2 uses similar token handling as BPE/SPM + if (vocab.is_byte(token)) { + // Handle byte tokens like <0xXX> + if (token_text.length() == 6 && token_text.substr(0, 3) == "<0x" && token_text.back() == '>') { + int hex_val = std::stoi(token_text.substr(3, 2), nullptr, 16); + if (length < 1) { + return -1; + } + buf[0] = static_cast(hex_val); + return 1; + } + } + + // Normal token - just copy the text + std::string result = token_text; + return _try_copy(result.data(), result.size()); + } default: GGML_ABORT("fatal error"); } @@ -2908,6 +3250,12 @@ llama_token llama_vocab::byte_to_token(uint8_t ch) const { case LLAMA_VOCAB_TYPE_BPE: { return pimpl->token_to_id.at(unicode_byte_to_utf8(ch)); } + case LLAMA_VOCAB_TYPE_PLAMO2: { + // PLaMo-2 uses byte tokens in format <0xXX> + char hex_str[8]; + snprintf(hex_str, sizeof(hex_str), "<0x%02X>", ch); + return pimpl->token_to_id.at(hex_str); + } default: GGML_ABORT("fatal error"); } @@ -3009,6 +3357,10 @@ llama_token llama_vocab::token_fim_sep() const { return pimpl->special_fim_sep_id; } +llama_token llama_vocab::token_mask() const { + return pimpl->special_mask_id; +} + bool llama_vocab::get_add_space_prefix() const { return pimpl->add_space_prefix; } @@ -3249,6 +3601,10 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) { return vocab->token_fim_sep(); } +llama_token llama_vocab_mask(const struct llama_vocab* vocab) { + return vocab->token_mask(); +} + // deprecated const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) { return llama_vocab_get_text(vocab, token); @@ -3385,4 +3741,3 @@ int32_t llama_detokenize( bool unparse_special) { return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special); } - diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 46a1ccecb51..842b129e861 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -45,6 +45,7 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, }; struct LLM_KV; @@ -100,6 +101,7 @@ struct llama_vocab { llama_token token_sep() const; llama_token token_nl () const; llama_token token_pad() const; + llama_token token_mask() const; llama_token token_prefix() const; llama_token token_middle() const; diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h index f73b1ab65fe..6f454a508a0 100644 --- a/examples/talk-llama/llama.h +++ b/examples/talk-llama/llama.h @@ -71,12 +71,13 @@ extern "C" { typedef int32_t llama_seq_id; enum llama_vocab_type { - LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab - LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback - LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE - LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece - LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram - LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_NONE = 0, // For models without vocab + LLAMA_VOCAB_TYPE_SPM = 1, // LLaMA tokenizer based on byte-level BPE with byte fallback + LLAMA_VOCAB_TYPE_BPE = 2, // GPT-2 tokenizer based on byte-level BPE + LLAMA_VOCAB_TYPE_WPM = 3, // BERT tokenizer based on WordPiece + LLAMA_VOCAB_TYPE_UGM = 4, // T5 tokenizer based on Unigram + LLAMA_VOCAB_TYPE_RWKV = 5, // RWKV tokenizer based on greedy tokenization + LLAMA_VOCAB_TYPE_PLAMO2 = 6, // PLaMo-2 tokenizer based on Aho-Corasick with dynamic programming }; enum llama_rope_type { @@ -334,6 +335,9 @@ extern "C" { bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) // NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases // ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573 + bool kv_unified; // use a unified buffer across the input sequences when computing the attention + // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix + // ref: https://github.com/ggml-org/llama.cpp/pull/14363 }; // model quantization parameters @@ -724,7 +728,7 @@ extern "C" { // - lazily on next llama_decode() // p0 < 0 : [0, p1] // p1 < 0 : [p0, inf) - DEPRECATED(void llama_kv_self_seq_div( + DEPRECATED(LLAMA_API void llama_kv_self_seq_div( struct llama_context * ctx, llama_seq_id seq_id, llama_pos p0, @@ -952,6 +956,7 @@ extern "C" { // in the order they have appeared in the batch. // Rows: number of tokens for which llama_batch.logits[i] != 0 // Cols: n_vocab + // TODO: deprecate in favor of llama_get_logits_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522) LLAMA_API float * llama_get_logits(struct llama_context * ctx); // Logits for the ith token. For positive indices, Equivalent to: @@ -966,6 +971,7 @@ extern "C" { // in the order they have appeared in the batch. // shape: [n_outputs*n_embd] // Otherwise, returns NULL. + // TODO: deprecate in favor of llama_get_embeddings_ith() (ref: https://github.com/ggml-org/llama.cpp/pull/14853#issuecomment-3113143522) LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); // Get the embeddings for the ith token. For positive indices, Equivalent to: @@ -1004,6 +1010,7 @@ extern "C" { LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding + LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab); LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab); @@ -1389,6 +1396,7 @@ extern "C" { int32_t n_p_eval; int32_t n_eval; + int32_t n_reused; // number of times a ggml compute graph had been reused }; struct llama_perf_sampler_data { diff --git a/examples/talk-llama/unicode.cpp b/examples/talk-llama/unicode.cpp index 43a4581b961..65f36651715 100644 --- a/examples/talk-llama/unicode.cpp +++ b/examples/talk-llama/unicode.cpp @@ -557,6 +557,178 @@ static std::vector unicode_regex_split_stl(const std::string & text, con return bpe_offsets; } +// K2 system regex patterns (from tokenization_kimi.py): +// [\p{Han}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+ +static std::vector unicode_regex_split_custom_kimi_k2(const std::string & text, const std::vector & offsets) { + std::vector bpe_offsets; + bpe_offsets.reserve(offsets.size()); + + const auto cpts = unicode_cpts_from_utf8(text); + + size_t start = 0; + for (auto offset : offsets) { + const size_t offset_ini = start; + const size_t offset_end = start + offset; + assert(offset_end <= cpts.size()); + start = offset_end; + + static const uint32_t OUT_OF_RANGE = 0xFFFFFFFF; + auto _get_cpt = [&] (const size_t pos) -> uint32_t { + return (offset_ini <= pos && pos < offset_end) ? cpts[pos] : OUT_OF_RANGE; + }; + + auto _get_flags = [&] (const size_t pos) -> unicode_cpt_flags { + return (offset_ini <= pos && pos < offset_end) ? unicode_cpt_flags_from_cpt(cpts[pos]) : unicode_cpt_flags{}; + }; + + size_t _prev_end = offset_ini; + auto _add_token = [&] (const size_t end) -> size_t { + assert(_prev_end <= end && end <= offset_end); + size_t len = end - _prev_end; + if (len > 0) { + bpe_offsets.push_back(len); + } + _prev_end = end; + return len; + }; + + for (size_t pos = offset_ini; pos < offset_end; /*pos++*/ ) { + const uint32_t cpt = _get_cpt(pos); + const auto flags = _get_flags(pos); + + // Pattern 1: [\p{Han}]+ (Chinese characters) + if (unicode_cpt_is_han(cpt)) { + while (unicode_cpt_is_han(_get_cpt(pos))) { + pos++; + } + _add_token(pos); + continue; + } + + // Pattern 2 & 3: Letter words excluding Han characters with optional contractions + // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+(?:'s|'t|'re|'ve|'m|'ll|'d)? + // [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]+[\p{Ll}\p{Lm}\p{Lo}\p{M}&&[^\p{Han}]]*(?:'s|'t|'re|'ve|'m|'ll|'d)? + // Check if current char is a letter OR if current char could be a leading char and next char is a letter + bool is_letter_pattern = (flags.is_letter && !unicode_cpt_is_han(cpt)) || + (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number) && + _get_flags(pos + 1).is_letter && !unicode_cpt_is_han(_get_cpt(pos + 1))); + + if (is_letter_pattern) { + // Handle optional leading non-letter/non-number character + bool has_leading_char = false; + if (!(cpt == '\r' || cpt == '\n' || flags.is_letter || flags.is_number)) { + has_leading_char = true; + pos++; + } + + // Match letter sequence (excluding Han characters) + bool has_letters = false; + while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) { + has_letters = true; + pos++; + } + + // Only proceed if we found letters (after potentially skipping leading char) + if (has_letters || (!has_leading_char && _get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos)))) { + if (!has_letters) pos++; // consume the first letter if we didn't already + + // Continue consuming letters + while (_get_flags(pos).is_letter && !unicode_cpt_is_han(_get_cpt(pos))) { + pos++; + } + + // Check for optional contractions (?:'s|'t|'re|'ve|'m|'ll|'d) + if (_get_cpt(pos) == '\'' && pos + 1 < offset_end) { + uint32_t cpt_next = unicode_tolower(_get_cpt(pos + 1)); + if (cpt_next == 's' || cpt_next == 't' || cpt_next == 'm' || cpt_next == 'd') { + pos += 2; + } else if (pos + 2 < offset_end) { + uint32_t cpt_next_next = unicode_tolower(_get_cpt(pos + 2)); + if ((cpt_next == 'r' && cpt_next_next == 'e') || + (cpt_next == 'v' && cpt_next_next == 'e') || + (cpt_next == 'l' && cpt_next_next == 'l')) { + pos += 3; + } + } + } + + _add_token(pos); + continue; + } else if (has_leading_char) { + // We consumed a leading char but found no letters, backtrack + pos--; + } + } + + // Pattern 4: \p{N}{1,3} (numbers 1-3 digits) + if (flags.is_number) { + size_t ini = pos; + while (_get_flags(pos).is_number) { + if (++pos - ini >= 3) { + _add_token(pos); + ini = pos; + } + } + _add_token(pos); + continue; + } + + // Pattern 5: ?[^\s\p{L}\p{N}]+[\r\n]* (optional space + non-word chars + optional newlines) + auto flags2 = (cpt == ' ' ? _get_flags(pos + 1) : flags); + if (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) { + pos += (cpt == ' '); + while (!(flags2.is_whitespace || flags2.is_letter || flags2.is_number) && flags2.as_uint()) { + flags2 = _get_flags(++pos); + } + // Match optional [\r\n]* + uint32_t cpt2 = _get_cpt(pos); + while (cpt2 == '\r' || cpt2 == '\n') { + cpt2 = _get_cpt(++pos); + } + _add_token(pos); + continue; + } + + // Count whitespace characters + size_t num_whitespaces = 0; + size_t last_end_r_or_n = 0; + while (_get_flags(pos + num_whitespaces).is_whitespace) { + uint32_t cpt2 = _get_cpt(pos + num_whitespaces); + if (cpt2 == '\r' || cpt2 == '\n') { + last_end_r_or_n = pos + num_whitespaces + 1; + } + num_whitespaces++; + } + + // Pattern 6: \s*[\r\n]+ (whitespace with newlines) + if (last_end_r_or_n > 0) { + pos = last_end_r_or_n; + _add_token(pos); + continue; + } + + // Pattern 7: \s+(?!\S) (trailing whitespace) + if (num_whitespaces > 1 && _get_cpt(pos + num_whitespaces) != OUT_OF_RANGE) { + pos += num_whitespaces - 1; + _add_token(pos); + continue; + } + + // Pattern 8: \s+ (general whitespace) + if (num_whitespaces > 0) { + pos += num_whitespaces; + _add_token(pos); + continue; + } + + // No matches - consume single character + _add_token(++pos); + } + } + + return bpe_offsets; +} + static std::vector unicode_regex_split_custom(const std::string & text, const std::string & regex_expr, const std::vector & offsets) { std::vector bpe_offsets; @@ -567,6 +739,9 @@ static std::vector unicode_regex_split_custom(const std::string & text, regex_expr == "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+") { bpe_offsets = unicode_regex_split_custom_llama3(text, offsets); + } else if (regex_expr == "\\p{Han}+") { + // K2's first pattern - handle all K2 patterns together + bpe_offsets = unicode_regex_split_custom_kimi_k2(text, offsets); } return bpe_offsets; @@ -672,6 +847,38 @@ uint32_t unicode_tolower(uint32_t cpt) { return cpt; // Return the original code point if no lowercase mapping is found } +bool unicode_cpt_is_han(uint32_t cpt) { + // Han character ranges (Chinese/CJK characters) + // CJK Unified Ideographs (most common) + if (cpt >= 0x4E00 && cpt <= 0x9FFF) return true; + + // CJK Extension A + if (cpt >= 0x3400 && cpt <= 0x4DBF) return true; + + // CJK Extension B + if (cpt >= 0x20000 && cpt <= 0x2A6DF) return true; + + // CJK Extension C + if (cpt >= 0x2A700 && cpt <= 0x2B73F) return true; + + // CJK Extension D + if (cpt >= 0x2B740 && cpt <= 0x2B81F) return true; + + // CJK Extension E + if (cpt >= 0x2B820 && cpt <= 0x2CEAF) return true; + + // CJK Extension F + if (cpt >= 0x2CEB0 && cpt <= 0x2EBEF) return true; + + // CJK Compatibility Ideographs + if (cpt >= 0xF900 && cpt <= 0xFAFF) return true; + + // CJK Compatibility Ideographs Supplement + if (cpt >= 0x2F800 && cpt <= 0x2FA1F) return true; + + return false; +} + std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs) { // unicode categories static const std::map k_ucat_enum = { diff --git a/examples/talk-llama/unicode.h b/examples/talk-llama/unicode.h index c27098df7d4..0a5fa2a78ce 100644 --- a/examples/talk-llama/unicode.h +++ b/examples/talk-llama/unicode.h @@ -63,4 +63,6 @@ uint8_t unicode_utf8_to_byte(const std::string & utf8); uint32_t unicode_tolower(uint32_t cpt); +bool unicode_cpt_is_han(uint32_t cpt); + std::vector unicode_regex_split(const std::string & text, const std::vector & regex_exprs); diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index de6d789c98a..20467c54da1 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -131,7 +131,7 @@ option(GGML_RVV "ggml: enable rvv" ON) option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) option(GGML_VXE "ggml: enable vxe" ON) -option(GGML_NNPA "ggml: enable nnpa" ON) +option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877 option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") @@ -174,6 +174,8 @@ option(GGML_HIP_GRAPHS "ggml: use HIP graph, experimental, option(GGML_HIP_NO_VMM "ggml: do not try to use HIP VMM" ON) option(GGML_HIP_ROCWMMA_FATTN "ggml: enable rocWMMA for FlashAttention" OFF) option(GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 "ggml: enable rocWMMA FlashAttention on GFX12" OFF) +option(GGML_MUSA_GRAPHS "ggml: use MUSA graph, experimental, unstable" OFF) +option(GGML_MUSA_MUDNN_COPY "ggml: enable muDNN for accelerated copy" OFF) option(GGML_VULKAN "ggml: use Vulkan" OFF) option(GGML_VULKAN_CHECK_RESULTS "ggml: run Vulkan op checks" OFF) option(GGML_VULKAN_DEBUG "ggml: enable Vulkan debug output" OFF) diff --git a/ggml/cmake/ggml-config.cmake.in b/ggml/cmake/ggml-config.cmake.in index 8c2dc31c6da..fe34cda4e01 100644 --- a/ggml/cmake/ggml-config.cmake.in +++ b/ggml/cmake/ggml-config.cmake.in @@ -1,152 +1,189 @@ - -@GGML_VARIABLES_EXPANDED@ - @PACKAGE_INIT@ -set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@") -set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@") -#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@") - -find_package(Threads REQUIRED) - -find_library(GGML_LIBRARY ggml - REQUIRED - HINTS ${GGML_LIB_DIR} - NO_CMAKE_FIND_ROOT_PATH) - -add_library(ggml::ggml UNKNOWN IMPORTED) -set_target_properties(ggml::ggml - PROPERTIES - IMPORTED_LOCATION "${GGML_LIBRARY}") - -find_library(GGML_BASE_LIBRARY ggml-base - REQUIRED - HINTS ${GGML_LIB_DIR} - NO_CMAKE_FIND_ROOT_PATH) - -add_library(ggml::ggml-base UNKNOWN IMPORTED) -set_target_properties(ggml::ggml-base - PROPERTIES - IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") +@GGML_VARIABLES_EXPANDED@ +# Find all dependencies before creating any target. +include(CMakeFindDependencyMacro) +find_dependency(Threads) if (NOT GGML_SHARED_LIB) + set(GGML_CPU_INTERFACE_LINK_LIBRARIES "") + set(GGML_CPU_INTERFACE_LINK_OPTIONS "") + if (APPLE AND GGML_ACCELERATE) - find_library(ACCELERATE_FRAMEWORK Accelerate REQUIRED) + find_library(ACCELERATE_FRAMEWORK Accelerate) + if(NOT ACCELERATE_FRAMEWORK) + set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0) + return() + endif() list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${ACCELERATE_FRAMEWORK}) endif() - if (GGML_OPENMP) - find_package(OpenMP REQUIRED) + if (GGML_OPENMP_ENABLED) + find_dependency(OpenMP) list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES OpenMP::OpenMP_C OpenMP::OpenMP_CXX) endif() if (GGML_CPU_HBM) - find_library(memkind memkind REQUIRED) + find_library(memkind memkind) + if(NOT memkind) + set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0) + return() + endif() list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES memkind) endif() if (GGML_BLAS) - find_package(BLAS REQUIRED) + find_dependency(BLAS) list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES ${BLAS_LIBRARIES}) list(APPEND GGML_CPU_INTERFACE_LINK_OPTIONS ${BLAS_LINKER_FLAGS}) endif() if (GGML_CUDA) - find_package(CUDAToolkit REQUIRED) + set(GGML_CUDA_INTERFACE_LINK_LIBRARIES "") + find_dependency(CUDAToolkit) + if (GGML_STATIC) + list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $) + if (WIN32) + list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $ $) + else() + list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $ $) + endif() + endif() + if (NOT GGML_CUDA_NO_VMM) + list(APPEND GGML_CUDA_INTERFACE_LINK_LIBRARIES $) + endif() endif() if (GGML_METAL) - find_library(FOUNDATION_LIBRARY Foundation REQUIRED) - find_library(METAL_FRAMEWORK Metal REQUIRED) - find_library(METALKIT_FRAMEWORK MetalKit REQUIRED) + find_library(FOUNDATION_LIBRARY Foundation) + find_library(METAL_FRAMEWORK Metal) + find_library(METALKIT_FRAMEWORK MetalKit) + if(NOT FOUNDATION_LIBRARY OR NOT METAL_FRAMEWORK OR NOT METALKIT_FRAMEWORK) + set(${CMAKE_FIND_PACKAGE_NAME}_FOUND 0) + return() + endif() + set(GGML_METAL_INTERFACE_LINK_LIBRARIES + ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK}) + endif() - list(APPEND GGML_METAL_INTERFACE_LINK_LIBRARIES - ${FOUNDATION_LIBRARY} ${METAL_FRAMEWORK} ${METALKIT_FRAMEWORK}) + if (GGML_OPENCL) + find_dependency(OpenCL) + set(GGML_OPENCL_INTERFACE_LINK_LIBRARIES $) endif() if (GGML_VULKAN) - find_package(Vulkan REQUIRED) - list(APPEND GGML_VULKAN_INTERFACE_LINK_LIBRARIES Vulkan::Vulkan) + find_dependency(Vulkan) + set(GGML_VULKAN_INTERFACE_LINK_LIBRARIES $) endif() if (GGML_HIP) - find_package(hip REQUIRED) - find_package(hipblas REQUIRED) - find_package(rocblas REQUIRED) - list(APPEND GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas) + find_dependency(hip) + find_dependency(hipblas) + find_dependency(rocblas) + set(GGML_HIP_INTERFACE_LINK_LIBRARIES hip::host roc::rocblas roc::hipblas) endif() if (GGML_SYCL) + set(GGML_SYCL_INTERFACE_LINK_LIBRARIES "") find_package(DNNL) if (${DNNL_FOUND} AND GGML_SYCL_TARGET STREQUAL "INTEL") list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES DNNL::dnnl) endif() if (WIN32) - find_package(IntelSYCL REQUIRED) - find_package(MKL REQUIRED) + find_dependency(IntelSYCL) + find_dependency(MKL) list(APPEND GGML_SYCL_INTERFACE_LINK_LIBRARIES IntelSYCL::SYCL_CXX MKL::MKL MKL::MKL_SYCL) endif() endif() endif() -set(_ggml_all_targets "") -foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) - string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}") - string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx) +set_and_check(GGML_INCLUDE_DIR "@PACKAGE_GGML_INCLUDE_INSTALL_DIR@") +set_and_check(GGML_LIB_DIR "@PACKAGE_GGML_LIB_INSTALL_DIR@") +#set_and_check(GGML_BIN_DIR "@PACKAGE_GGML_BIN_INSTALL_DIR@") - find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend} +if(NOT TARGET ggml::ggml) + find_package(Threads REQUIRED) + + find_library(GGML_LIBRARY ggml REQUIRED HINTS ${GGML_LIB_DIR} NO_CMAKE_FIND_ROOT_PATH) - message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}") + add_library(ggml::ggml UNKNOWN IMPORTED) + set_target_properties(ggml::ggml + PROPERTIES + IMPORTED_LOCATION "${GGML_LIBRARY}") + + find_library(GGML_BASE_LIBRARY ggml-base + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) - add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED) - set_target_properties(ggml::${_ggml_backend} + add_library(ggml::ggml-base UNKNOWN IMPORTED) + set_target_properties(ggml::ggml-base PROPERTIES - INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}" - IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" - IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}" - INTERFACE_COMPILE_FEATURES c_std_90 - POSITION_INDEPENDENT_CODE ON) - - string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}") - if(is_cpu_variant) - list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") - set_target_properties(ggml::${_ggml_backend} - PROPERTIES - INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}") + IMPORTED_LOCATION "${GGML_BASE_LIBRARY}") - if(GGML_CPU_INTERFACE_LINK_OPTIONS) - set_target_properties(ggml::${_ggml_backend} - PROPERTIES - INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}") - endif() + set(_ggml_all_targets "") + foreach(_ggml_backend ${GGML_AVAILABLE_BACKENDS}) + string(REPLACE "-" "_" _ggml_backend_pfx "${_ggml_backend}") + string(TOUPPER "${_ggml_backend_pfx}" _ggml_backend_pfx) + + find_library(${_ggml_backend_pfx}_LIBRARY ${_ggml_backend} + REQUIRED + HINTS ${GGML_LIB_DIR} + NO_CMAKE_FIND_ROOT_PATH) - else() - list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") + message(STATUS "Found ${${_ggml_backend_pfx}_LIBRARY}") + + add_library(ggml::${_ggml_backend} UNKNOWN IMPORTED) set_target_properties(ggml::${_ggml_backend} PROPERTIES - INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}") + INTERFACE_INCLUDE_DIRECTORIES "${GGML_INCLUDE_DIR}" + IMPORTED_LINK_INTERFACE_LANGUAGES "CXX" + IMPORTED_LOCATION "${${_ggml_backend_pfx}_LIBRARY}" + INTERFACE_COMPILE_FEATURES c_std_90 + POSITION_INDEPENDENT_CODE ON) + + string(REGEX MATCH "^ggml-cpu" is_cpu_variant "${_ggml_backend}") + if(is_cpu_variant) + list(APPEND GGML_CPU_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_LIBRARIES "${GGML_CPU_INTERFACE_LINK_LIBRARIES}") - if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS) + if(GGML_CPU_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${GGML_CPU_INTERFACE_LINK_OPTIONS}") + endif() + + else() + list(APPEND ${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES "ggml::ggml-base") set_target_properties(ggml::${_ggml_backend} PROPERTIES - INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}") + INTERFACE_LINK_LIBRARIES "${${_ggml_backend_pfx}_INTERFACE_LINK_LIBRARIES}") + + if(${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS) + set_target_properties(ggml::${_ggml_backend} + PROPERTIES + INTERFACE_LINK_OPTIONS "${${_ggml_backend_pfx}_INTERFACE_LINK_OPTIONS}") + endif() endif() - endif() - list(APPEND _ggml_all_targets ggml::${_ggml_backend}) -endforeach() + list(APPEND _ggml_all_targets ggml::${_ggml_backend}) + endforeach() -list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}") -set_target_properties(ggml::ggml - PROPERTIES - INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}") + list(APPEND GGML_INTERFACE_LINK_LIBRARIES ggml::ggml-base "${_ggml_all_targets}") + set_target_properties(ggml::ggml + PROPERTIES + INTERFACE_LINK_LIBRARIES "${GGML_INTERFACE_LINK_LIBRARIES}") -add_library(ggml::all INTERFACE IMPORTED) -set_target_properties(ggml::all - PROPERTIES - INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}") + add_library(ggml::all INTERFACE IMPORTED) + set_target_properties(ggml::all + PROPERTIES + INTERFACE_LINK_LIBRARIES "${_ggml_all_targets}") + +endif() check_required_components(ggml) diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index b7498b8d402..eaf41e5a6c8 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -647,6 +647,7 @@ struct ggml_backend_sched { // pipeline parallelism support int n_copies; int cur_copy; + int next_copy; ggml_backend_event_t events[GGML_SCHED_MAX_BACKENDS][GGML_SCHED_MAX_COPIES]; struct ggml_tensor * graph_inputs[GGML_SCHED_MAX_SPLIT_INPUTS]; int n_graph_inputs; @@ -1433,8 +1434,6 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } } - sched->cur_copy = (sched->cur_copy + 1) % sched->n_copies; - return GGML_STATUS_SUCCESS; } @@ -1535,10 +1534,10 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) { bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); - ggml_backend_sched_split_graph(sched, measure_graph); - ggml_backend_sched_synchronize(sched); + ggml_backend_sched_split_graph(sched, measure_graph); + if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { return false; } @@ -1550,6 +1549,10 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); + GGML_ASSERT(!sched->is_alloc); + + sched->cur_copy = sched->next_copy; + sched->next_copy = (sched->next_copy + 1) % sched->n_copies; ggml_backend_sched_split_graph(sched, graph); @@ -1590,7 +1593,7 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { // if the graph is not already allocated, always use copy 0 after a synchronization // this ensures that during generation the same copy is used every time, // which avoids changes in the graph that could cause CUDA or other graphs to be disabled - sched->cur_copy = 0; + sched->next_copy = 0; } } diff --git a/ggml/src/ggml-cann/acl_tensor.cpp b/ggml/src/ggml-cann/acl_tensor.cpp index f311864d486..8ffac31dd66 100755 --- a/ggml/src/ggml-cann/acl_tensor.cpp +++ b/ggml/src/ggml-cann/acl_tensor.cpp @@ -77,6 +77,8 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne, for (int i = 0; i < final_dims; i++) { acl_storage_len += (acl_ne[i] - 1) * acl_stride[i]; } + size_t elem_offset = offset / ggml_element_size(tensor); + acl_storage_len += elem_offset; // Reverse ne and stride. std::reverse(acl_ne, acl_ne + final_dims); @@ -84,7 +86,7 @@ aclTensor* ggml_cann_create_tensor(const ggml_tensor* tensor, int64_t* ne, aclTensor* acl_tensor = aclCreateTensor( acl_ne, final_dims, ggml_cann_type_mapping(tensor->type), acl_stride, - offset / ggml_element_size(tensor), format, &acl_storage_len, 1, + elem_offset, format, &acl_storage_len, 1, tensor->data); return acl_tensor; diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 4d5c2c18252..d616c491ae9 100755 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -99,7 +99,7 @@ void bcast_shape(ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, aclT } } -void ggml_cann_unary_op( +void ggml_cann_op_unary( std::function unary_op, ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src = dst->src[0]; @@ -111,6 +111,42 @@ void ggml_cann_unary_op( ggml_cann_release_resources(ctx, acl_src, acl_dst); } +void ggml_cann_op_unary_gated( + std::function unary_op, + ggml_backend_cann_context& ctx, ggml_tensor* dst) { + ggml_tensor* src0 = dst->src[0]; + ggml_tensor* src1 = dst->src[1]; + + GGML_ASSERT(ggml_is_contiguous_1(src0)); + GGML_ASSERT(ggml_is_contiguous_1(dst)); + const int32_t swapped = ggml_get_op_params_i32(dst, 1); + + aclTensor* acl_dst = ggml_cann_create_tensor(dst); + aclTensor *acl_src0 = nullptr, *acl_src1 = nullptr; + if(src1) { + GGML_ASSERT(ggml_is_contiguous_1(src1)); + GGML_ASSERT(src0->type == src1->type); + + acl_src0 = ggml_cann_create_tensor(src0); + acl_src1 = ggml_cann_create_tensor(src1); + } else { + int64_t ne[] = {src0->ne[0] / 2, src0->ne[1], src0->ne[2], src0->ne[3]}; + size_t nb[] = {src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]}; + acl_src0 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, 0); + acl_src1 = ggml_cann_create_tensor(src0, ne, nb, GGML_MAX_DIMS, ACL_FORMAT_ND, ne[0] * ggml_element_size(src0)); + if (swapped) { + std::swap(acl_src0, acl_src1); + } + } + + unary_op(ctx, acl_src0, acl_dst); + GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1); + + ggml_cann_release_resources(ctx, acl_src0, acl_dst); + if(src1) + ggml_cann_release_resources(ctx, acl_src1); +} + /** * @brief Repeats elements of a tensor along each dimension according to the * specified repeat array. @@ -1785,8 +1821,27 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx, size_t transpose_nb[] = {bcast_weight_nb[1], bcast_weight_nb[0], bcast_weight_nb[2], bcast_weight_nb[3], bcast_weight_nb[4], bcast_weight_nb[5]}; - aclTensor* acl_weight_tensor = - ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims); + aclTensor* acl_weight_tensor; + + bool weightToNZ = false; +#ifdef ASCEND_310P + weightToNZ = (getenv("GGML_CANN_WEIGHT_NZ") != nullptr); +#endif + if (weightToNZ && is_matmul_weight(weight)) { + int64_t acl_stride[2] = {1, transpose_ne[1]}; + + // Reverse ne. + std::reverse(transpose_ne, transpose_ne + n_dims); + + std::vector storageDims = {transpose_ne[0], transpose_ne[1]}; + + acl_weight_tensor = aclCreateTensor( + transpose_ne, n_dims, ggml_cann_type_mapping(weight->type), acl_stride, + 0, ACL_FORMAT_FRACTAL_NZ, storageDims.data(), 2, weight->data); + } else { + acl_weight_tensor = + ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_ND); + } aclTensor* acl_dst = ggml_cann_create_tensor(dst, bcast_dst_ne, bcast_dst_nb, n_dims); diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index 80ce80baea0..8deaf7ea1db 100755 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -23,6 +23,7 @@ #ifndef CANN_ACLNN_OPS #define CANN_ACLNN_OPS +#include #include #include #include @@ -1020,6 +1021,37 @@ inline void ggml_cann_async_memset(ggml_backend_cann_context & ctx, void * buffe */ void ggml_cann_mul_mat_id(ggml_backend_cann_context& ctx, ggml_tensor* dst); +/** + * @brief Check whether a tensor is a weight tensor for matrix multiplication. + * + * @details Checks whether the given tensor serves as weight parameters in matrix multiplication operations, + * typically within neural network layers. The function maintains a static set of canonical weight + * naming suffixes from Transformer-based architectures. Uses substring matching to identify weight + * tensors even with hierarchical naming patterns. + * + * @param tensor Pointer to the target ggml_tensor object (const-qualified). + */ +static bool is_matmul_weight(const ggml_tensor* tensor) { + std::string name = ggml_get_name(tensor); + static const std::unordered_set weight_suffixes{ + "output.weight", + "attn_q.weight", + "attn_k.weight", + "attn_v.weight", + "attn_output.weight", + "ffn_gate.weight", + "ffn_up.weight", + "ffn_down.weight" + }; + + for (const auto& suffix : weight_suffixes) { + if (name.find(suffix) != std::string::npos) { + return true; + } + } + return false; +} + /** * @brief Applies a element-wise operation to two input tensors using the CANN * backend. @@ -1066,7 +1098,7 @@ void ggml_cann_binary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { * @param dst The destination tensor. Its src[0] is treated as the input tensor. */ template - void ggml_cann_unary_op(ggml_backend_cann_context& ctx, ggml_tensor* dst) { + void ggml_cann_op_unary(ggml_backend_cann_context& ctx, ggml_tensor* dst) { ggml_tensor* src = dst->src[0]; aclTensor* acl_src = ggml_cann_create_tensor(src); @@ -1077,49 +1109,125 @@ template } /** - * @brief Applies a unary operation to a ggml tensor using the CANN backend. + * @brief Applies a unary operation to a ggml tensor using the CANN backend. * - * @details This function performs a unary operation on the input tensor using - * a user-provided lambda or callable object `unary_op`, which accepts the CANN - * context and two ACL tensors (source and destination). Internally, this function - * creates ACL representations of the ggml tensors and invokes the unary operation. - * The result is stored in the destination tensor `dst`. This utility abstracts the - * common boilerplate of tensor conversion and cleanup when implementing unary ops. + * @details This function applies a unary operation to the input tensor using + * a user-provided lambda or callable `unary_op`. The lambda receives the + * CANN backend context and two ACL tensors: the source and the destination. * - * @param unary_op A callable that performs the unary operation using CANN APIs. - * @param ctx The CANN context used for operations. - * @param dst The destination tensor where the result will be stored. - * The source tensor is retrieved from `dst->src[0]`. + * Internally, this function handles the conversion from GGML tensors to ACL tensors, + * calls the provided unary op, and manages resource cleanup. The input is assumed + * to be `dst->src[0]`, and the result is written to `dst`. + * + * This utility simplifies writing unary op wrappers by abstracting tensor preparation. + * + * @param unary_op A callable that performs the unary operation using CANN ACL APIs. + * @param ctx The CANN context for operation execution. + * @param dst The destination ggml_tensor where the result will be stored. + * The input tensor is assumed to be `dst->src[0]`. + * + * @see GGML_CANN_CALL_OP_UNARY + */ +void ggml_cann_op_unary( + std::function unary_op, + ggml_backend_cann_context& ctx, ggml_tensor* dst); + +/** + * @brief Applies a gated (GLU-style) unary operation using the CANN backend. + * + * @details This function performs a gated activation such as GEGLU or ReGLU. + * It supports two input modes: + * + * 1. **Dual input mode**: `dst->src[0]` and `dst->src[1]` are both valid tensors. + * These are used directly as the value and gate tensors. + * + * 2. **Packed input mode**: Only `dst->src[0]` is valid, and it is assumed to + * contain a concatenation of value and gate along the first dimension. This tensor + * will be split into two equal halves to form the value and gate inputs. + * + * The function applies a user-provided unary operation (e.g., GELU) to the value tensor, + * then multiplies the result in-place with the gate tensor: + * + * @code + * dst = unary_op(value) * gate; + * @endcode + * + * The `swapped` parameter (from `dst->op_params[1]`) allows flipping the + * order of value/gate in the packed input case. + * + * @param unary_op A callable that performs the unary operation using CANN ACL APIs. + * It receives (ctx, acl_value_tensor, acl_output_tensor). + * @param ctx The CANN context used for execution. + * @param dst The destination ggml_tensor. Source tensors are in `dst->src[0]` and optionally `src[1]`. + * + * @see GGML_CANN_CALL_OP_UNARY_GATED */ -void ggml_cann_unary_op( +void ggml_cann_op_unary_gated( std::function unary_op, ggml_backend_cann_context& ctx, ggml_tensor* dst); /** - * @brief Helper macro to invoke a unary ACL operation using ggml_cann_unary_op. + * @brief Helper macro to call a unary ACL operator via ggml_cann_op_unary. * - * This macro defines an inline lambda wrapping a specific ACL operation name, - * and passes it to the templated ggml_cann_unary_op function. It simplifies - * calling unary ops by hiding the lambda boilerplate. + * This macro wraps the specified ACLNN unary operator name into a lambda expression, + * and passes it to `ggml_cann_op_unary`, which handles the common logic for executing + * unary ops in the CANN backend. * - * Internally, the lambda will call: + * Internally, this macro expands to a lambda like: * @code - * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); + * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { + * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); + * }; * @endcode * + * This lambda is then passed to `ggml_cann_op_unary`, which applies the operation. + * * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP. * - * @see ggml_cann_unary_op + * @see ggml_cann_op_unary * @see GGML_CANN_CALL_ACLNN_OP */ -#define GGML_CANN_CALL_UNARY_OP(OP_NAME) \ +#define GGML_CANN_CALL_OP_UNARY(OP_NAME) \ do { \ auto lambda = [](ggml_backend_cann_context& ctx, \ aclTensor* acl_src, \ aclTensor* acl_dst) { \ GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ }; \ - ggml_cann_unary_op(lambda, ctx, dst); \ + ggml_cann_op_unary(lambda, ctx, dst); \ } \ while (0) + +/** + * @brief Helper macro to call a gated unary ACL operator via ggml_cann_op_unary_gated. + * + * This macro wraps the specified ACLNN unary operator name into a lambda expression, + * and passes it to `ggml_cann_op_unary_gated`, which handles the common logic for + * executing gated unary ops in the CANN backend. + * + * Internally, this macro expands to a lambda like: + * @code + * [](ggml_backend_cann_context& ctx, aclTensor* acl_src, aclTensor* acl_dst) { + * GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); + * }; + * @endcode + * + * This lambda is then passed to `ggml_cann_op_unary_gated`, which applies the operation. + * + * @param OP_NAME The name of the ACL unary operator to invoke via GGML_CANN_CALL_ACLNN_OP. + * + * @see ggml_cann_op_unary_gated + * @see GGML_CANN_CALL_ACLNN_OP + */ +#define GGML_CANN_CALL_OP_UNARY_GATED(OP_NAME) \ + do { \ + auto lambda = [](ggml_backend_cann_context& ctx, \ + aclTensor* acl_src, \ + aclTensor* acl_dst) { \ + GGML_CANN_CALL_ACLNN_OP(ctx, OP_NAME, acl_src, acl_dst); \ + }; \ + ggml_cann_op_unary_gated(lambda, ctx, dst); \ + } \ + while (0) + #endif // CANN_ACLNN_OPS diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index e5e11d4cdce..c6edb6b61bb 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -1115,6 +1116,63 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor( return GGML_STATUS_SUCCESS; } +static int CreateAclTensorWeight(const void *hostData, const std::vector &shape, void **deviceAddr, + aclDataType dataType, aclTensor **tensor) +{ + uint64_t size = 1; + for (auto i : shape) { + size *= i; + } + + const aclIntArray *mat2Size = aclCreateIntArray(shape.data(), shape.size()); + ACL_CHECK(aclnnCalculateMatmulWeightSizeV2(mat2Size, dataType, &size)); + + size *= sizeof(int16_t); + + ACL_CHECK(aclrtMalloc(deviceAddr, size, ACL_MEM_MALLOC_HUGE_FIRST)); + aclrtMemcpy(*deviceAddr, size, hostData, size, ACL_MEMCPY_HOST_TO_DEVICE); + + std::vector strides(shape.size(), 1); + for (int64_t i = shape.size() - 2; i >= 0; i--) { + strides[i] = shape[i + 1] * strides[i + 1]; + } + + *tensor = aclCreateTensor(shape.data(), shape.size(), dataType, strides.data(), 0, aclFormat::ACL_FORMAT_ND, + shape.data(), shape.size(), *deviceAddr); + return 0; +} + +static void weight_format_to_nz(ggml_tensor *tensor, const void *data, size_t offset) { + aclrtStream stream; + ACL_CHECK(aclrtCreateStream(&stream)); + + std::vector weightTransposedShape = {tensor->ne[1], tensor->ne[0]}; + void *weightTransposedDeviceAddr = nullptr; + aclTensor *weightTransposed = nullptr; + CreateAclTensorWeight(data, weightTransposedShape, &weightTransposedDeviceAddr, + ggml_cann_type_mapping(tensor->type), &weightTransposed); + + uint64_t workspaceSize = 0; + aclOpExecutor *executor; + void *workspaceAddr = nullptr; + + // TransMatmulWeight + ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, &workspaceSize, &executor)); + std::unique_ptr workspaceAddrPtrTrans(nullptr, aclrtFree); + if (workspaceSize > 0) { + ACL_CHECK(aclrtMalloc(&workspaceAddr, workspaceSize, ACL_MEM_MALLOC_HUGE_FIRST)); + workspaceAddrPtrTrans.reset(workspaceAddr); + } + ACL_CHECK(aclnnTransMatmulWeight(workspaceAddr, workspaceSize, executor, stream)); + + size_t size = ggml_nelements(tensor) * ggml_element_size(tensor); + + aclrtMemcpy((char *)tensor->data + offset, size, + weightTransposedDeviceAddr, size, ACL_MEMCPY_HOST_TO_DEVICE); + ACL_CHECK(aclDestroyTensor(weightTransposed)); + aclrtFree(weightTransposedDeviceAddr); +} + // TODO: need handle tensor which has paddings. /** * @brief Set tensor data in a CANN buffer. @@ -1139,9 +1197,16 @@ static void ggml_backend_cann_buffer_set_tensor( // For acl, synchronous functions use this default stream. // Why aclrtSynchronizeDevice? + bool weightToNZ = false; +#ifdef ASCEND_310P + weightToNZ = (getenv("GGML_CANN_WEIGHT_NZ") != nullptr); +#endif if (!need_transform(tensor->type)) { ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); + if (weightToNZ && is_matmul_weight((const ggml_tensor*)tensor)) { + weight_format_to_nz(tensor, data, offset); + } } else { void *transform_buffer = malloc(size); ggml_backend_cann_transform(tensor, data, transform_buffer); @@ -1616,16 +1681,18 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, case GGML_OP_UNARY: switch (ggml_get_unary_op(dst)) { case GGML_UNARY_OP_ABS: - GGML_CANN_CALL_UNARY_OP(Abs); + GGML_CANN_CALL_OP_UNARY(Abs); break; case GGML_UNARY_OP_NEG: - GGML_CANN_CALL_UNARY_OP(Neg); + GGML_CANN_CALL_OP_UNARY(Neg); break; case GGML_UNARY_OP_GELU: - GGML_CANN_CALL_UNARY_OP(Gelu); + case GGML_UNARY_OP_GELU_ERF: + // aclnnGelu internally uses the erf-based approximation. + GGML_CANN_CALL_OP_UNARY(Gelu); break; case GGML_UNARY_OP_SILU: - GGML_CANN_CALL_UNARY_OP(Silu); + GGML_CANN_CALL_OP_UNARY(Silu); break; case GGML_UNARY_OP_GELU_QUICK: { auto lambda = [](ggml_backend_cann_context& ctx, @@ -1633,31 +1700,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, aclTensor* acl_dst) { GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); }; - ggml_cann_unary_op(lambda, ctx, dst); + ggml_cann_op_unary(lambda, ctx, dst); } break; case GGML_UNARY_OP_TANH: - GGML_CANN_CALL_UNARY_OP(Tanh); + GGML_CANN_CALL_OP_UNARY(Tanh); break; case GGML_UNARY_OP_RELU: - GGML_CANN_CALL_UNARY_OP(Relu); + GGML_CANN_CALL_OP_UNARY(Relu); break; case GGML_UNARY_OP_SIGMOID: - GGML_CANN_CALL_UNARY_OP(Sigmoid); + GGML_CANN_CALL_OP_UNARY(Sigmoid); break; case GGML_UNARY_OP_HARDSIGMOID: - GGML_CANN_CALL_UNARY_OP(Hardsigmoid); + GGML_CANN_CALL_OP_UNARY(Hardsigmoid); break; case GGML_UNARY_OP_HARDSWISH: - GGML_CANN_CALL_UNARY_OP(Hardswish); + GGML_CANN_CALL_OP_UNARY(Hardswish); break; case GGML_UNARY_OP_EXP: - GGML_CANN_CALL_UNARY_OP(Exp); + GGML_CANN_CALL_OP_UNARY(Exp); break; case GGML_UNARY_OP_ELU: ggml_cann_elu(ctx, dst); break; case GGML_UNARY_OP_SGN: - GGML_CANN_CALL_UNARY_OP(Sign); + GGML_CANN_CALL_OP_UNARY(Sign); break; case GGML_UNARY_OP_STEP: ggml_cann_step(ctx, dst); @@ -1666,6 +1733,31 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, return false; } break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_REGLU: + GGML_CANN_CALL_OP_UNARY_GATED(Relu); + break; + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_GEGLU_ERF: + // aclnnGelu internally uses the erf-based approximation. + GGML_CANN_CALL_OP_UNARY_GATED(Gelu); + break; + case GGML_GLU_OP_SWIGLU: + GGML_CANN_CALL_OP_UNARY_GATED(Silu); + break; + case GGML_GLU_OP_GEGLU_QUICK: { + auto lambda = [](ggml_backend_cann_context& ctx, + aclTensor* acl_src, + aclTensor* acl_dst) { + GGML_CANN_CALL_ACLNN_OP(ctx, GeluV2, acl_src, 0, acl_dst); + }; + ggml_cann_op_unary_gated(lambda, ctx, dst); + } break; + default: + return false; + } + break; case GGML_OP_NORM: ggml_cann_norm(ctx, dst); break; @@ -1708,7 +1800,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_binary_op(ctx, dst); break; case GGML_OP_SQRT: - GGML_CANN_CALL_UNARY_OP(Sqrt); + GGML_CANN_CALL_OP_UNARY(Sqrt); break; case GGML_OP_CLAMP: ggml_cann_clamp(ctx, dst); @@ -1753,16 +1845,16 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context& ctx, ggml_cann_argmax(ctx, dst); break; case GGML_OP_COS: - ggml_cann_unary_op(ctx, dst); + ggml_cann_op_unary(ctx, dst); break; case GGML_OP_SIN: - ggml_cann_unary_op(ctx, dst); + ggml_cann_op_unary(ctx, dst); break; case GGML_OP_CONV_TRANSPOSE_1D: ggml_cann_conv_transpose_1d(ctx, dst); break; case GGML_OP_LOG: - GGML_CANN_CALL_UNARY_OP(Log); + GGML_CANN_CALL_OP_UNARY(Log); break; case GGML_OP_MEAN: ggml_cann_mean(ctx, dst); @@ -2036,10 +2128,23 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, case GGML_UNARY_OP_ELU: case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_STEP: + case GGML_UNARY_OP_GELU_ERF: return true; default: return false; } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return true; + default: + return false; + } + break; case GGML_OP_MUL_MAT: { switch (op->src[0]->type) { case GGML_TYPE_F16: diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 66a5ad8d2ed..f188d1638dc 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -70,10 +70,12 @@ function(ggml_add_cpu_backend_variant_impl tag_name) if (GGML_OPENMP) find_package(OpenMP) if (OpenMP_FOUND) + set(GGML_OPENMP_ENABLED "ON" CACHE INTERNAL "") target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_OPENMP) target_link_libraries(${GGML_CPU_NAME} PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX) else() + set(GGML_OPENMP_ENABLED "OFF" CACHE INTERNAL "") message(WARNING "OpenMP not found") endif() endif() @@ -456,6 +458,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) list(APPEND ARCH_FLAGS -march=z16) elseif (${S390X_M} MATCHES "9175|9176") # NOTE: Only available from GCC 15.1.0 onwards. Any z17 machine with compile issues must first verify their GCC version. + # binutils must also be updated to the latest for the -march=z17 flag to work. Otherwise, use -march=arch15. message(STATUS "z17 target") list(APPEND ARCH_FLAGS -march=z17) else() @@ -494,9 +497,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.9.0") + set(KLEIDIAI_COMMIT_TAG "v1.11.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "2a8e1bb55d201557553545536489a017") + set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) diff --git a/ggml/src/ggml-cpu/arch/loongarch/quants.c b/ggml/src/ggml-cpu/arch/loongarch/quants.c index 9e33fb32286..7908da4d16b 100644 --- a/ggml/src/ggml-cpu/arch/loongarch/quants.c +++ b/ggml/src/ggml-cpu/arch/loongarch/quants.c @@ -544,7 +544,7 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i __m128 max4 = __lsx_vfmax_s( lasx_extractf128( max_abs, 1 ), lasx_extractf128( max_abs, 0) ); max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) ); __m128 tmp = max4; - max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 )); + max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x1 )); const float max_scalar = ((v4f32)max4)[0]; // Quantize these floats diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.cpp b/ggml/src/ggml-cpu/kleidiai/kernels.cpp index 910fd0ee4e7..ddd29d002d1 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kernels.cpp @@ -22,9 +22,94 @@ #include "kai_common.h" +#include "simd-mappings.h" + #include "kernels.h" #define NELEMS(x) sizeof(x) / sizeof(*x) + +static const size_t INT4_PER_BYTE = 2; +static const size_t INT4_BITS = 4; +static const int Q4_0_ZERO_POINT = 8; +const size_t INT4_PER_UINT16 = 4; + +static void dequantize_row_qsi4c32pscalef16( + const void *packed_data, + int32_t row_idx, + int64_t nc, + float *out, + size_t nr_pack, + size_t packed_row_stride, + size_t kr, + size_t bl, + size_t num_bytes_multiplier +) { + size_t group_idx = row_idx / nr_pack; + size_t row_in_group = row_idx % nr_pack; + const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride; + size_t num_blocks = nc / bl; + const uint8_t *block_ptr = packed_group; + + for (size_t b = 0; b < num_blocks; ++b) { + uint16_t scale_f16 = *((const uint16_t *)(block_ptr + row_in_group * num_bytes_multiplier)); + float scale = GGML_CPU_FP16_TO_FP32(scale_f16); + + const uint8_t *segment_ptr = block_ptr + nr_pack * num_bytes_multiplier; + size_t num_segments = bl / kr; + size_t num_bytes_per_segment = kr / INT4_PER_BYTE; + + for (size_t s = 0; s < num_segments; ++s) { + const uint8_t *seg_base = segment_ptr + s * nr_pack * num_bytes_per_segment; + const uint8_t *qbytes = seg_base + row_in_group * num_bytes_per_segment; + for (size_t k = 0; k < num_bytes_per_segment; ++k) { + uint8_t byte = qbytes[k] ^ 0x88; + int x0 = (byte & 0x0F) - Q4_0_ZERO_POINT; + int x1 = (byte >> INT4_BITS) - Q4_0_ZERO_POINT; + out[b * bl + s * num_bytes_per_segment + k] = x0 * scale; + out[b * bl + s * num_bytes_per_segment + k + bl/2] = x1 * scale; + } + } + block_ptr += nr_pack * num_bytes_multiplier + num_segments * nr_pack * num_bytes_per_segment; + } +} + +static void dequantize_row_qsi4c32ps1s0scalef16( + const void *packed_data, + int32_t row_idx, + int64_t k, + float *out, + size_t nr, + size_t packed_row_stride, + size_t kr, + size_t bl, + size_t num_bytes_multiplier +) { + const size_t num_blocks = k / bl; + const size_t bl4 = bl / INT4_PER_UINT16; + + size_t group_idx = row_idx / nr; + size_t row_in_group = row_idx % nr; + + const uint8_t *packed_group = (const uint8_t *)packed_data + group_idx * packed_row_stride; + const uint16_t *qdata = (const uint16_t *)packed_group; + const uint16_t *scales = (const uint16_t *)(packed_group + packed_row_stride - (nr * num_blocks * num_bytes_multiplier)); + + for (size_t block_idx = 0; block_idx < num_blocks; ++block_idx) { + uint16_t scale_f16 = scales[row_in_group + block_idx * nr]; + float scale = GGML_CPU_FP16_TO_FP32(scale_f16); + + for (size_t bl4_idx = 0; bl4_idx < bl4; ++bl4_idx) { + uint16_t q = qdata[(block_idx * bl4 + bl4_idx) * nr + row_in_group]; + + for (size_t qidx = 0; qidx < INT4_PER_UINT16; ++qidx) { + int v = ((q >> (qidx * 4)) & 0xF) - Q4_0_ZERO_POINT; + out[block_idx * bl + bl4_idx * INT4_BITS + qidx] = v * scale; + } + } + } + GGML_UNUSED(kr); +} + static ggml_kleidiai_kernels gemm_gemv_kernels[] = { #if defined(__ARM_FEATURE_SME) { @@ -63,8 +148,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon, + /* .to_float = */ dequantize_row_qsi4c32ps1s0scalef16, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, @@ -107,8 +194,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, - /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + /* .packed_stride = */ NULL, + /* .pack_func = */ kai_run_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme, + /* .to_float = */ NULL, }, /* .required_cpu = */ CPU_FEATURE_SME, /* .lhs_type = */ GGML_TYPE_F32, @@ -154,8 +243,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, @@ -200,8 +291,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -247,8 +340,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD | CPU_FEATURE_I8MM, /* .lhs_type = */ GGML_TYPE_F32, @@ -293,8 +388,10 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = { /* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32, }, /* .rhs_info = */ { - /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, - /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .packed_size = */ kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .packed_stride = */ kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .pack_func = */ kai_run_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0, + /* .to_float = */ dequantize_row_qsi4c32pscalef16, }, /* .required_cpu = */ CPU_FEATURE_DOTPROD, /* .lhs_type = */ GGML_TYPE_F32, diff --git a/ggml/src/ggml-cpu/kleidiai/kernels.h b/ggml/src/ggml-cpu/kleidiai/kernels.h index 3b268d4a22a..bc8f33405d1 100644 --- a/ggml/src/ggml-cpu/kleidiai/kernels.h +++ b/ggml/src/ggml-cpu/kleidiai/kernels.h @@ -71,12 +71,15 @@ struct rhs_packing_info { std::function, std::function > packed_size; + size_t (*packed_stride)(size_t k, size_t nr, size_t kr, size_t bl); std::variant< std::function, std::function > pack_func; + void (*to_float)(const void *packed_data, int32_t row_idx, int64_t nc, float *out, size_t nr_pack, size_t packed_row_stride, + size_t kr, size_t bl, size_t num_bytes_multiplier); }; struct ggml_kleidiai_kernels { diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index fafe45e6c5c..3a513a55d76 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -40,6 +40,17 @@ struct ggml_kleidiai_context { ggml_kleidiai_kernels * kernels; } static ctx = { CPU_FEATURE_NONE, NULL }; +static const char* cpu_feature_to_string(cpu_feature f) { + switch (f) { + case CPU_FEATURE_NONE: return "NONE"; + case CPU_FEATURE_DOTPROD: return "DOTPROD"; + case CPU_FEATURE_I8MM: return "I8MM"; + case CPU_FEATURE_SVE: return "SVE"; + case CPU_FEATURE_SME: return "SME"; + default: return "UNKNOWN"; + } +} + static void init_kleidiai_context(void) { ggml_critical_section_start(); @@ -62,6 +73,11 @@ static void init_kleidiai_context(void) { ctx.features |= ggml_cpu_has_sme() ? CPU_FEATURE_SME : CPU_FEATURE_NONE; } ctx.kernels = ggml_kleidiai_select_kernels_q4_0(ctx.features); +#ifndef NDEBUG + if (ctx.kernels) { + GGML_LOG_DEBUG("kleidiai: using kernel with CPU feature %s\n", cpu_feature_to_string(ctx.kernels->required_cpu)); + } +#endif } ggml_critical_section_end(); } @@ -102,6 +118,9 @@ static void transpose_f32kxn_f16nxk(size_t n, size_t k, float * dst, const uint1 class tensor_traits : public ggml::cpu::tensor_traits { bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + if (op->op != GGML_OP_MUL_MAT) { + return false; + } ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); GGML_ASSERT(kernels); kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; @@ -135,6 +154,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { } else if (dst->src[0]->type == GGML_TYPE_F16) { return compute_forward_kv_cache(params, dst); } + } else if (dst->op == GGML_OP_GET_ROWS) { + if (dst->src[0]->type == GGML_TYPE_Q4_0) { + return compute_forward_get_rows(params, dst); + } } return false; } @@ -270,6 +293,8 @@ class tensor_traits : public ggml::cpu::tensor_traits { } bool compute_forward_q4_0(struct ggml_compute_params * params, struct ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); + const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -342,8 +367,49 @@ class tensor_traits : public ggml::cpu::tensor_traits { return true; } + bool compute_forward_get_rows(struct ggml_compute_params * params, struct ggml_tensor * dst) { + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_Q4_0); + GGML_ASSERT(ctx.kernels); + + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + rhs_packing_info * rhs_info = &ctx.kernels->rhs_info; + kernel_info * kernel = &ctx.kernels->gemm; + + const int64_t nc = ne00; + const int64_t nr = ggml_nelements(src1); + + const size_t block_rows = kernel->get_nr(); + const size_t kr = kernel->get_kr(); + + const size_t num_bytes_multiplier = sizeof(uint16_t); + const size_t packed_stride = rhs_info->packed_stride(nc, block_rows, kr, QK4_0); + + const int ith = params->ith; + const int nth = params->nth; + + const int dr = (nr + nth - 1) / nth; + const int ir0 = dr * ith; + const int ir1 = MIN(ir0 + dr, nr); + + for (int64_t i = ir0; i < ir1; ++i) { + GGML_ASSERT(src1->type == GGML_TYPE_I32); + int64_t row_idx = ((const int32_t *)src1->data)[i]; + GGML_ASSERT(row_idx >= 0 && row_idx < src0->ne[1]); + + float *out = (float *)((char *)dst->data + i * nb1); + rhs_info->to_float(src0->data, row_idx, nc, out, block_rows, packed_stride, kr, QK4_0, num_bytes_multiplier); + } + + return true; + } + public: int repack(struct ggml_tensor * tensor, const void * data, size_t data_size) { + GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0); GGML_ASSERT(ctx.kernels); const size_t n = tensor->ne[1]; const size_t k = tensor->ne[0]; @@ -351,17 +417,12 @@ class tensor_traits : public ggml::cpu::tensor_traits { size_t kr = ctx.kernels->gemm.get_kr(); size_t sr = ctx.kernels->gemm.get_sr(); -#ifndef NDEBUG - const size_t repacked_size = variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); - GGML_ASSERT(repacked_size <= data_size && "repacked size larger than the packed size!"); -#endif struct kai_rhs_pack_qs4cxs1s0_param params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; variant_call(ctx.kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, QK4_0, (const uint8_t*)data, nullptr, tensor->data, 0, ¶ms); return 0; - GGML_UNUSED(data_size); } }; @@ -375,8 +436,8 @@ static ggml::cpu::tensor_traits * get_tensor_traits(ggml_backend_buffer_t, struc static enum ggml_status ggml_backend_cpu_kleidiai_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { tensor->extra = (void *) ggml::cpu::kleidiai::get_tensor_traits(buffer, tensor); - GGML_UNUSED(buffer); return GGML_STATUS_SUCCESS; + GGML_UNUSED(buffer); } static void ggml_backend_cpu_kleidiai_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, @@ -418,18 +479,35 @@ static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alignment(ggml_backend_b GGML_UNUSED(buft); } +static size_t ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { + GGML_ASSERT(tensor->type == GGML_TYPE_Q4_0); + GGML_ASSERT(ctx.kernels); + + const size_t n = tensor->ne[1]; + const size_t k = tensor->ne[0]; + const size_t nr = ctx.kernels->gemm.get_nr(); + const size_t kr = ctx.kernels->gemm.get_kr(); + + return variant_call(ctx.kernels->rhs_info.packed_size, n, k, nr, kr, QK4_0); + + GGML_UNUSED(buft); +} + namespace ggml::cpu::kleidiai { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - if (op->op == GGML_OP_MUL_MAT && + if ((op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) && op->src[0]->type == GGML_TYPE_Q4_0 && op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { + if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) { + return false; + } if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { return false; } - if (op->src[1]->type == GGML_TYPE_F32 && + if ((op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_I32) && ggml_ne(op->src[1], 2) == 1 && ggml_ne(op->src[1], 3) == 1) { return true; } @@ -438,7 +516,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { } ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { - if (op->op == GGML_OP_MUL_MAT) { + if (op->op == GGML_OP_MUL_MAT || op->op == GGML_OP_GET_ROWS) { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { return (ggml::cpu::tensor_traits *) op->src[0]->extra; } @@ -469,7 +547,7 @@ ggml_backend_buffer_type_t ggml_backend_cpu_kleidiai_buffer_type(void) { /* .alloc_buffer = */ ggml_backend_cpu_kleidiai_buffer_type_alloc_buffer, /* .get_alignment = */ ggml_backend_cpu_kleidiai_buffer_type_get_alignment, /* .get_max_size = */ nullptr, // defaults to SIZE_MAX - /* .get_alloc_size = */ nullptr, // defaults to ggml_nbytes + /* .get_alloc_size = */ ggml_backend_cpu_kleidiai_buffer_type_get_alloc_size, /* .is_host = */ nullptr, }, /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 72ee93a5abc..74c1c029b94 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -14,7 +14,6 @@ #include #include #include -#include // for qsort #include // for GGML_ASSERT #include "repack.h" diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index c9ff4aa321b..98ed29bc9c1 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -102,12 +102,12 @@ if (CUDAToolkit_FOUND) if (GGML_STATIC) if (WIN32) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library - target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas CUDA::cublasLt) + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) else () - target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static) endif() else() - target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas CUDA::cublasLt) + target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) endif() if (GGML_CUDA_NO_VMM) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 1a2708ec9df..cdc3bb5ae76 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -56,7 +56,7 @@ #define GGML_CUDA_CC_GCN4 (GGML_CUDA_CC_OFFSET_AMD + 0x803) // Tonga, Fiji, Polaris, minimum for fast fp16 #define GGML_CUDA_CC_VEGA (GGML_CUDA_CC_OFFSET_AMD + 0x900) // Vega56/64, minimum for fp16 dual issue #define GGML_CUDA_CC_VEGA20 (GGML_CUDA_CC_OFFSET_AMD + 0x906) // MI50/Radeon VII, minimum for dp4a -#define GGML_CUDA_CC_CDNA (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers +#define GGML_CUDA_CC_CDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x908) // MI100, minimum for MFMA, acc registers #define GGML_CUDA_CC_CDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x910) // MI210, minimum acc register renameing #define GGML_CUDA_CC_CDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x942) // MI300 @@ -72,8 +72,9 @@ #define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3) #define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4) #define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4) -#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA) -#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA1) +#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA1 && cc < GGML_CUDA_CC_RDNA1) +#define GGML_CUDA_CC_IS_CDNA3(cc) (cc >= GGML_CUDA_CC_CDNA3 && cc < GGML_CUDA_CC_RDNA1) // Moore Threads #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 @@ -226,6 +227,10 @@ typedef float2 dfloat2; #define FP16_MMA_AVAILABLE #endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3) +#define AMD_MFMA_AVAILABLE +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && defined(CDNA3) + #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING #define NEW_MMA_AVAILABLE #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING @@ -288,6 +293,11 @@ static bool fp32_mma_hardware_available(const int cc) { return GGML_CUDA_CC_IS_CDNA(cc); } +// AMD CDNA3 matrix cores.. Will add support for other CDNA generations later. +static bool amd_mfma_available(const int cc) { + return cc >= GGML_CUDA_CC_OFFSET_AMD && GGML_CUDA_CC_IS_CDNA3(cc); +} + // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later. static bool new_mma_available(const int cc) { return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING; @@ -765,7 +775,7 @@ struct ggml_tensor_extra_gpu { }; -#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) +#if (defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS)) || defined(GGML_MUSA_GRAPHS) #define USE_CUDA_GRAPH #endif diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index eeaa14bf579..15c927861f0 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -6,24 +6,33 @@ #define CUDA_Q8_0_NE_ALIGN 2048 template -static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k) { - const int64_t i = (int64_t)2*(blockDim.x*blockIdx.x + threadIdx.x); +static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t s01, const int64_t s02, const int64_t s03) { + const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x); - if (i >= k) { + if (i00 >= ne00) { return; } - const int64_t ib = i/qk; // block index - const int64_t iqs = (i%qk)/qr; // quant index - const int64_t iybs = i - i%qk; // y block start index + const int64_t i01 = blockIdx.y; + const int64_t i02 = blockIdx.z % ne02; + const int64_t i03 = blockIdx.z / ne02; + + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index const int64_t y_offset = qr == 1 ? 1 : qk/2; // dequantize dfloat2 v; dequantize_kernel(vx, ib, iqs, v); - y[iybs + iqs + 0] = v.x; - y[iybs + iqs + y_offset] = v.y; + const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = float(v.x); + y[iy0 + y_offset] = float(v.y); } template @@ -457,9 +466,17 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst } template -static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { - const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE); - dequantize_block<<>>(vx, y, k); +static void dequantize_block_cuda(const void * vx, dst_t * y, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03); + dequantize_block<<>> + (vx, y, ne00, ne01, ne02, s01, s02, s03); +} + +template +static void dequantize_block_cont_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t k, cudaStream_t stream) { + dequantize_block_cuda(vx, y, k, 1, 1, 1, k/qk, k/qk, k/qk, stream); } static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int64_t k, cudaStream_t stream) { @@ -624,14 +641,14 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { case GGML_TYPE_Q4_1: return dequantize_row_q4_1_cuda; case GGML_TYPE_Q5_0: - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q5_1: - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q8_0: if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) { return dequantize_block_q8_0_f16_cuda; } - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; case GGML_TYPE_Q3_K: @@ -676,11 +693,11 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { case GGML_TYPE_Q4_1: return dequantize_row_q4_1_cuda; case GGML_TYPE_Q5_0: - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q5_1: - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q8_0: - return dequantize_block_cuda; + return dequantize_block_cont_cuda; case GGML_TYPE_Q2_K: return dequantize_row_q2_K_cuda; case GGML_TYPE_Q3_K: @@ -722,6 +739,16 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; case GGML_TYPE_BF16: return convert_unary_cuda; default: @@ -733,6 +760,16 @@ to_bf16_nc_cuda_t ggml_get_to_bf16_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F32: return convert_unary_cuda; + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; case GGML_TYPE_F16: return convert_unary_cuda; default: @@ -744,6 +781,16 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { switch (type) { case GGML_TYPE_F16: return convert_unary_cuda; + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; case GGML_TYPE_BF16: return convert_unary_cuda; default: diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh index e7a0bd2f1a0..410c12b7ba5 100644 --- a/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -2,24 +2,13 @@ #include "ggml-common.h" -static __device__ __forceinline__ void convert_f32_f32(const float * src, float * dst) { - *dst = *src; -} - -static __device__ __forceinline__ void convert_f32_f16(const float * src, half * dst) { - *dst = __float2half(*src); -} - -static __device__ __forceinline__ void convert_f32_bf16(const float * src, nv_bfloat16 * dst) { - *dst = *src; -} - -static __device__ __forceinline__ void convert_f16_f16(const half * src, half * dst) { - *dst = *src; -} - -static __device__ __forceinline__ void convert_f16_f32(const half * src, float * dst) { - *dst = *src; +template +static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) { + if constexpr (std::is_same_v) { + *dst = *src; + } else { + *dst = float(*src); + } } static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) { @@ -230,22 +219,7 @@ static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) { quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti); } -static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) { - convert_f32_f32((const float *)cxi, (float *)cdsti); -} - -static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) { - convert_f32_f16((const float *)cxi, (half *)cdsti); -} - -static __device__ void cpy_1_f32_bf16(const char * cxi, char * cdsti) { - convert_f32_bf16((const float *)cxi, (nv_bfloat16 *)cdsti); -} - -static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) { - convert_f16_f16((const half *)cxi, (half *)cdsti); -} - -static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) { - convert_f16_f32((const half *)cxi, (float *)cdsti); +template +static __device__ void cpy_1_flt(const char * cxi, char * cdsti) { + convert_flt((const src_t *)cxi, (dst_t *)cdsti); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index e7d0da08705..f9bb025643c 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -1,17 +1,17 @@ #include "cpy.cuh" #include "dequantize.cuh" #include "cpy-utils.cuh" -#ifdef GGML_USE_MUSA +#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) #include "ggml-musa/mudnn.cuh" -#endif // GGML_USE_MUSA +#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY typedef void (*cpy_kernel_t)(const char * cx, char * cdst); template -static __global__ void cpy_f32_f16(const char * cx, char * cdst_direct, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { +static __global__ void cpy_flt(const char * cx, char * cdst_direct, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, + const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, + const int nb12, const int nb13, char ** cdst_indirect, int graph_cpynode_index) { const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { @@ -121,7 +121,7 @@ static __global__ void cpy_q_f32(const char * cx, char * cdst_direct, const int // Copy destination pointers to GPU to be available when pointer indirection is in use void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream) { -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) if (cuda_graph->dest_ptrs_size < host_dest_ptrs_size) { // (re-)allocate GPU memory for destination pointers CUDA_CHECK(cudaStreamSynchronize(stream)); if (cuda_graph->dest_ptrs_d != nullptr) { @@ -139,43 +139,14 @@ void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_des #endif } -static void ggml_cpy_f16_f32_cuda( +template +static void ggml_cpy_flt_cuda( const char * cx, char * cdst, const int ne, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); -} - -static void ggml_cpy_f32_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); -} - -static void ggml_cpy_f32_bf16_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); -} - -static void ggml_cpy_f32_f16_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> + cpy_flt><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } @@ -307,16 +278,6 @@ static void ggml_cpy_f32_iq4_nl_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } -static void ggml_cpy_f16_f16_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { - - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; - cpy_f32_f16<<>> - (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); -} - void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -353,7 +314,7 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg char ** dest_ptrs_d = nullptr; int graph_cpynode_index = -1; -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { dest_ptrs_d = ctx.cuda_graph->dest_ptrs_d; graph_cpynode_index = ctx.cuda_graph->graph_cpynode_index; @@ -363,20 +324,20 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg #endif if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1)); -#ifdef GGML_USE_MUSA +#if defined(GGML_USE_MUSA) && defined(GGML_MUSA_MUDNN_COPY) if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16) { CUDA_CHECK(mudnnMemcpyAsync(ctx, src1, src0)); } else -#endif // GGML_USE_MUSA +#endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - ggml_cpy_f32_bf16_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { @@ -403,14 +364,22 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { ggml_cpy_q5_1_f32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); } -#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) +#if defined(GGML_CUDA_USE_GRAPHS) || defined(GGML_HIP_GRAPHS) || defined(GGML_MUSA_GRAPHS) if(ctx.cuda_graph->use_cpy_indirection && !disable_indirection_for_this_node) { ctx.cuda_graph->graph_cpynode_index = graph_cpynode_index; } @@ -430,11 +399,11 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { return nullptr; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) { return (void*) cpy_f32_q; } else if (src0->type == GGML_TYPE_Q8_0 && src1->type == GGML_TYPE_F32) { @@ -458,9 +427,17 @@ void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { } else if (src0->type == GGML_TYPE_Q5_1 && src1->type == GGML_TYPE_F32) { return (void*) cpy_q_f32, QK5_1>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_BF16) { + return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { - return (void*) cpy_f32_f16; + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { + return (void*) cpy_flt>; + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; } else { GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 9122fca6cf9..95e704e393c 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -23,33 +23,13 @@ typedef void (* fattn_kernel_t)( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3); + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33); typedef half (*vec_dot_KQ_f16_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); @@ -745,33 +725,58 @@ void launch_fattn( size_t nb23 = V ? V->nb[3] : nb13; if (need_f16_K && K->type != GGML_TYPE_F16) { - GGML_ASSERT(ggml_is_contiguously_allocated(K)); - K_f16.alloc(ggml_nelements(K)); - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); - to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); - K_data = (char *) K_f16.ptr; - const size_t bs = ggml_blck_size(K->type); const size_t ts = ggml_type_size(K->type); - nb11 = nb11*bs*sizeof(half)/ts; - nb12 = nb12*bs*sizeof(half)/ts; - nb13 = nb13*bs*sizeof(half)/ts; + K_f16.alloc(ggml_nelements(K)); + if (ggml_is_contiguously_allocated(K)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(K->type); + to_fp16(K_data, K_f16.ptr, ggml_nelements(K), main_stream); + + nb11 = nb11*bs*sizeof(half)/ts; + nb12 = nb12*bs*sizeof(half)/ts; + nb13 = nb13*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(K->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(K->type); + const int64_t s01 = nb11 / ts; + const int64_t s02 = nb12 / ts; + const int64_t s03 = nb13 / ts; + to_fp16(K_data, K_f16.ptr, K->ne[0], K->ne[1], K->ne[2], K->ne[3], s01, s02, s03, main_stream); + + nb11 = K->ne[0] * sizeof(half); + nb12 = K->ne[1] * nb11; + nb13 = K->ne[2] * nb12; + } + K_data = (char *) K_f16.ptr; } if (V && need_f16_V && V->type != GGML_TYPE_F16) { - GGML_ASSERT(ggml_is_contiguously_allocated(V)); - V_f16.alloc(ggml_nelements(V)); - to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); - to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); - V_data = (char *) V_f16.ptr; - const size_t bs = ggml_blck_size(V->type); const size_t ts = ggml_type_size(V->type); - nb21 = nb21*bs*sizeof(half)/ts; - nb22 = nb22*bs*sizeof(half)/ts; - nb23 = nb23*bs*sizeof(half)/ts; + V_f16.alloc(ggml_nelements(V)); + if (ggml_is_contiguously_allocated(V)) { + to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(V->type); + to_fp16(V_data, V_f16.ptr, ggml_nelements(V), main_stream); + V_data = (char *) V_f16.ptr; + + nb21 = nb21*bs*sizeof(half)/ts; + nb22 = nb22*bs*sizeof(half)/ts; + nb23 = nb23*bs*sizeof(half)/ts; + } else { + GGML_ASSERT(V->nb[0] == ts); + to_fp16_nc_cuda_t to_fp16 = ggml_get_to_fp16_nc_cuda(V->type); + const int64_t s01 = nb21 / ts; + const int64_t s02 = nb22 / ts; + const int64_t s03 = nb23 / ts; + to_fp16(V_data, V_f16.ptr, V->ne[0], V->ne[1], V->ne[2], V->ne[3], s01, s02, s03, main_stream); + + nb21 = V->ne[0] * sizeof(half); + nb22 = V->ne[1] * nb21; + nb23 = V->ne[2] * nb22; + } + V_data = (char *) V_f16.ptr; } int parallel_blocks = 1; @@ -867,14 +872,11 @@ void launch_fattn( mask ? ((const char *) mask->data) : nullptr, !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, - mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - nb11, nb12, nb13, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0 ); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 6fa2e77299e..83cf872f68a 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -408,7 +408,6 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int stride_K, const int stride_V, const int stride_mask, - const int jt, half2 * const __restrict__ tile_Q, half2 * const __restrict__ tile_K, half2 * const __restrict__ tile_V, @@ -455,7 +454,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( cp_async_wait_all(); __syncthreads(); flash_attn_ext_f16_load_tile - (V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V); + (V_h2 + int64_t(k_VKQ_0)*stride_V, tile_V, nbatch_V2, stride_V); } else { constexpr bool use_cp_async = nstages == 1; if (ncols2 > 1 || mask_h2) { @@ -471,7 +470,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (nstages <= 1) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile - (K_h2 + k_VKQ_0*stride_K + k0_start, tile_K, k0_diff, stride_K); + (K_h2 + int64_t(k_VKQ_0)*stride_K + k0_start, tile_K, k0_diff, stride_K); if (use_cp_async) { cp_async_wait_all(); } @@ -715,7 +714,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( (mask_h2 + (k_VKQ_0 + c::nbatch_fa)/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + (k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); + (K_h2 + int64_t(k_VKQ_0 + c::nbatch_fa)*stride_K, tile_K, nbatch_K2, stride_K); } } @@ -732,7 +731,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( if (nstages <= 1 && i0_start < reusable_cutoff) { constexpr bool use_cp_async = nstages == 1; flash_attn_ext_f16_load_tile - (V_h2 + k_VKQ_0*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); + (V_h2 + int64_t(k_VKQ_0)*stride_V + i0_start/2, tile_V, i0_diff/2, stride_V); if (use_cp_async) { cp_async_wait_all(); } @@ -771,8 +770,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup); GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_K); GGML_UNUSED(stride_V); - GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); - GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K); + GGML_UNUSED(stride_mask); GGML_UNUSED(tile_K); GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B); GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum); GGML_UNUSED(kb0); GGML_UNUSED(tile_Q); @@ -920,7 +918,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( (mask_h2 + kb0_start*c::nbatch_fa/2, tile_mask, stride_mask); } flash_attn_ext_f16_load_tile - (K_h2 + kb0_start*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); + (K_h2 + int64_t(kb0_start)*c::nbatch_fa*stride_K, tile_K, nbatch_K2, stride_K); } // Iterate over ne11 == previous tokens: @@ -928,13 +926,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( constexpr bool last_iter = false; flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); } { // kb0_start is always < kb0_stop so the last iter can be executed unconditionally. constexpr bool last_iter = true; flash_attn_ext_f16_iter (Q_f2, K_h2, V_h2, mask_h2, dstk, dstk_fixup, scale, slope, logit_softcap, - ne01, ne02, stride_K, stride_V, stride_mask, jt, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); + ne01, ne02, stride_K, stride_V, stride_mask, tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0_stop-1); } // With multi-stage loading there is no __syncthreads at the end of the iter, @@ -1214,33 +1212,13 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) // Skip unused kernel variants for faster compilation: @@ -1352,15 +1330,16 @@ static __global__ void flash_attn_ext_f16( ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel); #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); - GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); - GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10); - GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); - GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); - GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-tile-f16.cu b/ggml/src/ggml-cuda/fattn-tile-f16.cu index 1f141328845..7661c21efbb 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f16.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f16.cu @@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) // Skip unused kernel variants for faster compilation: @@ -127,7 +107,7 @@ static __global__ void flash_attn_tile_ext_f16( for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) { const int k_KQ = k_KQ_0 + threadIdx.x; - KV_tmp[i_KQ][k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; + KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; } } @@ -221,7 +201,7 @@ static __global__ void flash_attn_tile_ext_f16( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp[k][i] = V_h2[(k_VKQ_0 + k)*stride_KV2 + i]; + KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; } } @@ -300,8 +280,7 @@ static __global__ void flash_attn_tile_ext_f16( GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(nb23); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-tile-f32.cu b/ggml/src/ggml-cuda/fattn-tile-f32.cu index a4965583cef..11778bb9611 100644 --- a/ggml/src/ggml-cuda/fattn-tile-f32.cu +++ b/ggml/src/ggml-cuda/fattn-tile-f32.cu @@ -21,33 +21,13 @@ static __global__ void flash_attn_tile_ext_f32( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: @@ -57,17 +37,16 @@ static __global__ void flash_attn_tile_ext_f32( #endif // FP16_MMA_AVAILABLE if (use_logit_softcap && !(D == 128 || D == 256)) { GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; return; } @@ -135,7 +114,7 @@ static __global__ void flash_attn_tile_ext_f32( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) { - const half2 tmp = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; + const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x]; KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp); KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp); } @@ -231,8 +210,9 @@ static __global__ void flash_attn_tile_ext_f32( for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) { const int i = i0 + threadIdx.x; - KV_tmp2[k*(D/2) + i].x = __low2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); - KV_tmp2[k*(D/2) + i].y = __high2float(V_h2[(k_VKQ_0 + k)*stride_KV2 + i]); + const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i]; + KV_tmp2[k*(D/2) + i].x = __low2float(tmp); + KV_tmp2[k*(D/2) + i].y = __high2float(tmp); } } @@ -302,17 +282,16 @@ static __global__ void flash_attn_tile_ext_f32( } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); - GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); - GGML_UNUSED(ne31); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-vec-f16.cuh b/ggml/src/ggml-cuda/fattn-vec-f16.cuh index b2d469938ab..e9b5c306365 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f16.cuh @@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) // Skip unused kernel variants for faster compilation: @@ -191,13 +171,16 @@ static __global__ void flash_attn_vec_ext_f16( half2 VKQ[ncols] = {{0.0f, 0.0f}}; + K += blockIdx.y*D * nb11; + V += blockIdx.y*D * nb21; + maskh += blockIdx.y*D; for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { #pragma unroll for (int j = 0; j < ncols; ++j) { - maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + k_VKQ_0 + tid]; + maskh_shared[j*D + tid] = slopeh*maskh[j*ne11 + tid]; } __syncthreads(); @@ -244,7 +227,7 @@ static __global__ void flash_attn_vec_ext_f16( #pragma unroll for (int j = 0; j < ncols; ++j) { - half sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); + half sum = vec_dot_KQ(K + i_KQ*nb11, Q_h2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum((float)sum); if (use_logit_softcap) { @@ -300,14 +283,18 @@ static __global__ void flash_attn_vec_ext_f16( } half2 V_k; - reinterpret_cast(V_k.x) = dequantize_1_v(V + (k_VKQ_0 + k0 + 0)*nb21, tid); - reinterpret_cast(V_k.y) = dequantize_1_v(V + (k_VKQ_0 + k0 + 1)*nb21, tid); + reinterpret_cast(V_k.x) = dequantize_1_v(V + (k0 + 0)*nb21, tid); + reinterpret_cast(V_k.y) = dequantize_1_v(V + (k0 + 1)*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_k*KQ2[j*(D/2) + k0/2]; } } + K += gridDim.y*D * nb11; + V += gridDim.y*D * nb21; + maskh += gridDim.y*D; + __syncthreads(); } @@ -342,17 +329,16 @@ static __global__ void flash_attn_vec_ext_f16( } #else GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); - GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale); - GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); + GGML_UNUSED(dst); GGML_UNUSED(dst_meta); + GGML_UNUSED(scale); GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1); GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); - GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); - GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11); - GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32); - GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); - GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); - GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); + GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); + GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); + GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); + GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); + GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); + GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE) } diff --git a/ggml/src/ggml-cuda/fattn-vec-f32.cuh b/ggml/src/ggml-cuda/fattn-vec-f32.cuh index 405b6f5106e..6a4bdc0ff9a 100644 --- a/ggml/src/ggml-cuda/fattn-vec-f32.cuh +++ b/ggml/src/ggml-cuda/fattn-vec-f32.cuh @@ -18,33 +18,13 @@ static __global__ void flash_attn_vec_ext_f32( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: @@ -59,8 +39,7 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); - GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1); - GGML_UNUSED(ne2); GGML_UNUSED(ne3); + GGML_UNUSED(nb23); NO_DEVICE_CODE; return; } @@ -198,13 +177,16 @@ static __global__ void flash_attn_vec_ext_f32( float VKQ[ncols] = {0.0f}; + K += blockIdx.y*D * nb11; + V += blockIdx.y*D * nb21; + maskh += blockIdx.y*D; for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) { // Calculate KQ tile and keep track of new maximum KQ values: if (mask) { #pragma unroll for (int j = 0; j < ncols; ++j) { - maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + k_VKQ_0 + tid]); + maskf_shared[j*D + tid] = slope*__half2float(maskh[j*ne11 + tid]); } __syncthreads(); @@ -246,7 +228,7 @@ static __global__ void flash_attn_vec_ext_f32( #pragma unroll for (int j = 0; j < ncols; ++j) { - float sum = vec_dot_KQ(K + (k_VKQ_0 + i_KQ)*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); + float sum = vec_dot_KQ(K + i_KQ*nb11, Q_f2[j], Q_i32[j], Q_ds[j]); sum = warp_reduce_sum(sum); if (use_logit_softcap) { @@ -297,13 +279,17 @@ static __global__ void flash_attn_vec_ext_f32( break; } - const float V_ki = dequantize_1_v(V + (k_VKQ_0 + k)*nb21, tid); + const float V_ki = dequantize_1_v(V + k*nb21, tid); #pragma unroll for (int j = 0; j < ncols; ++j) { VKQ[j] += V_ki*KQ[j*D + k]; } } + K += gridDim.y*D * nb11; + V += gridDim.y*D * nb21; + maskh += gridDim.y*D; + __syncthreads(); } @@ -348,7 +334,6 @@ static __global__ void flash_attn_vec_ext_f32( GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // FLASH_ATTN_AVAILABLE } diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 741b8781d29..c9b083bed01 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -37,33 +37,13 @@ static __global__ void flash_attn_ext_f16( const float m1, const uint32_t n_head_log2, const float logit_softcap, - const int ne00, - const int ne01, - const int ne02, - const int ne03, - const int ne10, - const int ne11, - const int ne12, - const int ne13, - const int ne31, - const int ne32, - const int ne33, - const int nb31, - const int nb32, - const int nb33, - const int nb01, - const int nb02, - const int nb03, - const int nb11, - const int nb12, - const int nb13, - const int nb21, - const int nb22, - const int nb23, - const int ne0, - const int ne1, - const int ne2, - const int ne3) { + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { #if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { @@ -197,7 +177,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -344,7 +324,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -451,7 +431,6 @@ static __global__ void flash_attn_ext_f16( GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23); - GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3); NO_DEVICE_CODE; #endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 6bc0096cc65..d9f1613051d 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -280,22 +280,12 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); - if (GGML_CUDA_CC_IS_AMD(cc)) { #if defined(GGML_HIP_ROCWMMA_FATTN) - if (fp16_mma_available(cc)) { - ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); - return; - } -#endif // defined(GGML_HIP_ROCWMMA_FATTN) - - // On AMD the tile kernels perform poorly, use the vec kernel instead: - if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { - ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); - } else { - ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); - } + if (GGML_CUDA_CC_IS_AMD(cc) && fp16_mma_available(cc)) { + ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); return; } +#endif // defined(GGML_HIP_ROCWMMA_FATTN) if (!fast_fp16_available(cc)) { if (Q->ne[1] <= 8 || Q->ne[0] == 256) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index dfc50ef0daf..03c380897cd 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -55,6 +55,7 @@ #include #include #include +#include #include #include #include @@ -2765,6 +2766,39 @@ static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { } #endif +static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + + //rms norm only supports F32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + + //if rms norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + return false; + } + + //rms_norm kernel assumes contigous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + + return true; +} + static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { // flag used to determine whether it is an integrated_gpu @@ -2774,6 +2808,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if (!use_cuda_graph || cuda_graph_update_required) { + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2781,6 +2816,12 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); + if (!disable_fusion && ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { @@ -3242,13 +3283,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g { ggml_type src0_type = op->src[0]->type; ggml_type src1_type = op->src[1]->type; - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { - return true; - } - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) { - return true; - } - if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + if ((src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_BF16 || src0_type == GGML_TYPE_F16) && + (src1_type == GGML_TYPE_F32 || src1_type == GGML_TYPE_BF16 || src1_type == GGML_TYPE_F16) + ) { return true; } if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) { @@ -3284,12 +3321,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { return true; } - if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { - return true; - } - if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { - return true; - } if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { return true; } @@ -3370,7 +3401,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return op->src[0]->ne[1] % 128 == 0; } case GGML_OP_CONT: - return op->src[0]->type != GGML_TYPE_BF16; + return true; case GGML_OP_DIAG_MASK_INF: return true; case GGML_OP_SOFT_MAX: diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu index 86a54e42bb7..5bb85b4807b 100644 --- a/ggml/src/ggml-cuda/im2col.cu +++ b/ggml/src/ggml-cuda/im2col.cu @@ -10,7 +10,7 @@ static __global__ void im2col_kernel( return; } - const int64_t ksize = OW * (KH > 1 ? KW : 1); + const int64_t ksize = OW * KH; const int64_t kx = i / ksize; const int64_t kd = kx * ksize; const int64_t ky = (i - kd) / OW; diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index 2af63355a19..d6817d804d2 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -12,7 +12,8 @@ // The methods get_i and get_j can be used to get the physical 32 bit index of the lth element of a thread within a tile. // All matrix tiles have ne physical 32 bit elements per warp. // -// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. +// As described in the PTX documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes. +// The API in this file also assumes that the pointers for load_generic are aligned to 16 bytes, unaligned pointers are considered undefined behavior. #include "common.cuh" @@ -66,7 +67,44 @@ namespace ggml_cuda_mma { struct tile { static constexpr int I = I_; static constexpr int J = J_; - static constexpr int ne = I * J / WARP_SIZE; + +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + static constexpr int ne = I * J / 64; + T x[ne] = {0}; + + static __device__ __forceinline__ int get_i(const int l) { + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 4) { + return threadIdx.x % 32; + } else if constexpr (I == 16 && J == 16) { + return 4 * (threadIdx.x / 16) + l; + } else if constexpr (I == 32 && J == 32) { + return 4 * (threadIdx.x / 32) + 8 * (l / 4) + (l % 4); + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } + + static __device__ __forceinline__ int get_j(const int l) { + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> + return (2 * ((threadIdx.x / 16) % 2) + l); + } else if constexpr (I == 16 && J == 8) { + return 2 * (threadIdx.x / 16) + l; + } else if constexpr (I == 32 && J == 4) { + return 2 * (threadIdx.x / 32) + l; + } else if constexpr (I == 16 && J == 16) { + return threadIdx.x % 16; + } else if constexpr (I == 32 && J == 32) { + return threadIdx.x % 32; + } else { + static_assert(I == -1 && J == -1, "template specialization not implemented"); + } + } +#else + static constexpr int ne = I * J / 32; T x[ne] = {0}; static __device__ __forceinline__ int get_i(const int l) { @@ -94,6 +132,7 @@ namespace ggml_cuda_mma { static_assert(I == -1 && J == -1, "template specialization not implemented"); } } +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) }; template @@ -148,10 +187,23 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void load_generic(tile & t, const T * __restrict__ xs0, const int stride) { +#if defined(AMD_MFMA_AVAILABLE) + if constexpr (I == 64 && J == 2) { // Special tile size to load <16, 4> as <16, 8> +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; + } + } else { + int64_t * xi = (int64_t *) t.x; + const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); + xi[0] = xs[0]; + } +#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } +#endif // defined(AMD_MFMA_AVAILABLE) } template @@ -186,7 +238,7 @@ namespace ggml_cuda_mma { template static __device__ __forceinline__ void load_ldmatrix( tile<16, 8, T> & t, const T * __restrict__ xs0, const int stride) { -#ifdef NEW_MMA_AVAILABLE +#if defined(NEW_MMA_AVAILABLE) int * xi = (int * ) t.x; const int * xs = (const int *) xs0 + (threadIdx.x % t.I) * stride + (threadIdx.x / t.I) * (t.J / 2); asm volatile("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];" @@ -393,4 +445,60 @@ namespace ggml_cuda_mma { NO_DEVICE_CODE; #endif // NEW_MMA_AVAILABLE } + + static __device__ __forceinline__ void mma( + tile<16, 16, int> & D, const tile<16, 8, int> & A, const tile<16, 8, int> & B) { +#if defined(AMD_MFMA_AVAILABLE) + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + int32x4_t * acc = (int32x4_t *) D.x; +#if defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x32_i8(((int64_t *) A.x)[0], + ((int64_t *) B.x)[0], + acc[0], + 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA) + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[0], + B.x[0], + acc[0], + 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_16x16x16i8(A.x[1], + B.x[1], + acc[0], + 0, 0, 0); +#endif // defined(CDNA3) +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } + + static __device__ __forceinline__ void mma( + tile<32, 32, int> & D, const tile<32, 4, int> & A, const tile<32, 4, int> & B) { +#if defined(AMD_MFMA_AVAILABLE) + using int32x16_t = __attribute__((__vector_size__(16 * sizeof(int)))) int; + int32x16_t * acc = (int32x16_t *) D.x; +#if defined(CDNA3) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x16_i8(((int64_t *) A.x)[0], + ((int64_t *) B.x)[0], + acc[0], + 0, 0, 0); +#elif defined(CDNA2) || defined(CDNA) + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[0], + B.x[0], + acc[0], + 0, 0, 0); + acc[0] = __builtin_amdgcn_mfma_i32_32x32x8i8(A.x[1], + B.x[1], + acc[0], + 0, 0, 0); +#endif // defined(CDNA3) +#else + GGML_UNUSED(D); + GGML_UNUSED(A); + GGML_UNUSED(B); + NO_DEVICE_CODE; +#endif // AMD_MFMA_AVAILABLE + } } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 2db5b4ab0f0..e2fd0c1c254 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -109,7 +109,8 @@ void ggml_cuda_mul_mat_q( const int64_t s03 = src0->nb[3] / ts_src0; const int64_t s3 = dst->nb[3] / ts_dst; - const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA; + const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc))); if (!ids) { const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 + @@ -250,8 +251,9 @@ void ggml_cuda_op_mul_mat_q( // The stream-k decomposition is only faster for recent NVIDIA GPUs. // Also its fixup needs to allocate a temporary buffer in the memory pool. // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer. - const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && - ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11; + const bool use_stream_k = ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) + || (GGML_CUDA_CC_IS_AMD(cc) && GGML_CUDA_CC_IS_CDNA3(cc))) + && src1_ncols == ne11; const mmq_args args = { src0_dd_i, src0->type, (const int *) src1_ddq_i, nullptr, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride01, ne11, nrows_dst, @@ -304,7 +306,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (new_mma_available(cc)) { + if (new_mma_available(cc) || amd_mfma_available(cc)) { return true; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 9696a320462..36e84be154e 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -90,7 +90,7 @@ struct tile_x_sizes { }; static int get_mmq_x_max_host(const int cc) { - return new_mma_available(cc) ? 128 : + return (amd_mfma_available(cc) || new_mma_available(cc)) ? 128 : GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? #ifdef GGML_CUDA_FORCE_MMQ 128 : 64; @@ -100,12 +100,12 @@ static int get_mmq_x_max_host(const int cc) { } static constexpr __device__ int get_mmq_x_max_device() { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) return 128; -#else // NEW_MMA_AVAILABLE +#else // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) - return 128; + return 64; #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA @@ -115,12 +115,11 @@ static constexpr __device__ int get_mmq_x_max_device() { return MMQ_DP4A_MAX_BATCH_SIZE; #endif // GGML_CUDA_FORCE_MMQ #else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA - return 64; #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } static int get_mmq_y_host(const int cc) { @@ -144,16 +143,25 @@ static constexpr __device__ int get_mmq_y_device() { #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) } -#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_0 + mmq_y/QI4_0, 0} -#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_1 + mmq_y/QI4_1, 0} -#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_0 + mmq_y/(QI8_0/2), 0} -#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*4/QI8_0 + mmq_y/(QI8_0/4), 0} -#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE*2/QI8_1 + mmq_y/(QI8_1/2), 0} -#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE + mmq_y, 0} -#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*WARP_SIZE + mmq_y, mmq_y*WARP_SIZE/QI4_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI5_K + mmq_y/QI5_K, mmq_y*WARP_SIZE/8 + mmq_y/8} -#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*WARP_SIZE*2 + mmq_y, mmq_y*WARP_SIZE/QI6_K + mmq_y/QI6_K, mmq_y*WARP_SIZE/8 + mmq_y/8} +// Decouple shared memory tile sizes from WARP_SIZE to allow for different warp sizes. +// The K dimension of the tiles has either, +// 1*MMQ_TILE_NE_K==32 (always for TILE_Y_K) or 2*MMQ_TILE_NE_K==64 (typically for TILE_X_K), +// 32 bit elements for the quantized data (does not include scales). +// In other words, the size of the quantized data in the K dimension is a multiple of MMQ_TILE_NE_K. +// The final tile size in K direction is padded to avoid shared memory bank conflicts, +// in terms of 32 bit elements that means K % 2 == 1 for dp4a or K % 8 == 4 for mma. +#define MMQ_TILE_NE_K 32 + +#define MMQ_DP4A_TXS_Q4_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_0 + mmq_y/QI4_0, 0} +#define MMQ_DP4A_TXS_Q4_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_1 + mmq_y/QI4_1, 0} +#define MMQ_DP4A_TXS_Q8_0 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_0 + mmq_y/(QI8_0/2), 0} +#define MMQ_DP4A_TXS_Q8_0_16 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*4/QI8_0 + mmq_y/(QI8_0/4), 0} +#define MMQ_DP4A_TXS_Q8_1 tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K*2/QI8_1 + mmq_y/(QI8_1/2), 0} +#define MMQ_DP4A_TXS_Q2_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K + mmq_y, 0} +#define MMQ_DP4A_TXS_Q3_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q4_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K + mmq_y, mmq_y*MMQ_TILE_NE_K/QI4_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q5_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI5_K + mmq_y/QI5_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} +#define MMQ_DP4A_TXS_Q6_K tile_x_sizes{mmq_y*MMQ_TILE_NE_K*2 + mmq_y, mmq_y*MMQ_TILE_NE_K/QI6_K + mmq_y/QI6_K, mmq_y*MMQ_TILE_NE_K/8 + mmq_y/8} static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) { switch (type) { @@ -179,11 +187,11 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml } } -#define MMQ_MMA_TILE_X_K_Q8_0 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q8_1 (2*WARP_SIZE + 2*WARP_SIZE/QI8_0 + 4) -#define MMQ_MMA_TILE_X_K_Q2_K (2*WARP_SIZE + WARP_SIZE + 4) -#define MMQ_MMA_TILE_X_K_Q3_K (2*WARP_SIZE + WARP_SIZE/2 + 4) -#define MMQ_MMA_TILE_X_K_Q6_K (2*WARP_SIZE + WARP_SIZE/QI6_K + WARP_SIZE/8 + 7) +#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0 + 4) +#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K + 4) +#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2 + 4) +#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K + MMQ_TILE_NE_K/8 + 7) static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding."); static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding."); @@ -215,42 +223,80 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) { } } -#define MMQ_TILE_Y_K (WARP_SIZE + WARP_SIZE/QI8_1) +// block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales) +#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1) static int mmq_get_granularity_host(const int mmq_x, const int cc) { - return new_mma_available(cc) && mmq_x >= 48 ? 16 : 8; + if (amd_mfma_available(cc)) { + return mmq_x >= 128 ? 32 : 16; + } else if (new_mma_available(cc) && mmq_x >= 48) { + return 16; + } else { + return 8; + } } -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) +static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { + return mmq_x >= 128 ? 32 : 16; +} +#elif defined(NEW_MMA_AVAILABLE) static constexpr __device__ int mmq_get_granularity_device(const int mmq_x) { return mmq_x >= 48 ? 16 : 8; } #else -static constexpr __device__ int mmq_get_granularity_device(const int /* mmq_x */) { +static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) { return 8; } -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE + +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +static int mmq_get_nwarps_host(const int cc) { + return amd_mfma_available(cc) ? 8 : 4; +} +#else +static int mmq_get_nwarps_host(const int /*cc*/) { + return 8; +} +#endif // (GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) + +static constexpr __device__ int mmq_get_nwarps_device() { +#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +#if defined(AMD_MFMA_AVAILABLE) + return 8; +#else + return 4; +#endif // AMD_MFMA_AVAILABLE +#else + return 8; +#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) +} // ------------------------------------------------------------ -template static __device__ __forceinline__ void load_tiles_q4_0( +template static __device__ __forceinline__ void load_tiles_q4_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + 2*WARP_SIZE); + float * x_df = (float *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI4_0; - const int kqsx = threadIdx.x % QI4_0; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_0; + const int kqsx = txi % QI4_0; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -259,20 +305,21 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b2(bxi->qs, kqsx); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + 0] = __vsubss4((qs0 >> 0) & 0x0F0F0F0F, 0x08080808); x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI4_0) + kqsx + QI4_0] = __vsubss4((qs0 >> 4) & 0x0F0F0F0F, 0x08080808); #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_0) { - int i = i0 + threadIdx.y * QI4_0 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -280,17 +327,19 @@ template static __device__ __forceinlin const block_q4_0 * bxi = (const block_q4_0 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_0, mmq_y); const int * x_qs = (const int *) x; @@ -299,7 +348,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_0*VDR_Q4_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -307,7 +356,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); @@ -320,32 +369,37 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_dp4a( u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_0)]; } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_0_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_0], u, - x_df[i*(WARP_SIZE/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_0_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_0], u, + x_df[i*(MMQ_TILE_NE_K/QI4_0) + i/QI4_0 + k0/(QR4_0*QI4_0)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q4_1( +template static __device__ __forceinline__ void load_tiles_q4_1( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI4_1; - const int kqsx = threadIdx.x % QI4_1; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_1; + const int kqsx = txi % QI4_1; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -354,20 +408,21 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbx; const int qs0 = get_int_b4(bxi->qs, kqsx); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + 0] = (qs0 >> 0) & 0x0F0F0F0F; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI4_1) + kqsx + QI4_1] = (qs0 >> 4) & 0x0F0F0F0F; #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_1) { - int i = i0 + threadIdx.y * QI4_1 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -375,17 +430,19 @@ template static __device__ __forceinlin const block_q4_1 * bxi = (const block_q4_1 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else - x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; -#endif // NEW_MMA_AVAILABLE + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_1, mmq_y); const int * x_qs = (const int *) x; @@ -394,7 +451,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_1*VDR_Q4_1_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -402,7 +459,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; const int kyqs = QI8_1 * ((k01/2) / (QI8_1/2)) + (k01/2) % (QI8_1/2); @@ -415,32 +472,37 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_dp4a( u[2*l+1] = y_qs[j*MMQ_TILE_Y_K + kyqs + (l + QI4_1)]; } - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_1_q8_1_impl - (&x_qs[i*(WARP_SIZE + 1) + k0/QR4_1], u, - x_dm[i*(WARP_SIZE/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_1_q8_1_impl + (&x_qs[i*(MMQ_TILE_NE_K + 1) + k0/QR4_1], u, + x_dm[i*(MMQ_TILE_NE_K/QI4_1) + i/QI4_1 + k0/(QR4_1*QI4_1)], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q5_0( +template static __device__ __forceinline__ void load_tiles_q5_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI5_0; - const int kqsx = threadIdx.x % QI5_0; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_0); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_0; + const int kqsx = txi % QI5_0; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -449,7 +511,7 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbx; const int ql = get_int_b2(bxi->qs, kqsx); - const int qh = get_int_b2(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_0)); + const int qh = get_int_b2(bxi->qh, 0) >> (4 * kqsx); int qs0 = (ql >> 0) & 0x0F0F0F0F; qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 @@ -465,21 +527,22 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; #else - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_0) + kqsx + QI5_0] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_0) { - int i = i0 + threadIdx.y * QI5_0 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -487,32 +550,37 @@ template static __device__ __forceinlin const block_q5_0 * bxi = (const block_q5_0 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(WARP_SIZE/QI5_0) + i/QI5_0 + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI5_0) + i/QI5_0 + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_q5_1( +template static __device__ __forceinline__ void load_tiles_q5_1( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI5_1; - const int kqsx = threadIdx.x % QI5_1; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_1); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI5_1; + const int kqsx = txi % QI5_1; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -521,7 +589,7 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbx; const int ql = get_int_b4(bxi->qs, kqsx); - const int qh = get_int_b4(bxi->qh, 0) >> (4 * (threadIdx.x % QI5_1)); + const int qh = get_int_b4(bxi->qh, 0) >> (4 * kqsx); int qs0 = (ql >> 0) & 0x0F0F0F0F; qs0 |= (qh << 4) & 0x00000010; // 0 -> 4 @@ -535,21 +603,22 @@ template static __device__ __forceinlin qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + 0] = qs0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; #else - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; - x_qs[i*(2*WARP_SIZE + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + 0] = qs0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kbx*(2*QI5_1) + kqsx + QI5_1] = qs1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI5_1; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI5_1) { - int i = i0 + threadIdx.y * QI5_1 + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -557,32 +626,38 @@ template static __device__ __forceinlin const block_q5_1 * bxi = (const block_q5_1 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + kbxd] = bxi->dm; #else - x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; -#endif // NEW_MMA_AVAILABLE + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + kbxd] = bxi->dm; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_q8_0( +template static __device__ __forceinline__ void load_tiles_q8_0( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_tile + 2*WARP_SIZE); + float * x_df = (float *) (x_tile + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI8_0; - const int kqsx = threadIdx.x % QI8_0; + // MMQ_ITER_K / (4 * QR8_0) == 64 required. but NV has only 32 threads per warp + constexpr int threads_per_row = 32; + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI8_0; + const int kqsx = txi % QI8_0; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -590,21 +665,22 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbx; -#ifdef NEW_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); #else - x_qs[i*(2*WARP_SIZE + 1) + 0 + threadIdx.x] = get_int_b2(bxi[0].qs, kqsx); - x_qs[i*(2*WARP_SIZE + 1) + WARP_SIZE + threadIdx.x] = get_int_b2(bxi[WARP_SIZE/QI8_0].qs, kqsx); -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 0 + txi] = get_int_b2(bxi[0].qs, kqsx); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + MMQ_TILE_NE_K + txi] = get_int_b2(bxi[MMQ_TILE_NE_K/QI8_0].qs, kqsx); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = 2*WARP_SIZE / QI8_0; + constexpr int blocks_per_tile_x_row = 2*MMQ_TILE_NE_K / QI8_0; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI8_0/2) { - int i = i0 + threadIdx.y * (QI8_0/2) + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -612,17 +688,19 @@ template static __device__ __forceinlin const block_q8_0 * bxi = (const block_q8_0 *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = bxi->d; #else - x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + kbxd] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q8_0, mmq_y); const int * x_qs = (const int *) x; @@ -631,7 +709,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -639,21 +717,76 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % WARP_SIZE], - x_df[i*(2*WARP_SIZE/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (WARP_SIZE/QI8_1)]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k0 % MMQ_TILE_NE_K], + x_df[i*(2*MMQ_TILE_NE_K/QI8_0) + i/(QI8_0/2) + k0/QI8_0], y_df[j*MMQ_TILE_Y_K + (k0/QI8_1) % (MMQ_TILE_NE_K/QI8_1)]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + float dB; + const int j = j0 + tile_C::get_j(0); + if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) { + dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1]; + } else { + dB = __low2float(y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + const float dA = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l]*dA*dB; + } + } + } + } +#else typedef tile<16, 8, int> tile_A; typedef tile< 8, 8, int> tile_B; typedef tile<16, 8, int> tile_C; @@ -662,23 +795,23 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + 2*WARP_SIZE; + const float * x_df = (const float *) x_qs + 2*MMQ_TILE_NE_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; const half2 * y_ds = (const half2 *) y; - tile_A A[ntx][WARP_SIZE/QI8_0]; - float dA[ntx][tile_C::ne/2][WARP_SIZE/QI8_0]; + tile_A A[ntx][MMQ_TILE_NE_K/QI8_0]; + float dA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_0]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0); @@ -689,7 +822,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0]; @@ -700,7 +833,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { tile_B B; float dB[tile_C::ne/2]; @@ -729,11 +862,14 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( } } } +#endif // defined(AMD_MFMA_AVAILABLE) } -template +template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_1, mmq_y); const int * x_qs = (const int *) x; @@ -742,7 +878,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += VDR_Q8_0_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += VDR_Q8_0_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -750,45 +886,95 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_1_q8_1_impl - (&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - x_dm[i*(WARP_SIZE/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_1_q8_1_impl + (&x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + x_dm[i*(MMQ_TILE_NE_K/QI5_1) + i/QI5_1 + k0/QI8_1], y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; - typedef tile<16, 8, int> tile_A; - typedef tile< 8, 8, int> tile_B; - typedef tile<16, 8, int> tile_C; + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; + const int * y_qs = (const int *) y + 4; + const half2 * y_dm = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B; + load_generic(B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float2 dsB = __half22float2(y_dm[j*MMQ_TILE_Y_K + k01/QI8_1]); + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_A::I + tile_C::get_i(l); + float2 dmA = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.x*dsB.x*C.x[l]; + sum[(j0/tile_C::J + n)*tile_C::ne + l] += dmA.y*dsB.y; + } + } + } + } +#else + typedef tile<16, 8, int> tile_A; + typedef tile< 8, 8, int> tile_B; + typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::J*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + 2*WARP_SIZE; + const half2 * x_dm = (const half2 *) x_qs + 2*MMQ_TILE_NE_K; const int * y_qs = (const int *) y + 4; const half2 * y_dm = (const half2 *) y; - tile_A A[ntx][WARP_SIZE/QI8_1]; - float2 dmA[ntx][tile_C::ne/2][WARP_SIZE/QI8_1]; + tile_A A[ntx][MMQ_TILE_NE_K/QI8_1]; + float2 dmA[ntx][tile_C::ne/2][MMQ_TILE_NE_K/QI8_1]; const int i0 = (threadIdx.y/ntx)*rows_per_warp; #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_1 + k0, MMQ_MMA_TILE_X_K_Q8_1); @@ -799,7 +985,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( const int i = i0 + n*tile_A::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; dmA[n][l][k01/QI8_1] = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + k0/QI8_1]); @@ -810,7 +996,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { tile_B B; float2 dsB[tile_C::ne/2]; @@ -836,11 +1022,15 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma( } } } +#endif // defined(AMD_MFMA_AVAILABLE) } -template +// Used for Q3_K, IQ2_S, and IQ2_XS +template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; const int * x_qs = (const int *) x; @@ -849,7 +1039,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_0) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_0) { const int k0 = k00 + k01; #pragma unroll @@ -857,23 +1047,73 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q8_0_16_q8_1_impl( - &x_qs[i*(2*WARP_SIZE + 1) + k0], + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q8_0_16_q8_1_impl( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], + &x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + k0/(QI8_0/2)], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template +// Used for Q3_K, IQ2_S, and IQ2_XS: +template static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4] * dB; + } + } + } + } +#elif defined(NEW_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; @@ -884,10 +1124,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE*2; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -899,7 +1139,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { const int k0 = k00 + k01; load_ldmatrix(((tile_A_8 *) A[n])[k01/8], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q3_K + k0, MMQ_MMA_TILE_X_K_Q3_K); @@ -910,7 +1150,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 4) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { const int k0 = k00 + k01; dA[n][l][k01/4] = x_df[i*MMQ_MMA_TILE_X_K_Q3_K + k0/4]; @@ -921,7 +1161,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { tile_B B[2]; float dB[tile_C::ne/2]; @@ -952,26 +1192,29 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE } -template static __device__ __forceinline__ void load_tiles_q2_K( +template static __device__ __forceinline__ void load_tiles_q2_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % QI2_K; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR2_K); + constexpr int nrows = ggml_cuda_get_physical_warp_size() / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI2_K) { - int i = i0 + threadIdx.y*(WARP_SIZE/QI2_K) + threadIdx.x/QI2_K; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -987,11 +1230,11 @@ template static __device__ __forceinlin const int x_qs_k = (x_ql_0 >> (2*l)) & 0x03030303; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q2_K + k] = x_qs_k; #else - x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int sc_m = bxi->scales[kqsx]; @@ -1002,17 +1245,19 @@ template static __device__ __forceinlin const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4)); #endif // FAST_FP16_AVAILABLE -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + kqsx] = x_dm_ik; #else - x_dm[i*(WARP_SIZE + 1) + kqsx] = x_dm_ik; -#endif // NEW_MMA_AVAILABLE + x_dm[i*(MMQ_TILE_NE_K + 1) + kqsx] = x_dm_ik; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q2_K, mmq_y); const int * x_qs = (const int *) x; @@ -1029,7 +1274,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1037,13 +1282,13 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; constexpr int ns = 2; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } } @@ -1052,7 +1297,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( // Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop. // As a workaround 2 separate loops are used instead. #pragma unroll - for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { + for (int k01 = MMQ_TILE_NE_K/2; k01 < MMQ_TILE_NE_K; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1060,23 +1305,89 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; constexpr int ns = 1; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], - &x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q2_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], + &x_dm[i*(MMQ_TILE_NE_K + 1) + k0/4], k01 < MMQ_TILE_NE_K/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y, &y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; + const int * y_qs = (const int *) y + 4; + const half2 * y_ds = (const half2 *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = (k01 < MMQ_TILE_NE_K/2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K]).x/2 : __half22float2(y_ds[j*MMQ_TILE_Y_K]).y/2; + const float sB = (k01 >= MMQ_TILE_NE_K * 3/4) ? 0 + : (((k01/4)%2) ? __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).y + : __half22float2(y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]).x); + + tile_C Cm; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tile_A A1; + A1.x[0] = 0x01010101; + A1.x[1] = 0x01010101; + mma(Cm, A1, B[0]); + } + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C Cd; + mma(Cd, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/4]); + float tmp = Cd.x[l]*dm.x; + if (k01 >= MMQ_TILE_NE_K * 3/4) { + tmp -= Cm.x[l]*dm.y; + } + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*dB; + sum[(j0/tile_C::J + n)*tile_C::ne + l] -= dm.y*sB; + } + } + } + } +#elif defined(NEW_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile<16, 8, int> tile_A_8; @@ -1087,10 +1398,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const half2 * x_dm = (const half2 *) x_qs + WARP_SIZE*2; + const half2 * x_dm = (const half2 *) x_qs + MMQ_TILE_NE_K*2; const int * y_qs = (const int *) y + 4; const half2 * y_ds = (const half2 *) y; @@ -1103,7 +1414,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { const int k0 = k00 + k01; load_ldmatrix(((tile_A_8 *) A[n])[k01/QI8_1], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q2_K + k0, MMQ_MMA_TILE_X_K_Q2_K); @@ -1117,7 +1428,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int i = i0 + n*tile_C::I + tile_C::get_i(2*l); #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1/2) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1/2) { const int k0 = k00 + k01; const float2 dm = __half22float2(x_dm[i*MMQ_MMA_TILE_X_K_Q2_K + k0/(QI8_1/2)]); @@ -1140,7 +1451,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QI8_1) { tile_B B[2]; // Here load_generic is faster than load_ldmatrix. @@ -1148,7 +1459,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( load_generic(B[1], y_qs + j0*MMQ_TILE_Y_K + (k01 + tile_B::J), MMQ_TILE_Y_K); tile_C Cm[2]; - if (k01 >= WARP_SIZE * 3/4) { + if (k01 >= MMQ_TILE_NE_K * 3/4) { tile_A A1; A1.x[0] = 0x01010101; A1.x[1] = 0x01010101; @@ -1166,16 +1477,16 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #pragma unroll for (int l = 0; l < tile_C::ne; ++l) { float tmp = Cd[0].x[l]*dA[n][l/2][k01/4 + 0] + Cd[1].x[l]*dA[n][l/2][k01/4 + 1]; - if (k01 >= WARP_SIZE * 3/4) { + if (k01 >= MMQ_TILE_NE_K * 3/4) { tmp -= Cm[0].x[l]*mA[n][l/2][k01/4 + 0] + Cm[1].x[l]*mA[n][l/2][k01/4 + 1]; } - sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < WARP_SIZE/2 ? dB[l%2].x : dB[l%2].y); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += tmp*(k01 < MMQ_TILE_NE_K/2 ? dB[l%2].x : dB[l%2].y); } } } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE * 3/4; k01 += QI8_1) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K * 3/4; k01 += QI8_1) { float2 sB[tile_C::ne/2]; #pragma unroll @@ -1198,27 +1509,31 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE } -template static __device__ __forceinline__ void load_tiles_q3_K( +template static __device__ __forceinline__ void load_tiles_q3_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % QI3_K; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR3_K); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI3_K) { - int i = i0 + threadIdx.y * (WARP_SIZE/QI3_K) + threadIdx.x / QI3_K; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -1238,17 +1553,18 @@ template static __device__ __forceinlin const int x_qs_k = __vsubss4(x_ql_k | x_qh_k, 0x04040404); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + k] = x_qs_k; #else - x_qs[i*(2*WARP_SIZE + 1) + k] = x_qs_k; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k] = x_qs_k; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { - int i = i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/4; if (need_check) { i = min(i, i_max); @@ -1256,7 +1572,7 @@ template static __device__ __forceinlin const block_q3_K * bxi = (const block_q3_K *) x + kbx0 + i*stride; - const int ksc = threadIdx.x % (WARP_SIZE/8); + const int ksc = threadIdx.x % 4; const int ksc_low = ksc % (QI3_K/8); const int shift_low = 4 * (ksc / (QI3_K/8)); @@ -1268,23 +1584,23 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) const int8_t * sc8 = (const int8_t *) ≻ const float d = bxi->d; #pragma unroll for (int l = 0; l < int(sizeof(int)); ++l) { - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l]; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*ksc + l] = d*sc8[l]; } #else - x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = sc; -#endif // NEW_MMA_AVAILABLE + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = sc; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -#ifndef NEW_MMA_AVAILABLE +#if !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*WARP_SIZE) { - int i = (i0 + threadIdx.y*WARP_SIZE + threadIdx.x) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1294,12 +1610,14 @@ template static __device__ __forceinlin x_df[i] = bxi->d; } -#endif // NEW_MMA_AVAILABLE +#endif // !(defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE)) } -template +template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q3_K, mmq_y); const int * x_qs = (const int *) x; @@ -1309,7 +1627,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR3_K*VDR_Q3_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1317,13 +1635,13 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int8_t * scales = ((const int8_t *) (x_sc + i*(WARP_SIZE/8) + i/8)) + k0/4; + const int8_t * scales = ((const int8_t *) (x_sc + i*(MMQ_TILE_NE_K/8) + i/8)) + k0/4; - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q3_K_q8_1_impl_mmq( - &x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q3_K_q8_1_impl_mmq( + &x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], scales, x_df[i], y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } @@ -1340,72 +1658,85 @@ static __device__ __forceinline__ int unpack_scales_q45_K(const int * scales, co ((scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030); // upper 2 bits } -template static __device__ __forceinline__ void load_tiles_q4_K( +template static __device__ __forceinline__ void load_tiles_q4_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + 2*WARP_SIZE); + half2 * x_dm = (half2 *) (x_qs + 2*MMQ_TILE_NE_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; - const int qs0 = get_int_b4(bxi->qs, threadIdx.x); + const int qs0 = get_int_b4(bxi->qs, txi); -#ifdef NEW_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(threadIdx.x/8) + threadIdx.x % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 0] = (qs0 >> 0) & 0x0F0F0F0F; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 16*(txi/8) + txi % 8 + 8] = (qs0 >> 4) & 0x0F0F0F0F; #else - x_qs[i*(WARP_SIZE + 1) + threadIdx.x] = qs0; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(MMQ_TILE_NE_K + 1) + txi] = qs0; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -#ifdef NEW_MMA_AVAILABLE - +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { - int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride; - const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/16); + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; - const int sc32 = unpack_scales_q45_K(scales, ksc + 0); - const int m32 = unpack_scales_q45_K(scales, ksc + 2); + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); - const uint8_t * sc8 = (const uint8_t *) &sc32; - const uint8_t * m8 = (const uint8_t *) &m32; + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; - const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); -#pragma unroll - for (int l = 0; l < int(sizeof(int)); ++l) { - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + #pragma unroll + for (int l = 0; l < sizeof(int); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } } } - #else - #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI4_K) { - int i = (i0 + threadIdx.y*QI4_K + threadIdx.x) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1415,30 +1746,32 @@ template static __device__ __forceinlin x_dm[i] = bxi->dm; } - + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / (QI4_K/8); + const block_q4_K * bxi = (const block_q4_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/8); + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); const int scales8 = unpack_scales_q45_K(scales, ksc); - x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -template +template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q4_K, mmq_y); const int * x_qs = (const int *) x; @@ -1448,7 +1781,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR4_K*VDR_Q4_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1456,97 +1789,110 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const uint8_t * sc = (const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/32] + 2*(k01/16); + const uint8_t * sc = (const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/32] + 2*(k01/16); - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q4_K_q8_1_impl_mmq( - &x_qs[i*(WARP_SIZE + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q4_K_q8_1_impl_mmq( + &x_qs[i*(MMQ_TILE_NE_K + 1) + k0/2], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q5_K( +template static __device__ __forceinline__ void load_tiles_q5_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_dm = (half2 *) (x_qs + WARP_SIZE*2); + half2 * x_dm = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); int * x_qs = (int *) x_tile; half2 * x_dm = (half2 *) (x_qs + txs.qs); int * x_sc = (int *) (x_dm + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR5_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - const int ky = QR5_K*threadIdx.x; + const int ky = QR5_K*txi; - const int ql = get_int_b4(bxi->qs, threadIdx.x); + const int ql = get_int_b4(bxi->qs, txi); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_b4(bxi->qh, threadIdx.x % (QI5_K/4)); - const int qh0 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 0)) << 4) & 0x10101010; - const int qh1 = ((qh >> (2 * (threadIdx.x / (QI5_K/4)) + 1)) << 4) & 0x10101010; + const int qh = get_int_b4(bxi->qh, txi % (QI5_K/4)); + const int qh0 = ((qh >> (2 * (txi / (QI5_K/4)) + 0)) << 4) & 0x10101010; + const int qh1 = ((qh >> (2 * (txi / (QI5_K/4)) + 1)) << 4) & 0x10101010; - const int kq0 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + 0; - const int kq1 = ky - ky % (QI5_K/2) + threadIdx.x % (QI5_K/4) + QI5_K/4; + const int kq0 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + 0; + const int kq1 = ky - ky % (QI5_K/2) + txi % (QI5_K/4) + QI5_K/4; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq0] = ql0 | qh0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + kq1] = ql1 | qh1; #else - x_qs[i*(2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_qs[i*(2*WARP_SIZE + 1) + kq1] = ql1 | qh1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = ql0 | qh0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = ql1 | qh1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -#ifdef NEW_MMA_AVAILABLE - +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + constexpr int rows_per_warp = warp_size / 2; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*16) { - int i = (i0 + threadIdx.y*16 + threadIdx.x/(WARP_SIZE/16)) % mmq_y; - - if (need_check) { - i = min(i, i_max); - } + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { +#if defined(AMD_MFMA_AVAILABLE) + // Need if on AMD instead of % because warp_size == 64 + // This causes double work and throughput loss (MI300X) + // H100 loses about 100 t/s with 'if' condition over '%' + int i = i0 + threadIdx.y*rows_per_warp + threadIdx.x/2; + if (i < mmq_y) { +#else + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; + { +#endif // defined(AMD_MFMA_AVAILABLE) + if (need_check) { + i = min(i, i_max); + } - const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; + const block_q5_K * bxi = (const block_q5_K *) x + kbx0 + i*stride; - const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/16); + const int * scales = (const int *) bxi->scales; + const int ksc = threadIdx.x % 2; - const int sc32 = unpack_scales_q45_K(scales, ksc + 0); - const int m32 = unpack_scales_q45_K(scales, ksc + 2); + const int sc32 = unpack_scales_q45_K(scales, ksc + 0); + const int m32 = unpack_scales_q45_K(scales, ksc + 2); - const uint8_t * sc8 = (const uint8_t *) &sc32; - const uint8_t * m8 = (const uint8_t *) &m32; + const uint8_t * sc8 = (const uint8_t *) &sc32; + const uint8_t * m8 = (const uint8_t *) &m32; - const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); + const half2 dm = bxi->dm * make_half2(1.0f, -1.0f); #pragma unroll - for (int l = 0; l < int(sizeof(int)); ++l) { - x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + for (int l = 0; l < int(sizeof(int)); ++l) { + x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]); + } } } - #else - #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*QI5_K) { - int i = (i0 + threadIdx.y*QI5_K + threadIdx.x) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1557,9 +1903,10 @@ template static __device__ __forceinlin x_dm[i] = bxi->dm; } + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps*8) { - int i = (i0 + threadIdx.y*8 + threadIdx.x/(WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); @@ -1569,17 +1916,19 @@ template static __device__ __forceinlin const int * scales = (const int *) bxi->scales; - const int ksc = threadIdx.x % (WARP_SIZE/8); + const int ksc = threadIdx.x % (MMQ_TILE_NE_K/8); const int scales8 = unpack_scales_q45_K(scales, ksc); - x_sc[i*(WARP_SIZE/8) + i/8 + ksc] = scales8; + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + ksc] = scales8; } -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } -template +template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q5_K, mmq_y); const int * x_qs = (const int *) x; @@ -1589,7 +1938,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const half2 * y_ds = (const half2 *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR5_K*VDR_Q5_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1597,36 +1946,42 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k00/32]) + 2*(k01/16); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k00/32]) + 2*(k01/16); - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q5_K_q8_1_impl_mmq( - &x_qs[i*(QR5_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q5_K_q8_1_impl_mmq( + &x_qs[i*(QR5_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, sc+8, x_dm[i], &y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template static __device__ __forceinline__ void load_tiles_q6_K( +template static __device__ __forceinline__ void load_tiles_q6_K( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); - int * x_sc = (int *) (x_df + WARP_SIZE/QI6_K); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); + int * x_sc = (int *) (x_df + MMQ_TILE_NE_K/QI6_K); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); int * x_sc = (int *) (x_df + txs.dm); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR6_K); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -1634,67 +1989,67 @@ template static __device__ __forceinlin const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; - const int ql = get_int_b2(bxi->ql, threadIdx.x); + const int ql = get_int_b2(bxi->ql, txi); const int ql0 = (ql >> 0) & 0x0F0F0F0F; const int ql1 = (ql >> 4) & 0x0F0F0F0F; - const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (threadIdx.x / (QI6_K/2)) + threadIdx.x % (QI6_K/4)); - const int qh0 = ((qh >> ((threadIdx.x & 0x08) >> 2)) << 4) & 0x30303030; - const int qh1 = (qh >> ((threadIdx.x & 0x08) >> 2)) & 0x30303030; + const int qh = get_int_b2(bxi->qh, (QI6_K/4) * (txi / (QI6_K/2)) + txi % (QI6_K/4)); + const int qh0 = ((qh >> ((txi & 0x08) >> 2)) << 4) & 0x30303030; + const int qh1 = (qh >> ((txi & 0x08) >> 2)) & 0x30303030; - const int kq0 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + 0; - const int kq1 = 2*threadIdx.x - threadIdx.x % (QI6_K/2) + QI6_K/2; + const int kq0 = 2*txi - txi % (QI6_K/2) + 0; + const int kq1 = 2*txi - txi % (QI6_K/2) + QI6_K/2; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq0] = __vsubss4(ql0 | qh0, 0x20202020); x_qs[i*MMQ_MMA_TILE_X_K_Q6_K + kq1] = __vsubss4(ql1 | qh1, 0x20202020); #else - x_qs[i*(2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_qs[i*(2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_qs[i*(2*MMQ_TILE_NE_K + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 - const int kbxd = threadIdx.x % blocks_per_tile_x_row; // == 0 if QK_K == 256 - #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI6_K) { - int i = (i0 + threadIdx.y * QI6_K + threadIdx.x / blocks_per_tile_x_row) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*warp_size) { + int i = (i0 + threadIdx.y*warp_size + threadIdx.x) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + kbxd; + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q6_K + kbxd] = bxi->d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q6_K] = bxi->d; #else - x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K + kbxd] = bxi->d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K] = bxi->d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } + constexpr int rows_per_warp = warp_size / 4; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + threadIdx.y * 8 + threadIdx.x / (WARP_SIZE/8)) % mmq_y; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps*rows_per_warp) { + int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/(MMQ_TILE_NE_K/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (WARP_SIZE/8)) / 4; + const block_q6_K * bxi = (const block_q6_K *) x + kbx0 + i*stride + (threadIdx.x % (MMQ_TILE_NE_K/8)) / 4; -#ifdef NEW_MMA_AVAILABLE - x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_sc[i*MMQ_MMA_TILE_X_K_Q6_K + threadIdx.x%4] = get_int_b2(bxi->scales, threadIdx.x % (MMQ_TILE_NE_K/8)); #else - x_sc[i*(WARP_SIZE/8) + i/8 + threadIdx.x % (WARP_SIZE/8)] = get_int_b2(bxi->scales, threadIdx.x % (QI6_K/8)); -#endif // NEW_MMA_AVAILABLE + x_sc[i*(MMQ_TILE_NE_K/8) + i/8 + threadIdx.x%(MMQ_TILE_NE_K/8)] = get_int_b2(bxi->scales, threadIdx.x%(QI6_K/8)); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_Q6_K, mmq_y); const int * x_qs = (const int *) x; @@ -1704,7 +2059,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const float * y_df = (const float *) y; // #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += QR6_K*VDR_Q6_K_Q8_1_MMQ) { const int k0 = k00 + k01; #pragma unroll @@ -1712,23 +2067,74 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_dp4a( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k0/16]); + const int8_t * sc = ((const int8_t *) &x_sc[i * (MMQ_TILE_NE_K/8) + i/8 + k0/16]); - sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q6_K_q8_1_impl_mmq( - &x_qs[i*(QR6_K*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, - x_df[i*(WARP_SIZE/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); + sum[j0/nwarps*mmq_y/warp_size + i0/warp_size] += vec_dot_q6_K_q8_1_impl_mmq( + &x_qs[i*(QR6_K*MMQ_TILE_NE_K + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01], sc, + x_df[i*(MMQ_TILE_NE_K/QI6_K) + i/QI6_K], &y_df[j*MMQ_TILE_Y_K + k01/QI8_1]); } } } } -template +template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) { -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) + typedef tile<16, 8, int> tile_A; + typedef tile<16, 8, int> tile_B; + typedef tile<16, 16, int> tile_C; + typedef tile<64, 2, int> tile_load; + + constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int rows_per_warp = granularity; + constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. + + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); + + const int * x_qs = (const int *) x; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; + const int * y_qs = (const int *) y + 4; + const float * y_df = (const float *) y; + + const int i0 = (threadIdx.y / ntx) * rows_per_warp; + + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 4) { + const int k0 = k00 + k01; + + tile_A A[ntx]; +#pragma unroll + for (int n = 0; n < ntx; ++n) { + load_generic(((tile_load *) A)[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + k0, MMQ_MMA_TILE_X_K_Q6_K); + } + +#pragma unroll + for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { + tile_B B[1]; + load_generic(((tile_load *) B)[0], y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); + + const int j = j0 + tile_C::get_j(0); + const float dB = y_df[j*MMQ_TILE_Y_K + k01/QI8_1] / 2; + +#pragma unroll + for (int n = 0; n < ntx; ++n) { + tile_C C; + mma(C, A[n], B[0]); + +#pragma unroll + for (int l = 0; l < tile_C::ne; ++l) { + const int i = i0 + n*tile_C::I + tile_C::get_i(l); + const int8_t * sc = (const int8_t *) (x_sc + i*MMQ_MMA_TILE_X_K_Q6_K + k00/16); + sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x[l] * sc[k01/4] * x_df[i*MMQ_MMA_TILE_X_K_Q6_K] * dB; + } + } + } + } +#elif defined(NEW_MMA_AVAILABLE) typedef tile<16, 4, int> tile_A; typedef tile< 8, 4, int> tile_B; @@ -1738,11 +2144,11 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( constexpr int rows_per_warp = 2 * granularity; constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. - y += (threadIdx.y % ntx) * (tile_B::I*MMQ_TILE_Y_K); + y += (threadIdx.y % ntx) * (tile_C::J*MMQ_TILE_Y_K); const int * x_qs = (const int *) x; - const float * x_df = (const float *) x_qs + WARP_SIZE*2; - const int * x_sc = (const int *) x_df + WARP_SIZE/QI6_K; + const float * x_df = (const float *) x_qs + MMQ_TILE_NE_K*2; + const int * x_sc = (const int *) x_df + MMQ_TILE_NE_K/QI6_K; const int * y_qs = (const int *) y + 4; const float * y_df = (const float *) y; @@ -1755,7 +2161,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #pragma unroll for (int n = 0; n < ntx; ++n) { #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { const int k0 = k00 + k01; load_ldmatrix(A[n][k01/4 + 0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q6_K + (k0 + 0), MMQ_MMA_TILE_X_K_Q6_K); @@ -1763,7 +2169,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 16) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 16) { const int k0 = k00 + k01; #pragma unroll @@ -1793,7 +2199,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( float tmp[ntx][tile_C::ne] = {{0.0f}}; #pragma unroll - for (int k01 = 0; k01 < WARP_SIZE; k01 += 8) { + for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 8) { tile_B B[2]; float dB[tile_C::ne/2]; @@ -1832,27 +2238,32 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( #else GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00); NO_DEVICE_CODE; -#endif // NEW_MMA_AVAILABLE +#endif // AMD_MFMA_AVAILABLE } -template static __device__ __forceinline__ void load_tiles_iq4_nl( +template static __device__ __forceinline__ void load_tiles_iq4_nl( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_NL, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = threadIdx.x / QI4_NL; - const int kqsx = threadIdx.x % QI4_NL; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_NL); + constexpr int nrows = warp_size / threads_per_row; + const int txi = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; + const int kbx = txi / QI4_NL; + const int kqsx = txi % QI4_NL; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); @@ -1862,22 +2273,24 @@ template static __device__ __forceinlin const int aux_q4 = get_int_b2(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); - const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef NEW_MMA_AVAILABLE - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; - x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; + const int k0 = kbx * (2 * QI4_NL) + kqsx; + +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; + x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + QI4_NL] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; - x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + QI4_NL] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL; + constexpr int blocks_per_tile_x_row = MMQ_TILE_NE_K / QI4_NL; + constexpr int rows_per_warp = warp_size / blocks_per_tile_x_row; const int kbxd = threadIdx.x % blocks_per_tile_x_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) { - int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / blocks_per_tile_x_row; if (need_check) { i = min(i, i_max); @@ -1885,31 +2298,35 @@ template static __device__ __forceinlin const block_iq4_nl * bxi = (const block_iq4_nl *) x + kbx0 + i*stride + kbxd; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = __half2float(bxi->d); #else - x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = __half2float(bxi->d); -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/QI4_NL) + i/QI4_NL + kbxd] = __half2float(bxi->d); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq2_xxs( +template static __device__ __forceinline__ void load_tiles_iq2_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI2_XXS/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = warp_size > threads_per_row ? threadIdx.x % threads_per_row : threadIdx.x; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XXS/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XXS) + threadIdx.x/(QI2_XXS/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -1932,42 +2349,46 @@ template static __device__ __forceinlin const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid1; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid0; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; #else - x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/4; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq2_xs( +template static __device__ __forceinline__ void load_tiles_iq2_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = MMQ_DP4A_TXS_Q8_0_16; int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI2_XS/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_XS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_XS/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_XS) + threadIdx.x/(QI2_XS/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -1986,44 +2407,48 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // NEW_MMA_AVAILABLE + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq2_s( +template static __device__ __forceinline__ void load_tiles_iq2_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ2_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI2_S/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR2_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI2_S/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI2_S) + threadIdx.x/(QI2_S/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2049,44 +2474,48 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos[0] ^ signs0, signs0); const int grid_h = __vsub4(grid_pos[1] ^ signs1, signs1); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = bxi->scales[kqsx]; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q3_K + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; #else - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; - x_df[i*(2*WARP_SIZE*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; -#endif // NEW_MMA_AVAILABLE + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+0] = ((ls & 0x0F)*d + d/2)/4; + x_df[i*(2*MMQ_TILE_NE_K*2/QI8_0) + i/(QI8_0/4) + 2*kqsx+1] = ((ls >> 4)*d + d/2)/4; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq3_xxs( +template static __device__ __forceinline__ void load_tiles_iq3_xxs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_XXS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI3_XXS/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_XXS)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_XXS/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_XXS) + threadIdx.x/(QI3_XXS/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2107,42 +2536,46 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l + 1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l + 1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = aux32 >> 28; const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/2; #else - x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = (ls*d + d/2)/2; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/2; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq3_s( +template static __device__ __forceinline__ void load_tiles_iq3_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % (QI3_S/2); + constexpr int threads_per_row = (MMQ_ITER_K / (4 * QR3_S)) / 2; + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/(QI3_S/2)) { - int i = i0 + threadIdx.y*(2*WARP_SIZE/QI3_S) + threadIdx.x/(QI3_S/2); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2170,42 +2603,46 @@ template static __device__ __forceinlin const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+0)] = grid_l; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l+1)] = grid_h; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid_l; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid_h; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid_l; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid_h; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const int ls = 1 + 2*((bxi->scales[kqsx/2] >> (((2*kqsx) << 1) & 0x04)) & 0x0F); const float d = bxi->d; -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = ls*d; #else - x_df[i*(WARP_SIZE/4) + i/4 + kqsx] = ls*d; -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = ls*d; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq1_s( +template static __device__ __forceinline__ void load_tiles_iq1_s( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - half2 * x_ds = (half2 *) (x_qs + WARP_SIZE*2); + half2 * x_ds = (half2 *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ3_S, mmq_y); int * x_qs = (int *) x_tile; half2 * x_ds = (half2 *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kqsx = threadIdx.x % QI1_S; + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR1_S); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * WARP_SIZE/QI1_S) { - int i = i0 + threadIdx.y*(WARP_SIZE/QI1_S) + threadIdx.x/QI1_S; + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * nrows) { + int i = i0 + threadIdx.y*nrows + threadIdx.x/threads_per_row; if (need_check) { i = min(i, i_max); @@ -2225,66 +2662,71 @@ template static __device__ __forceinlin const int grid0 = (grid >> 0) & 0x0F0F0F0F; const int grid1 = (grid >> 4) & 0x0F0F0F0F; -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+0)] = grid0; x_qs[i*MMQ_MMA_TILE_X_K_Q8_1 + 8*kqsx + (2*l+1)] = grid1; #else - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+0)] = grid0; - x_qs[i*(2*WARP_SIZE + 1) + 8*kqsx + (2*l+1)] = grid1; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+0)] = grid0; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + 8*kqsx + (2*l+1)] = grid1; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } const float d1q = __half2float(bxi->d) * (((qh >> 11) & 0x0E) + 1); const float delta = -1.0f + IQ1S_DELTA - (qh & 0x8000) * (2.0f*IQ1S_DELTA/0x8000); -#ifdef NEW_MMA_AVAILABLE - x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_ds[i*MMQ_MMA_TILE_X_K_Q8_1 + kqsx] = make_half2(d1q, d1q*delta); #else - x_ds[i*(WARP_SIZE/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); -#endif // NEW_MMA_AVAILABLE + x_ds[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = make_half2(d1q, d1q*delta); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template static __device__ __forceinline__ void load_tiles_iq4_xs( +template static __device__ __forceinline__ void load_tiles_iq4_xs( const char * __restrict__ x, int * __restrict__ x_tile, const int kbx0, const int i_max, const int stride) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); -#ifdef NEW_MMA_AVAILABLE +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) int * x_qs = (int *) x_tile; - float * x_df = (float *) (x_qs + WARP_SIZE*2); + float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2); #else constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y); int * x_qs = (int *) x_tile; float * x_df = (float *) (x_qs + txs.qs); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) - const int kbx = 0; // threadIdx.x / QI4_XS - const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS + constexpr int threads_per_row = MMQ_ITER_K / (4 * QR4_XS); + constexpr int nrows = warp_size / threads_per_row; + const int kqsx = threadIdx.x % threads_per_row; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps) { - int i = i0 + threadIdx.y; + for (int i0 = 0; i0 < mmq_y; i0 += nrows*nwarps) { + int i = i0 + (nrows == 1 ? threadIdx.y : threadIdx.y*nrows + threadIdx.x/threads_per_row); if (need_check) { i = min(i, i_max); } - const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride + kbx; + const block_iq4_xs * bxi = (const block_iq4_xs *) x + kbx0 + i*stride; const int aux_q4 = get_int_b4(bxi->qs, kqsx); const int2 v = get_int_from_table_16(aux_q4); - const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4; -#ifdef NEW_MMA_AVAILABLE + const int k0 = 8 * (kqsx / 4) + kqsx % 4; + +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x; x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y; #else - x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x; - x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y; -#endif // NEW_MMA_AVAILABLE + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 0] = v.x; + x_qs[i*(2*MMQ_TILE_NE_K + 1) + k0 + 4] = v.y; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } + constexpr int rows_per_warp = warp_size / 8; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4); + for (int i0 = 0; i0 < mmq_y; i0 += nwarps * rows_per_warp) { + int i = i0 + threadIdx.y * rows_per_warp + threadIdx.x / (MMQ_TILE_NE_K/4); if (need_check) { i = min(i, i_max); @@ -2297,18 +2739,21 @@ template static __device__ __forceinlin const int ls = ((bxi->scales_l[(threadIdx.x % 8)/2] >> (4*(threadIdx.x % 2))) & 0x0F) | (((bxi->scales_h >> (2*(threadIdx.x % 8))) & 0x03) << 4); -#ifdef NEW_MMA_AVAILABLE - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = d * (ls - 32); #else - x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); -#endif // NEW_MMA_AVAILABLE + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + threadIdx.x % 8] = d * (ls - 32); +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) } } -template +template static __device__ __forceinline__ void mmq_write_back_dp4a( const float * __restrict__ sum, const int32_t * __restrict__ ids_dst, float * __restrict__ dst, const int stride, const int i_max, const int j_max) { + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -2318,32 +2763,40 @@ static __device__ __forceinline__ void mmq_write_back_dp4a( } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } - dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[ids_dst[j]*stride + i] = sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } } -template +template static __device__ __forceinline__ void mmq_write_back_mma( const float * __restrict__ sum, const int * __restrict__ ids_dst, float * __restrict__ dst, const int stride, const int i_max, const int j_max) { - typedef tile<16, 8, int> tile_C; constexpr int granularity = mmq_get_granularity_device(mmq_x); + constexpr int nwarps = mmq_get_nwarps_device(); + +#if defined(AMD_MFMA_AVAILABLE) + constexpr int tileC_IJ = mmq_get_granularity_device(0); + typedef tile tile_C; + constexpr int rows_per_warp = granularity; +#else + typedef tile<16, 8, int> tile_C; constexpr int rows_per_warp = 2 * granularity; +#endif constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); -#ifdef NEW_MMA_AVAILABLE +#if defined(NEW_MMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) static_assert(nwarps*tile_C::I == mmq_y, "nwarps*tile_C::I != mmq_y"); -#endif // NEW_MMA_AVAILABLE +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += ntx*tile_C::J) { @@ -2371,179 +2824,181 @@ static __device__ __forceinline__ void mmq_write_back_mma( // ------------------------------------------------------------------------------------------------------------------------------------- -template +template struct mmq_type_traits; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q4_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q4_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_1_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q5_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q5_1_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_1; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q8_0_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q8_0; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q2_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q2_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q2_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q2_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q3_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q3_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q3_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q4_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q4_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q4_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q5_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q5_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q5_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_Q6_K_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_q6_K; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q6_K_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q6_K_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_XXS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_XS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ2_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq2_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_16_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ3_XXS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_xxs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ3_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq3_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ1_S_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq1_s; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_1_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_1_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_NL_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_nl; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template -struct mmq_type_traits { +template +struct mmq_type_traits { static constexpr int vdr = VDR_IQ4_XS_Q8_1_MMQ; - static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; - static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; - static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; + static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs; + static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma; + static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a; }; -template +template static __device__ __forceinline__ void mul_mat_q_process_tile( const char * __restrict__ x, const int offset_x, const int * __restrict__ y, const int * __restrict__ ids_dst, float * __restrict__ dst, float * __restrict__ tmp_fixup, const int stride_row_x, const int ncols_y, const int stride_col_dst, const int tile_x_max_i, const int tile_y_max_j, const int kb0_start, const int kb0_stop) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int nwarps = mmq_get_nwarps_device(); constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); - constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; + constexpr load_tiles_mmq_t load_tiles = mmq_type_traits::load_tiles; extern __shared__ int data_mul_mat_q[]; int * tile_y = data_mul_mat_q + mmq_x; - int * tile_x = tile_y + GGML_PAD(mmq_x*(WARP_SIZE + WARP_SIZE/QI8_1), nwarps*WARP_SIZE); + int * tile_x = tile_y + GGML_PAD(mmq_x*MMQ_TILE_Y_K, nwarps*warp_size); -#ifdef NEW_MMA_AVAILABLE - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; - constexpr mmq_write_back_t write_back = mmq_write_back_mma; +#if defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_mma; + constexpr mmq_write_back_t write_back = mmq_write_back_mma; #else - constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; - constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // NEW_MMA_AVAILABLE + constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; + constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; +#endif // defined(AMD_MFMA_AVAILABLE) || defined(NEW_MMA_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk; - float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) { load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x); @@ -2551,8 +3006,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( { const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { - int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; } @@ -2567,8 +3022,8 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( { const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int)); #pragma unroll - for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*WARP_SIZE) { - int l = l0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) { + int l = l0 + threadIdx.y*warp_size + threadIdx.x; tile_y[l] = by0[l]; } @@ -2576,7 +3031,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( __syncthreads(); - vec_dot(tile_x, tile_y, sum, WARP_SIZE); + vec_dot(tile_x, tile_y, sum, MMQ_TILE_NE_K); __syncthreads(); } @@ -2591,16 +3046,16 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( // The mul_mat_q kernel implements "stream-k" work partitioning as described in https://arxiv.org/abs/2301.03598 -template +template #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) #if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) - __launch_bounds__(WARP_SIZE*nwarps, 2) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) #endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN) #else #if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA - __launch_bounds__(WARP_SIZE*nwarps, 1) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 1) #else - __launch_bounds__(WARP_SIZE*nwarps, 2) + __launch_bounds__(ggml_cuda_get_physical_warp_size()*mmq_get_nwarps_device(), 2) #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) static __global__ void mul_mat_q( @@ -2616,6 +3071,9 @@ static __global__ void mul_mat_q( return; } + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + constexpr int qk = ggml_cuda_type_traits::qk; constexpr int mmq_y = get_mmq_y_device(); @@ -2627,10 +3085,10 @@ static __global__ void mul_mat_q( // For MoE the correct indices are loaded from ids_dst. extern __shared__ int ids_dst_shared[]; // Stored at beginning of shared memory. #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2639,7 +3097,7 @@ static __global__ void mul_mat_q( __syncthreads(); // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead: -#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA +#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA { const int wt = blockIdx.z / nchannels_y; const int zt = blockIdx.z - wt*nchannels_y; @@ -2667,10 +3125,10 @@ static __global__ void mul_mat_q( // __syncthreads(); // There is no previous tile that could cause a race condition. #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2688,12 +3146,12 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; - mul_mat_q_process_tile + mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, 0, ncols_x/qk); return; } -#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA +#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA const int64_t blocks_per_ne00 = ncols_x / qk; constexpr int blocks_per_iter = MMQ_ITER_K / qk; @@ -2745,10 +3203,10 @@ static __global__ void mul_mat_q( __syncthreads(); #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2766,7 +3224,7 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. - mul_mat_q_process_tile + mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); @@ -2812,10 +3270,10 @@ static __global__ void mul_mat_q( // The memory layout for the fixup buffer is always contiguous, therefore reset ids: __syncthreads(); #pragma unroll - for (int j0 = 0; j0 < mmq_x; j0 += nwarps*WARP_SIZE) { - const int j = j0 + threadIdx.y*WARP_SIZE + threadIdx.x; + for (int j0 = 0; j0 < mmq_x; j0 += nwarps*warp_size) { + const int j = j0 + threadIdx.y*warp_size + threadIdx.x; - if (j0 + nwarps*WARP_SIZE > mmq_x && j >= mmq_x) { + if (j0 + nwarps*warp_size > mmq_x && j >= mmq_x) { break; } @@ -2833,13 +3291,13 @@ static __global__ void mul_mat_q( const int offset_x = (wt/sample_ratio)*stride_sample_x + (zt/channel_ratio)*stride_channel_x + it*mmq_y*stride_row_x; constexpr bool fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. - mul_mat_q_process_tile + mul_mat_q_process_tile (x, offset_x, y + offset_y, ids_dst_shared, dst + offset_dst, tmp_fixup, stride_row_x, ncols_y, stride_col_dst, tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); } -template +template static __global__ void mul_mat_q_stream_k_fixup( const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, @@ -2849,7 +3307,10 @@ static __global__ void mul_mat_q_stream_k_fixup( constexpr int blocks_per_iter = MMQ_ITER_K / qk; const int64_t blocks_per_ne00 = ncols_x / qk; - float sum[mmq_x*mmq_y / (nwarps*WARP_SIZE)] = {0.0f}; + constexpr int nwarps = mmq_get_nwarps_device(); + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); + + float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f}; const int ntx = (ncols_dst + mmq_x - 1) / mmq_x; const int nty = (nrows_x + mmq_y - 1) / mmq_y; @@ -2893,10 +3354,10 @@ static __global__ void mul_mat_q_stream_k_fixup( const int j = j0 + threadIdx.y; #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; - sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; + sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size] += tmp_last_tile[bidx*(mmq_x*mmq_y) + j*mmq_y + i]; } } @@ -2937,14 +3398,14 @@ static __global__ void mul_mat_q_stream_k_fixup( } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } - dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[j*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } return; @@ -2955,7 +3416,7 @@ static __global__ void mul_mat_q_stream_k_fixup( const int col_high = expert_bounds[zt + 1]; const int col_diff = col_high - col_low; - for (int j = threadIdx.y*WARP_SIZE + threadIdx.x; j < mmq_x; j += nwarps*WARP_SIZE) { + for (int j = threadIdx.y*warp_size + threadIdx.x; j < mmq_x; j += nwarps*warp_size) { ids_dst_shared[j] = ids_dst[col_low + j]; } __syncthreads(); @@ -2975,14 +3436,14 @@ static __global__ void mul_mat_q_stream_k_fixup( } #pragma unroll - for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < mmq_y; i0 += warp_size) { const int i = i0 + threadIdx.x; if (need_check && i > i_max) { continue; } - dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/WARP_SIZE) + i0/WARP_SIZE]; + dst[ids_dst_shared[j]*stride_col_dst + i] += sum[(j0/nwarps) * (mmq_y/warp_size) + i0/warp_size]; } } } @@ -2996,13 +3457,13 @@ struct mmq_args { }; template -static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc) { +static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int cc, const int warp_size, const int nwarps) { const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); const size_t nbs_ids = mmq_x*sizeof(int); - const size_t nbs_x = new_mma_available(cc) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); + const size_t nbs_x = (new_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); - return nbs_ids + nbs_x + GGML_PAD(nbs_y, MMQ_NWARPS*WARP_SIZE*sizeof(int)); + return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); } template @@ -3010,14 +3471,16 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a const int id = ggml_cuda_get_device(); const int cc = ggml_cuda_info().devices[id].cc; const int nsm = ggml_cuda_info().devices[id].nsm; + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc); const int mmq_y = get_mmq_y_host(cc); - const dim3 block_dims(WARP_SIZE, MMQ_NWARPS, 1); + const dim3 block_dims(warp_size, nwarps, 1); - const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc); + const int nbytes_shared = mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); - CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); + CUDA_SET_SHARED_MEMORY_LIMIT((mul_mat_q), nbytes_shared); const int nty = (args.nrows_x + mmq_y - 1) / mmq_y; const int ntx = (args.ncols_dst + mmq_x - 1) / mmq_x; @@ -3032,14 +3495,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (!args.use_stream_k) { if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, sample_ratio, args.nsamples_y, args.stride_sample_x, args.stride_sample_y, args.stride_sample_dst); } else { constexpr bool need_check = true; - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, nullptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3059,8 +3522,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a if (args.nrows_x % mmq_y == 0) { constexpr bool need_check = false; - - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3070,13 +3532,12 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a return; } - mul_mat_q_stream_k_fixup<<>> + mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); } else { constexpr bool need_check = true; - - mul_mat_q<<>> + mul_mat_q<<>> (args.x, args.y, args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.stride_row_x, args.ncols_y, args.nrows_dst, channel_ratio, args.nchannels_y, args.stride_channel_x, args.stride_channel_y, args.stride_channel_dst, @@ -3086,7 +3547,7 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a return; } - mul_mat_q_stream_k_fixup<<>> + mul_mat_q_stream_k_fixup<<>> (args.ids_dst, args.expert_bounds, args.dst, tmp_fixup.ptr, args.ncols_x, args.nrows_x, args.ncols_dst, args.nrows_dst, args.nchannels_y, args.stride_channel_dst, args.nsamples_y, args.stride_sample_dst); } @@ -3094,9 +3555,11 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a template void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) { - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const size_t smpbo = ggml_cuda_info().devices[id].smpbo; + const int warp_size = ggml_cuda_info().devices[id].warp_size; + const int nwarps = mmq_get_nwarps_host(cc); const int mmq_x_max = get_mmq_x_max_host(cc); const int mmq_y = get_mmq_y_host(cc); @@ -3107,7 +3570,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda for (int mmq_x = 8; mmq_x <= mmq_x_max && ntiles_x_best > 1; mmq_x += 8) { const int granularity = mmq_get_granularity_host(mmq_x, cc); - if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc) > smpbo) { + if (mmq_x % granularity != 0 || mmq_get_nbytes_shared(mmq_x, mmq_y, cc, warp_size, nwarps) > smpbo) { continue; } diff --git a/ggml/src/ggml-cuda/norm.cu b/ggml/src/ggml-cuda/norm.cu index 0020dbcec5f..bddcca51b7b 100644 --- a/ggml/src/ggml-cuda/norm.cu +++ b/ggml/src/ggml-cuda/norm.cu @@ -104,10 +104,12 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr } } -template +template static __global__ void rms_norm_f32( const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel, - const int64_t stride_sample, const float eps) { + const int64_t stride_sample, const float eps, const float * mul = nullptr, const int64_t mul_stride_row = 0, + const int64_t mul_stride_channel = 0, const int64_t mul_stride_sample = 0, const int mul_ncols = 0, + const int mul_nrows = 0, const int mul_nchannels = 0, const int mul_nsamples = 0) { const int nrows = gridDim.x; const int nchannels = gridDim.y; @@ -119,6 +121,13 @@ static __global__ void rms_norm_f32( x += sample*stride_sample + channel*stride_channel + row*stride_row; dst += ((sample*nchannels + channel)*nrows + row)*ncols; + if constexpr (do_multiply) { + const int mul_row = row % mul_nrows; + const int mul_channel = channel % mul_nchannels; + const int mul_sample = sample % mul_nsamples; + mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row; + } + float tmp = 0.0f; // partial sum for thread in warp for (int col = tid; col < ncols; col += block_size) { @@ -145,7 +154,12 @@ static __global__ void rms_norm_f32( const float scale = rsqrtf(mean + eps); for (int col = tid; col < ncols; col += block_size) { - dst[col] = scale * x[col]; + if constexpr (do_multiply) { + const int mul_col = col % mul_ncols; + dst[col] = scale * x[col] * mul[mul_col]; + } else { + dst[col] = scale * x[col]; + } } } @@ -310,10 +324,30 @@ static void rms_norm_f32_cuda( const dim3 blocks_num(nrows, nchannels, nsamples); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } else { + const dim3 block_dims(1024, 1, 1); + rms_norm_f32<1024, false><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + } +} + +static void rms_norm_mul_f32_cuda( + const float * x, const float * mul, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples, + const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, + const int64_t mul_stride_row, const int64_t mul_stride_channel, const int64_t mul_stride_sample, + const int mul_ncols, const int mul_nrows, const int mul_nchannels, const int mul_nsamples, + const float eps, cudaStream_t stream) { + const dim3 blocks_num(nrows, nchannels, nsamples); + if (mul == nullptr) { + rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); + return; + } + if (ncols < 1024) { + const dim3 block_dims(WARP_SIZE, 1, 1); + rms_norm_f32<<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); } else { const dim3 block_dims(1024, 1, 1); - rms_norm_f32<1024><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); + rms_norm_f32<1024, true><<>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples); } } @@ -407,6 +441,59 @@ void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream); } +void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor) { + const ggml_tensor * rms_norm_src = (ggml_tensor *) dst->src[0]; + float eps = 0.0f; + + memcpy(&eps, dst->op_params, sizeof(float)); + + const float * src0_d = (const float *) rms_norm_src->data; + const float * mul_d = nullptr; + const ggml_tensor * mul_src = nullptr; + + if (mul_tensor->src[0] == dst) { + mul_d = (float *) mul_tensor->src[1]->data; + mul_src = mul_tensor->src[1]; + } else if(mul_tensor->src[1] == dst) { + mul_d = (float *) mul_tensor->src[0]->data; + mul_src = mul_tensor->src[0]; + } else { + GGML_ASSERT(false); + } + + float * dst_d = (float *) mul_tensor->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(rms_norm_src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(mul_tensor->type == GGML_TYPE_F32); + GGML_ASSERT(eps >= 0.0f); + + const int64_t ne00 = rms_norm_src->ne[0]; + const int64_t ne01 = rms_norm_src->ne[1]; + const int64_t ne02 = rms_norm_src->ne[2]; + const int64_t ne03 = rms_norm_src->ne[3]; + + const size_t ts0 = ggml_type_size(rms_norm_src->type); + GGML_ASSERT(rms_norm_src->nb[0] == ts0); + const int64_t s01 = rms_norm_src->nb[1] / ts0; + const int64_t s02 = rms_norm_src->nb[2] / ts0; + const int64_t s03 = rms_norm_src->nb[3] / ts0; + + const size_t ts_mul = ggml_type_size(mul_src->type); + GGML_ASSERT(mul_src->nb[0] == ts_mul); + const int64_t mul_s01 = mul_src->nb[1] / ts_mul; + const int64_t mul_s02 = mul_src->nb[2] / ts_mul; + const int64_t mul_s03 = mul_src->nb[3] / ts_mul; + + const int mul_ncols = mul_src->ne[0]; + const int mul_nrows = mul_src->ne[1]; + const int mul_nchannels = mul_src->ne[2]; + const int mul_nsamples = mul_src->ne[3]; + + rms_norm_mul_f32_cuda(src0_d, mul_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, mul_s01, mul_s02, mul_s03, mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, eps, stream); +} + void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * grad = dst->src[0]; // gradients const ggml_tensor * src0f = dst->src[1]; // src0 from forward pass diff --git a/ggml/src/ggml-cuda/norm.cuh b/ggml/src/ggml-cuda/norm.cuh index 706a5660a68..7ea7bd4df3c 100644 --- a/ggml/src/ggml-cuda/norm.cuh +++ b/ggml/src/ggml-cuda/norm.cuh @@ -6,6 +6,8 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_rms_norm_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * mul_tensor); + void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 560604d095f..07983436459 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -4,24 +4,8 @@ typedef void (*set_rows_kernel_t)(const char * src, char * dst); template -__device__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { - GGML_UNUSED(src_f); - GGML_UNUSED(dst_f); -} - -template<> -__device__ __forceinline__ void set_rows_1(const float * src_f, half * dst_h) { - convert_f32_f16(src_f, dst_h); -} - -template<> -__device__ __forceinline__ void set_rows_1(const float * src_f, nv_bfloat16 * dst_b) { - convert_f32_bf16(src_f, dst_b); -} - -template<> -__device__ __forceinline__ void set_rows_1(const float * src_f, float * dst_f) { - convert_f32_f32(src_f, dst_f); +__device__ __forceinline__ void set_rows_1(const src_t * src_f, dst_t * dst_f) { + convert_flt(src_f, dst_f); } // Generic quantized set_rows kernel template @@ -60,6 +44,9 @@ static __global__ void k_set_rows_quant( block_type * dst_block = dst_row_ptr + i00 / qk; quantize_func(src_block, dst_block); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); } // Template dispatch function for quantized set_rows diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 184d445f5c0..56e59a058f9 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -160,7 +160,19 @@ #endif #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) -#define CDNA +#define CDNA // For the entire family +#endif + +#if defined(__gfx942__) +#define CDNA3 +#endif + +#if defined(__gfx90a__) +#define CDNA2 +#endif + +#if defined(__gfx908__) +#define CDNA1 #endif #if defined(__GFX12__) diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 937779a90af..19896320244 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -13,7 +13,7 @@ #define CUBLAS_OP_N MUBLAS_OP_N #define CUBLAS_OP_T MUBLAS_OP_T #define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS -#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_MATH_MODE_DEFAULT +#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH #define CUDA_R_16F MUSA_R_16F #define CUDA_R_16BF MUSA_R_16BF #define CUDA_R_32F MUSA_R_32F @@ -29,7 +29,7 @@ #define cublasSgemm mublasSgemm #define cublasStatus_t mublasStatus_t #define cublasOperation_t mublasOperation_t -#define cublasGetStatusString mublasStatus_to_string +#define cublasGetStatusString mublasGetStatusString #define cudaDataType_t musaDataType_t #define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index b7b3fc49af3..8424464d8ca 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -528,6 +528,7 @@ typedef struct { int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; + int64_t s_off; uint64_t nb01; uint64_t nb02; uint64_t nb03; diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index dc391a0d4d5..337f7985bad 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -1955,6 +1955,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex static int ggml_metal_encode_node( ggml_backend_t backend, int idx, + int idx_end, id encoder, struct ggml_metal_mem_pool * mem_pool) { struct ggml_backend_metal_context * ctx = backend->context; @@ -2181,7 +2182,9 @@ static int ggml_metal_encode_node( size_t offs_fuse; id id_fuse; - for (n_fuse = 0; n_fuse <= 6; ++n_fuse) { + // note: in metal, we sometimes encode the graph in parallel so we have to avoid fusing nodes + // across splits. idx_end indicates the last node in the current split + for (n_fuse = 0; n_fuse <= 6 && idx + n_fuse + 1 < idx_end; ++n_fuse) { if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { break; } @@ -3138,6 +3141,7 @@ static int ggml_metal_encode_node( /*.n_group =*/ n_group, /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, + /*.s_off =*/ ggml_nelements(src1) * sizeof(float), /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, @@ -3166,12 +3170,22 @@ static int ggml_metal_encode_node( [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; [encoder setBytes:&args length:sizeof(args) atIndex:8]; + // One shared memory bucket for each simd group in the threadgroup + // NOTE: Metal kernels require the buffer size to be multiple of 16 bytes + // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength + if (d_state >= 32) { + GGML_ASSERT((int64_t)(d_state / 32) <= 32); + const int64_t shmem_size = 32; + GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup); + [encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0]; + } + if (ne30 == 1) { // Mamba-2 - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } else { GGML_ASSERT(d_inner == 1); - [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(d_state, 1, 1)]; } } break; case GGML_OP_RWKV_WKV6: @@ -4288,7 +4302,7 @@ static int ggml_metal_encode_node( ops[1] = GGML_OP_MUL; ops[2] = GGML_OP_ADD; - for (n_fuse = 0; n_fuse <= 1; ++n_fuse) { + for (n_fuse = 0; n_fuse <= 1 && idx + n_fuse + 1 < idx_end; ++n_fuse) { if (!ggml_can_fuse(gf, idx + n_fuse, ops + n_fuse, 2)) { break; } @@ -6271,7 +6285,11 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; } - const int res = ggml_metal_encode_node(backend, idx, encoder, mem_pool); + const int res = ggml_metal_encode_node(backend, idx, node_end, encoder, mem_pool); + if (idx + res > node_end) { + GGML_ABORT("fusion error: nodes spanning multiple encoders have been fused. this indicates a bug in the fusion logic %s", + "https://github.com/ggml-org/llama.cpp/pull/14849"); + } if (should_capture) { [encoder popDebugGroup]; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index f62b9ad548e..99a453090f6 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1823,10 +1823,16 @@ kernel void kernel_ssm_scan_f32( device const void * src5, device const void * src6, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + const int64_t i0 = tpitg.x; const int64_t i1 = 0; const int64_t ir = tgpig.x; // current head const int64_t i3 = tgpig.y; // current seq @@ -1841,41 +1847,88 @@ kernel void kernel_ssm_scan_f32( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + const int64_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + const int64_t i = i0 + i1*nc; + float s0 = s0_buff[i]; + float s = s_buff[i]; + + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); + device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); + device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); + device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh} - device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} + device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; - float sumf = 0.0f; - for (int64_t i0 = 0; i0 < nc; ++i0) { - const int64_t i = i0 + i1*nc; - const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; - } + const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); + s = state; + + // Parallel sum: This relies on the fact that this kernel will be + // dispatched with each threadgroup having (d_state, 1, 1) threads which + // are subdivided into SIMD groups of size `sgptg`. The goal is to + // compute y = sum({state * C[i] for i in range(d_state)}). + // To parallelize this effectively, we first use simd_sum over each SIMD + // group to compute the sum of each SIMD group, then place the result in + // the SIMD group's indexed bucket in the shared memory. We then sum + // over the individual group sums to compute the final sum. + + // Computed for each thread + float sumf = state * C[i0]; - y[0] = sumf; + // Sum the threads in the simd group => simd sum + sumf = simd_sum(sumf); + + if (sgptg > 1) { + + // Once per simd group, place the group sum into the shared buffer + if (tiisg == 0) { + shared[sgitg] = sumf; + } + + // Wait for all threads in the threadgroup to reach this point. This + // ensures that all elements of the shared buffer are populated with the + // sum of the individual simd groups. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // For simd group 0 at indices < num simd groups, extract the shared + // simd sum + sumf = 0.0f; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + y[0] = sumf; + } + } + } else if (tiisg == 0) { + y[0] = sumf; + } // recurse s0 = s; } + + // Assign the final state to the output buffer + s_buff[i] = s; } // ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -// TODO: optimize (e.g. by parallelizing over d_state) kernel void kernel_ssm_scan_f32_group( device const void * src0, device const void * src1, @@ -1885,10 +1938,16 @@ kernel void kernel_ssm_scan_f32_group( device const void * src5, device const void * src6, device float * dst, + threadgroup float * shared [[threadgroup(0)]], constant ggml_metal_kargs_ssm_scan & args, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + + const int64_t i0 = tpitg.x; const int64_t i1 = tgpig.x; const int64_t ir = tgpig.y; // current head const int64_t i3 = tgpig.z; // current seq @@ -1903,38 +1962,81 @@ kernel void kernel_ssm_scan_f32_group( const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + const int64_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + const int64_t i = i0 + i1*nc; + float s0 = s0_buff[i]; + float s = s_buff[i]; + + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} + device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); + device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); + device const float * B_block = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i3*args.nb43); + device const float * C_block = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i3*args.nb53); + device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} + device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; const float x_dt = x[0] * dt_soft_plus; const float dA = exp(dt_soft_plus * A[0]); - float sumf = 0.0f; - for (int64_t i0 = 0; i0 < nc; ++i0) { - const int64_t i = i0 + i1*nc; - const float state = (s0[i] * dA) + (B[i0] * x_dt); - sumf += state * C[i0]; - s[i] = state; + const float state = (s0 * dA) + (B[i0] * x_dt); + s = state; + + // Parallel sum: This relies on the fact that this kernel will be + // dispatched with each threadgroup having (d_state, 1, 1) threads which + // are subdivided into SIMD groups of size `sgptg`. The goal is to + // compute y = sum({state * C[i] for i in range(d_state)}). + // To parallelize this effectively, we first use simd_sum over each SIMD + // group to compute the sum of each SIMD group, then place the result in + // the SIMD group's indexed bucket in the shared memory. We then sum + // over the individual group sums to compute the final sum. + + // Computed for each thread + float sumf = state * C[i0]; + + // Sum the threads in the simd group => simd sum + sumf = simd_sum(sumf); + + // Once per simd group, place the group sum into the shared buffer + if (tiisg == 0) { + shared[sgitg] = sumf; } - y[0] = sumf; + // Wait for all threads in the threadgroup to reach this point. This + // ensures that all elements of the shared buffer are populated with the + // sum of the individual simd groups. + threadgroup_barrier(mem_flags::mem_threadgroup); + + // For simd group 0 at indices < num simd groups, extract the shared + // simd sum + sumf = 0.0f; + if (sgitg == 0) { + if (tiisg < sgptg) { + sumf = shared[tiisg]; + } + sumf = simd_sum(sumf); + if (tiisg == 0) { + y[0] = sumf; + } + } // recurse s0 = s; } + + // Assign the final state to the output buffer + s_buff[i] = s; } kernel void kernel_rwkv_wkv6_f32( diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 971314debc7..02904526ade 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -34,8 +34,12 @@ if (MUSAToolkit_FOUND) list(APPEND GGML_SOURCES_MUSA ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") list(APPEND GGML_SOURCES_MUSA ${SRCS}) - file(GLOB SRCS "../ggml-musa/*.cu") - list(APPEND GGML_SOURCES_MUSA ${SRCS}) + + if (GGML_MUSA_MUDNN_COPY) + file(GLOB SRCS "../ggml-musa/*.cu") + list(APPEND GGML_SOURCES_MUSA ${SRCS}) + add_compile_definitions(GGML_MUSA_MUDNN_COPY) + endif() if (GGML_CUDA_FA_ALL_QUANTS) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-vec*.cu") @@ -72,6 +76,10 @@ if (MUSAToolkit_FOUND) add_compile_definitions(GGML_USE_MUSA) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) + if (GGML_MUSA_GRAPHS) + add_compile_definitions(GGML_MUSA_GRAPHS) + endif() + if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() @@ -97,10 +105,16 @@ if (MUSAToolkit_FOUND) endif() if (GGML_STATIC) - # TODO: mudnn has not provided static libraries yet target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static) + # TODO: mudnn has not provided static libraries yet + # if (GGML_MUSA_MUDNN_COPY) + # target_link_libraries(ggml-musa PRIVATE mudnn_static) + # endif() else() - target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas mudnn) + target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas) + if (GGML_MUSA_MUDNN_COPY) + target_link_libraries(ggml-musa PRIVATE mudnn) + endif() endif() if (GGML_CUDA_NO_VMM) diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index ec5d8cf5955..015fa8f0682 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -105,6 +105,8 @@ set(GGML_OPENCL_KERNELS pad repeat mul_mat_f16_f32 + conv2d + conv2d_f16_f32 ) foreach (K ${GGML_OPENCL_KERNELS}) diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 3388259152b..c87a32383c8 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -333,6 +333,7 @@ struct ggml_backend_opencl_context { size_t max_alloc_size; bool fp16_support; bool has_vector_subgroup_broadcast; + bool disable_fusion; ggml_cl_compiler_version adreno_cl_compiler_version; int adreno_wave_size; @@ -390,6 +391,9 @@ struct ggml_backend_opencl_context { cl_program program_tanh; cl_program program_upscale; cl_program program_concat; + cl_program program_conv_2d_f16; + cl_program program_conv_2d_f32; + cl_program program_conv_2d_f16_f32; cl_program program_tsembd; cl_program program_mul_mv_id_q4_0_f32_8x_flat; @@ -408,7 +412,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick, kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16; cl_kernel kernel_norm; - cl_kernel kernel_rms_norm; + cl_kernel kernel_rms_norm, kernel_rms_norm_mul; cl_kernel kernel_group_norm; cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8; cl_kernel kernel_soft_max, kernel_soft_max_4; @@ -441,6 +445,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_upscale_bilinear; cl_kernel kernel_concat_f32_contiguous; cl_kernel kernel_concat_f32_non_contiguous; + cl_kernel kernel_conv_2d_f16; + cl_kernel kernel_conv_2d_f32; + cl_kernel kernel_conv_2d_f16_f32; cl_kernel kernel_timestep_embedding; cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat; @@ -1094,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve backend_ctx->program_rms_norm = build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err)); + CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err)); GGML_LOG_CONT("."); } @@ -1478,6 +1486,47 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // conv2d + { + #ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "conv2d.cl.h" + }; + const std::string kernel_src_f16_f32 { + #include "conv2d_f16_f32.cl.h" + }; + #else + const std::string kernel_src = read_file("conv2d.cl"); + const std::string kernel_src_f16_f32 = read_file("conv2d_f16_f32.cl"); + #endif + if (!kernel_src.empty()) { + backend_ctx->program_conv_2d_f16 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), (std::string(compile_opts) + " -DUSE_FP16=1").c_str()); + CL_CHECK((backend_ctx->kernel_conv_2d_f16 = clCreateKernel(backend_ctx->program_conv_2d_f16, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + backend_ctx->program_conv_2d_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_conv_2d_f32 = clCreateKernel(backend_ctx->program_conv_2d_f32, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: conv2d kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_conv_2d_f16 = nullptr; + backend_ctx->kernel_conv_2d_f16 = nullptr; + backend_ctx->program_conv_2d_f32 = nullptr; + backend_ctx->kernel_conv_2d_f32 = nullptr; + } + if (!kernel_src_f16_f32.empty()) { + backend_ctx->program_conv_2d_f16_f32 = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src_f16_f32.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_conv_2d_f16_f32 = clCreateKernel(backend_ctx->program_conv_2d_f16_f32, "kernel_conv_2d", &err), err)); + GGML_LOG_CONT("."); + } else { + GGML_LOG_WARN("ggml_opencl: conv2d_f16_f32 kernel source not found or empty. This op will not be available.\n"); + backend_ctx->program_conv_2d_f16_f32 = nullptr; + backend_ctx->kernel_conv_2d_f16_f32 = nullptr; + } + } + // mul_mv_id_q4_0_f32_8x_flat { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2063,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) { CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err)); #endif // GGML_OPENCL_USE_ADRENO_KERNELS + backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr; + dev_ctx->backend_ctx = backend_ctx.release(); return dev_ctx->backend_ctx; } @@ -2232,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) { sync_with_other_backends(backend_ctx); } +static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx+1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + !ggml_are_same_shape(mul->src[0], rms_norm->src[1])) { + return false; + } + + // rms_norm assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + + return true; +} + +static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor); + static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2245,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm continue; } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]); + i++; + continue; + } + bool ok = ggml_cl_compute_forward(backend, node); if (!ok) { GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); @@ -2361,6 +2456,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te op->src[0]->ne[3] == 1 && op->ne[3] == 1; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_CONV_2D: + return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || + (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || + (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32); case GGML_OP_CONCAT: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_TIMESTEP_EMBEDDING: @@ -4404,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } +static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) { + GGML_ASSERT(mul_tensor); + GGML_ASSERT(rms_norm_tensor); + + // src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm) + const ggml_tensor * src0 = rms_norm_tensor->src[0]; + const ggml_tensor * src1; + if (mul_tensor->src[0] == rms_norm_tensor) { + src1 = mul_tensor->src[1]; + } else if (mul_tensor->src[1] == rms_norm_tensor) { + src1 = mul_tensor->src[0]; + } else { + GGML_ASSERT(false && "Invalid args for rms_norm and mul"); + } + const ggml_tensor * dst = mul_tensor; + + GGML_ASSERT(src0); + GGML_ASSERT(src0->extra); + GGML_ASSERT(src1); + GGML_ASSERT(src1->extra); + GGML_ASSERT(dst); + GGML_ASSERT(dst->extra); + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + float eps; + memcpy(&eps, rms_norm_tensor->op_params, sizeof(float)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb13 = src1->nb[3]; + + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; + + GGML_ASSERT(ne00 % 4 == 0); + + size_t sgs; + if (backend_ctx->gpu_family == ADRENO) { + sgs = 64; + } else if (backend_ctx->gpu_family == INTEL) { + sgs = 32; + } else { + GGML_ASSERT(false && "Unsupported GPU"); + } + + cl_kernel kernel = backend_ctx->kernel_rms_norm_mul; + + int nth = sgs; + int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel); + while (nth < ne00 && nth < max_workgroup_size) { + nth *= 2; + } + nth = MIN(nth, max_workgroup_size); + nth = MIN(nth, ne00); + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL)); + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); +} + static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -4998,6 +5208,82 @@ static void ggml_cl_mul_mat_f16_f32_tiled(ggml_backend_t backend, const ggml_ten backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); } +static void ggml_cl_conv_2d(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_TENSOR_BINARY_OP_LOCALS; + ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; + + ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; + ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra; + ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; + + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offset1 = extra1->offset + src1->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; + + const cl_uint Cout = ne03; const cl_uint Cin = ne02; const cl_uint N = ne13; + const cl_uint KW = ne00; const cl_uint KH = ne01; const cl_uint W = ne10; const cl_uint H = ne11; const cl_uint OW = ne0; const cl_uint OH = ne1; + + const cl_uint s0 = dst->op_params[0]; const cl_uint s1 = dst->op_params[1]; + const cl_uint p0 = dst->op_params[2]; const cl_uint p1 = dst->op_params[3]; + const cl_uint d0 = dst->op_params[4]; const cl_uint d1 = dst->op_params[5]; + + const cl_uint cl_nb01 = nb01/ggml_type_size(src0->type); const cl_uint cl_nb02 = nb02/ggml_type_size(src0->type); const cl_uint cl_nb03 = nb03/ggml_type_size(src0->type); + const cl_uint cl_nb11 = nb11/ggml_type_size(src1->type); const cl_uint cl_nb12 = nb12/ggml_type_size(src1->type); const cl_uint cl_nb13 = nb13/ggml_type_size(src1->type); + const cl_uint cl_nb1 = nb1/ggml_type_size(dst->type); const cl_uint cl_nb2 = nb2/ggml_type_size(dst->type); const cl_uint cl_nb3 = nb3/ggml_type_size(dst->type); + + const int64_t NPQ = (int64_t)N * OW * OH; + + const uint32_t BS_K = 64; + const uint32_t BS_NPQ = 64; + const uint32_t BS_CRS = 16; + const uint32_t VEC_SIZE = 4; + + const uint32_t TS_K = 4; + const uint32_t TS_NPQ = 8; + + const uint32_t WG_K = BS_K / TS_K; + const uint32_t WG_NPQ = BS_NPQ / TS_NPQ; + + auto splitWork = [](uint32_t work_size, uint32_t block_size) { return (block_size + work_size - 1) / block_size; }; + const uint32_t NB_K = splitWork(Cout, BS_K); + const uint32_t NB_NPQ = splitWork(NPQ, BS_NPQ); + + cl_kernel kernel; + size_t shmem_size; + + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) { + kernel = backend_ctx->kernel_conv_2d_f16; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_half4)); + } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_conv_2d_f32; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_float) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); + } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_conv_2d_f16_f32; + shmem_size = (size_t)(BS_K * BS_CRS * sizeof(cl_half) + BS_CRS * (BS_NPQ / VEC_SIZE) * sizeof(cl_float4)); + } else { + GGML_ASSERT(false && "Unsupported data type combination for conv2d"); + } + + cl_uint idx = 0; + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extra1->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, idx++, shmem_size, NULL)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cout)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &Cin)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &N)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &KH)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &W)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &H)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OW)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &OH)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &s1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &p1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d0)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &d1)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb01)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb02)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb03)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb11)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb12)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb13)); + CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb1)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb2)); CL_CHECK(clSetKernelArg(kernel, idx++, sizeof(cl_uint), &cl_nb3)); + + size_t global_work_size[] = { (size_t)NB_K * WG_K, (size_t)NB_NPQ * WG_NPQ, 1 }; + size_t local_work_size[] = { (size_t)WG_K, (size_t)WG_NPQ, 1 }; + + backend_ctx->enqueue_ndrange_kernel(kernel, 2, global_work_size, local_work_size, dst); +} + static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { GGML_ASSERT(src0); GGML_ASSERT(src0->extra); @@ -6752,6 +7038,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor } ggml_cl_upscale(backend, tensor->src[0], tensor); return true; + case GGML_OP_CONV_2D: + if (!any_on_device) { + return false; + } + func = ggml_cl_conv_2d; + break; case GGML_OP_CONCAT: if (!any_on_device) { return false; diff --git a/ggml/src/ggml-opencl/kernels/conv2d.cl b/ggml/src/ggml-opencl/kernels/conv2d.cl new file mode 100644 index 00000000000..e339c90cff5 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/conv2d.cl @@ -0,0 +1,185 @@ +#ifdef USE_FP16 +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#define T_FLOAT half +#define T_FLOAT4 half4 +#define VSTORE_T_FLOAT4(data, offset, p) vstore_half4_rte(data, offset, p) +#else +#define T_FLOAT float +#define T_FLOAT4 float4 +#define VSTORE_T_FLOAT4(data, offset, p) vstore4(data, offset, p) +#endif + +#if defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +#define REQD_SUBGROUP_SIZE_128 +#endif + +#define T_ACCUM float4 +#define VEC_SIZE 4 + +#define BS_K 64 +#define BS_NPQ 64 +#define BS_CRS 16 + +#define TS_K 4 +#define TS_NPQ 8 + +#define WG_K (BS_K / TS_K) +#define WG_NPQ (BS_NPQ / TS_NPQ) + +#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE) +#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE) + +static inline uint splitWork(uint work_size, uint block_size){ + return (work_size + block_size - 1) / block_size; +} + +REQD_SUBGROUP_SIZE_128 +kernel void kernel_conv_2d( + global void* p_knl, + ulong off_knl, + global void* p_src, + ulong off_src, + global void* p_dst, + ulong off_dst, + local void* shared, + uint Cout, uint Cin, uint N, + uint KW, uint KH, uint W, uint H, uint OW, uint OH, + uint s0, uint s1, uint p0, uint p1, uint d0, uint d1, + uint nb01, uint nb02, uint nb03, + uint nb11, uint nb12, uint nb13, + uint nb1, uint nb2, uint nb3 +) { + global T_FLOAT* knl_data = (global T_FLOAT*) ((global char*)p_knl + off_knl); + global T_FLOAT* src_data = (global T_FLOAT*) ((global char*)p_src + off_src); + global T_FLOAT* dst_data = (global T_FLOAT*) ((global char*)p_dst + off_dst); + + const uint K = Cout; + const uint CRS = Cin*KH*KW; + const uint NPQ = N*OH*OW; + + const uint lid_k = get_local_id(0); + const uint lid_npq = get_local_id(1); + const uint tid = lid_npq * WG_K + lid_k; + + const uint B_idx_K = get_group_id(0); + const uint B_idx_NPQ = get_group_id(1); + + const uint offset_k = B_idx_K * BS_K; + const uint offset_npq = B_idx_NPQ * BS_NPQ; + + local T_FLOAT* Ash = (local T_FLOAT*)shared; + local T_FLOAT4* Bsh = (local T_FLOAT4*) &Ash[BS_K * BS_CRS]; + + T_ACCUM regC[TS_K][TS_NPQ_VEC]; + for (int i = 0; i < TS_K; ++i) { + for (int j = 0; j < TS_NPQ_VEC; ++j) { + regC[i][j] = (T_ACCUM)(0.0f); + } + } + + const uint NB_CRS = splitWork(CRS, BS_CRS); + + for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) { + const uint offset_crs = B_idx_CRS * BS_CRS; + + for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) { + const uint k_l = i / BS_CRS; + const uint crs_l = i % BS_CRS; + const uint k_g = offset_k + k_l; + const uint crs_g = offset_crs + crs_l; + + if (k_g < K && crs_g < CRS) { + const uint Cin_idx = crs_g / (KW*KH); + const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW; + const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW; + const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03; + Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx]; + } else { + Ash[k_l * BS_CRS + crs_l] = (T_FLOAT)0.0f; + } + } + + for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) { + const uint crs_l = i / BS_NPQ_VEC; + const uint npq_l_vec = i % BS_NPQ_VEC; + const uint crs_g = offset_crs + crs_l; + + T_FLOAT4 val = (T_FLOAT4)(0.0f); + if (crs_g < CRS) { + const uint Cin_idx = crs_g / (KW * KH); + const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW; + const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW; + for (int v = 0; v < VEC_SIZE; ++v) { + const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v; + if (npq_g < NPQ) { + const uint N_idx = npq_g / (OH * OW); + const uint pq_idx = npq_g % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1); + const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0); + + if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) { + const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13; + ((T_FLOAT*)&val)[v] = src_data[src_idx]; + } + } + } + } + Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + #pragma unroll + for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) { + T_FLOAT regA[TS_K]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l]; + } + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + T_FLOAT4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), convert_float4(regB), regC[k_l_reg][npq_l_vec_reg]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + const uint k_g = offset_k + lid_k * TS_K + k_l_reg; + if (k_g >= K) continue; + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE; + + const uint N_idx = npq_g_base / (OH * OW); + const uint pq_idx = npq_g_base % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + + if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) { + const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3; + VSTORE_T_FLOAT4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]); + } else { + T_ACCUM res = regC[k_l_reg][npq_l_vec_reg]; + for (int v = 0; v < VEC_SIZE; ++v) { + const uint npq_g = npq_g_base + v; + if (npq_g < NPQ) { + const uint N_idx_s = npq_g / (OH*OW); + const uint pq_idx_s = npq_g % (OH*OW); + const uint OH_idx_s = pq_idx_s / OW; + const uint OW_idx_s = pq_idx_s % OW; + const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3; + dst_data[dst_idx_s] = (T_FLOAT)(((float*)&res)[v]); + } + } + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl b/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl new file mode 100644 index 00000000000..cb05637f33a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/conv2d_f16_f32.cl @@ -0,0 +1,176 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#if defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#else +#define REQD_SUBGROUP_SIZE_128 +#endif + +#define T_ACCUM float4 +#define VEC_SIZE 4 + +#define BS_K 64 +#define BS_NPQ 64 +#define BS_CRS 16 + +#define TS_K 4 +#define TS_NPQ 8 + +#define WG_K (BS_K / TS_K) +#define WG_NPQ (BS_NPQ / TS_NPQ) + +#define BS_NPQ_VEC (BS_NPQ / VEC_SIZE) +#define TS_NPQ_VEC (TS_NPQ / VEC_SIZE) + +static inline uint splitWork(uint work_size, uint block_size){ + return (work_size + block_size - 1) / block_size; +} + +REQD_SUBGROUP_SIZE_128 +kernel void kernel_conv_2d( + global void* p_knl, + ulong off_knl, + global void* p_src, + ulong off_src, + global void* p_dst, + ulong off_dst, + local void* shared, + uint Cout, uint Cin, uint N, + uint KW, uint KH, uint W, uint H, uint OW, uint OH, + uint s0, uint s1, uint p0, uint p1, uint d0, uint d1, + uint nb01, uint nb02, uint nb03, + uint nb11, uint nb12, uint nb13, + uint nb1, uint nb2, uint nb3 +) { + global half* knl_data = (global half*) ((global char*)p_knl + off_knl); + global float* src_data = (global float*) ((global char*)p_src + off_src); + global float* dst_data = (global float*) ((global char*)p_dst + off_dst); + + const uint K = Cout; + const uint CRS = Cin*KH*KW; + const uint NPQ = N*OH*OW; + + const uint lid_k = get_local_id(0); + const uint lid_npq = get_local_id(1); + const uint tid = lid_npq * WG_K + lid_k; + + const uint B_idx_K = get_group_id(0); + const uint B_idx_NPQ = get_group_id(1); + + const uint offset_k = B_idx_K * BS_K; + const uint offset_npq = B_idx_NPQ * BS_NPQ; + + local half* Ash = (local half*)shared; + local float4* Bsh = (local float4*) &Ash[BS_K * BS_CRS]; + + T_ACCUM regC[TS_K][TS_NPQ_VEC]; + for (int i = 0; i < TS_K; ++i) { + for (int j = 0; j < TS_NPQ_VEC; ++j) { + regC[i][j] = (T_ACCUM)(0.0f); + } + } + + const uint NB_CRS = splitWork(CRS, BS_CRS); + + for (uint B_idx_CRS = 0; B_idx_CRS < NB_CRS; ++B_idx_CRS) { + const uint offset_crs = B_idx_CRS * BS_CRS; + + for (int i = tid; i < BS_K * BS_CRS; i += (WG_K * WG_NPQ)) { + const uint k_l = i / BS_CRS; + const uint crs_l = i % BS_CRS; + const uint k_g = offset_k + k_l; + const uint crs_g = offset_crs + crs_l; + + if (k_g < K && crs_g < CRS) { + const uint Cin_idx = crs_g / (KW*KH); + const uint KH_idx = (crs_g - Cin_idx*KW*KH) / KW; + const uint KW_idx = crs_g - Cin_idx*KW*KH - KH_idx*KW; + const uint knl_idx = KW_idx + KH_idx*nb01 + Cin_idx*nb02 + k_g*nb03; + Ash[k_l * BS_CRS + crs_l] = knl_data[knl_idx]; + } else { + Ash[k_l * BS_CRS + crs_l] = (half)0.0f; + } + } + + for (int i = tid; i < BS_CRS * BS_NPQ_VEC; i += (WG_K * WG_NPQ)) { + const uint crs_l = i / BS_NPQ_VEC; + const uint npq_l_vec = i % BS_NPQ_VEC; + const uint crs_g = offset_crs + crs_l; + + float4 val = (float4)(0.0f); + if (crs_g < CRS) { + const uint Cin_idx = crs_g / (KW * KH); + const uint KH_idx = (crs_g - Cin_idx * KW * KH) / KW; + const uint KW_idx = crs_g - Cin_idx * KW * KH - KH_idx * KW; + for (int v = 0; v < VEC_SIZE; ++v) { + const uint npq_g = offset_npq + npq_l_vec * VEC_SIZE + v; + if (npq_g < NPQ) { + const uint N_idx = npq_g / (OH * OW); + const uint pq_idx = npq_g % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + const int H_idx = (int)(OH_idx * s1 + KH_idx * d1 - p1); + const int W_idx = (int)(OW_idx * s0 + KW_idx * d0 - p0); + + if (H_idx >= 0 && H_idx < H && W_idx >= 0 && W_idx < W) { + const uint src_idx = W_idx + H_idx * nb11 + Cin_idx * nb12 + N_idx * nb13; + ((float*)&val)[v] = src_data[src_idx]; + } + } + } + } + Bsh[crs_l * BS_NPQ_VEC + npq_l_vec] = val; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + #pragma unroll + for (uint crs_l = 0; crs_l < BS_CRS; ++crs_l) { + half regA[TS_K]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regA[k_l_reg] = Ash[(lid_k * TS_K + k_l_reg) * BS_CRS + crs_l]; + } + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + float4 regB = Bsh[crs_l * BS_NPQ_VEC + lid_npq * TS_NPQ_VEC + npq_l_vec_reg]; + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + regC[k_l_reg][npq_l_vec_reg] = mad(convert_float(regA[k_l_reg]), regB, regC[k_l_reg][npq_l_vec_reg]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + for (uint k_l_reg = 0; k_l_reg < TS_K; ++k_l_reg) { + const uint k_g = offset_k + lid_k * TS_K + k_l_reg; + if (k_g >= K) continue; + + for (uint npq_l_vec_reg = 0; npq_l_vec_reg < TS_NPQ_VEC; ++npq_l_vec_reg) { + const uint npq_g_base = offset_npq + (lid_npq * TS_NPQ_VEC + npq_l_vec_reg) * VEC_SIZE; + + const uint N_idx = npq_g_base / (OH * OW); + const uint pq_idx = npq_g_base % (OH * OW); + const uint OH_idx = pq_idx / OW; + const uint OW_idx = pq_idx % OW; + + if (nb1 == OW && OW_idx + VEC_SIZE <= OW && npq_g_base + VEC_SIZE <= NPQ) { + const uint dst_idx = OW_idx + OH_idx*nb1 + k_g*nb2 + N_idx*nb3; + vstore4(regC[k_l_reg][npq_l_vec_reg], 0, &dst_data[dst_idx]); + } else { + T_ACCUM res = regC[k_l_reg][npq_l_vec_reg]; + for (int v = 0; v < VEC_SIZE; ++v) { + const uint npq_g = npq_g_base + v; + if (npq_g < NPQ) { + const uint N_idx_s = npq_g / (OH*OW); + const uint pq_idx_s = npq_g % (OH*OW); + const uint OH_idx_s = pq_idx_s / OW; + const uint OW_idx_s = pq_idx_s % OW; + const uint dst_idx_s = OW_idx_s + OH_idx_s*nb1 + k_g*nb2 + N_idx_s*nb3; + dst_data[dst_idx_s] = ((float*)&res)[v]; + } + } + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/im2col_f16.cl b/ggml/src/ggml-opencl/kernels/im2col_f16.cl index b84c8984653..cf6cdaa4ce5 100644 --- a/ggml/src/ggml-opencl/kernels/im2col_f16.cl +++ b/ggml/src/ggml-opencl/kernels/im2col_f16.cl @@ -31,7 +31,7 @@ kernel void kernel_im2col_f16( src1 = (global float*)((global char*)src1 + offset1); dst = (global half*)((global char*)dst + offsetd); - long ksize = OW * (KH > 1 ? KW : 1); + long ksize = OW * KH; long kx = i / ksize; long kd = kx * ksize; long ky = (i - kd) / OW; diff --git a/ggml/src/ggml-opencl/kernels/im2col_f32.cl b/ggml/src/ggml-opencl/kernels/im2col_f32.cl index 4bf65e4eaaf..1ecdb2344ad 100644 --- a/ggml/src/ggml-opencl/kernels/im2col_f32.cl +++ b/ggml/src/ggml-opencl/kernels/im2col_f32.cl @@ -31,7 +31,7 @@ kernel void kernel_im2col_f32( src1 = (global float*)((global char*)src1 + offset1); dst = (global float*)((global char*)dst + offsetd); - long ksize = OW * (KH > 1 ? KW : 1); + long ksize = OW * KH; long kx = i / ksize; long kd = kx * ksize; long ky = (i - kd) / OW; diff --git a/ggml/src/ggml-opencl/kernels/rms_norm.cl b/ggml/src/ggml-opencl/kernels/rms_norm.cl index 9d21f3398ec..ecd053cb4c1 100644 --- a/ggml/src/ggml-opencl/kernels/rms_norm.cl +++ b/ggml/src/ggml-opencl/kernels/rms_norm.cl @@ -94,3 +94,82 @@ kernel void kernel_rms_norm( } } } + +//------------------------------------------------------------------------------ +// rms_norm_mul +//------------------------------------------------------------------------------ +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_32 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_rms_norm_mul( + global char * src0, + ulong offset0, + global char * src1, + ulong offset1, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + int ne10, + int ne11, + int ne12, + int ne13, + ulong nb11, + ulong nb12, + ulong nb13, + ulong nb1, + ulong nb2, + ulong nb3, + float eps, + local float * sum +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + int i03 = get_group_id(2); + int i02 = get_group_id(1); + int i01 = get_group_id(0); + + global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01); + global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11); + + float sumf = 0; + + // parallel sum + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + sumf += dot(x[i00], x[i00]); + } + sumf = sub_group_reduce_add(sumf); + if (get_sub_group_local_id() == 0) { + sum[get_sub_group_id()] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) { + if (get_local_id(0) < i) { + sum[get_local_id(0)] += sum[get_local_id(0) + i]; + } + } + if (get_local_id(0) == 0) { + sum[0] /= ne00; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + float mean = sum[0]; + float scale = 1.0f/sqrt(mean + eps); + + global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1); + for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) { + y[i00] = (x[i00] * scale) * f[i00%(ne10/4)]; + } +} diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index f468f796d57..29bc421d58f 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1055,7 +1055,7 @@ bool rpc_server::set_tensor(const std::vector & input) { GGML_ASSERT(ctx_ptr != nullptr); ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor); - if (tensor == nullptr) { + if (tensor == nullptr || tensor->buffer == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); return false; } @@ -1124,7 +1124,7 @@ bool rpc_server::set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rp GGML_ASSERT(ctx_ptr != nullptr); ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); - if (tensor == nullptr) { + if (tensor == nullptr || tensor->buffer == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); return false; } @@ -1192,7 +1192,7 @@ bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector< GGML_ASSERT(ctx_ptr != nullptr); ggml_context * ctx = ctx_ptr.get(); ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor); - if (tensor == nullptr) { + if (tensor == nullptr || tensor->buffer == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensor\n", __func__); return false; } @@ -1229,7 +1229,7 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co ggml_tensor * src = deserialize_tensor(ctx, &request.src); ggml_tensor * dst = deserialize_tensor(ctx, &request.dst); - if (src == nullptr || dst == nullptr) { + if (src == nullptr || dst == nullptr || src->buffer == nullptr || dst->buffer == nullptr) { GGML_LOG_ERROR("[%s] error deserializing tensors\n", __func__); return false; } diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 872eb4b052d..a023d6fb452 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -3531,7 +3531,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; - assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + assert(max_work_group_size % (WARP_SIZE * WARP_SIZE) == 0); { sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size)); diff --git a/ggml/src/ggml-sycl/im2col.cpp b/ggml/src/ggml-sycl/im2col.cpp index 52737cc746d..7adcb3d9d9c 100644 --- a/ggml/src/ggml-sycl/im2col.cpp +++ b/ggml/src/ggml-sycl/im2col.cpp @@ -26,7 +26,7 @@ static void im2col_kernel(const float * x, T * dst, int64_t batch_offset, int64_ // make each work-item deal with more elements since sycl global range can not exceed max int for (int64_t i = global_id; i < pelements; i += (work_group_size * item_ct1.get_group_range(2))) { - const int64_t ksize = OW * (KH > 1 ? KW : 1); + const int64_t ksize = OW * KH; const int64_t kx = i / ksize; const int64_t kd = kx * ksize; const int64_t ky = (i - kd) / OW; diff --git a/ggml/src/ggml-sycl/quants.hpp b/ggml/src/ggml-sycl/quants.hpp index 8b952db43bf..d0d5ac9a4e8 100644 --- a/ggml/src/ggml-sycl/quants.hpp +++ b/ggml/src/ggml-sycl/quants.hpp @@ -48,11 +48,11 @@ template <> struct block_q_t { }; static constexpr std::pair get_block_offset(const int block_index, const int /* nblocks */) { - return { block_index * (traits::qk / traits::qr), 0 }; + return { block_index * (QK4_0 / QR4_0), 0 }; } static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { - return { (ncols / traits::qr * nrows) + block_index * sizeof(ggml_half), 0 }; + return { (ncols / QR4_0 * nrows) + block_index * sizeof(ggml_half), 0 }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } @@ -71,14 +71,12 @@ template <> struct block_q_t { } static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { - auto nblocks = (nrows * (ncols / traits::qk)); - return { nblocks * (QK_K / 2), + auto nblocks = (nrows * (ncols / QK_K)); + return { nblocks * (QK_K / 2) + (block_index * K_SCALE_SIZE), (nblocks * QK_K / 2) + (nblocks * K_SCALE_SIZE) + (block_index * sizeof(ggml_half2)) }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } - - constexpr size_t get_total_qs_bytes(int nblocks) { return nblocks * QK_K / 2; } }; template <> struct block_q_t { @@ -90,22 +88,23 @@ template <> struct block_q_t { }; static constexpr std::pair get_block_offset(const int block_index, const int n_blocks) { - auto low_bits_index = block_index * (traits::qk / traits::qr); + auto low_bits_index = block_index * (QK_K / QR6_K); // the index of high bits it's after all low bits auto high_bits_index = n_blocks * (QK_K / 2) + (block_index * (QK_K / 4)); return { low_bits_index, high_bits_index }; } static constexpr std::pair get_d_offset(int nrows, int ncols, const int block_index) { - auto nblocks = (nrows * (ncols / traits::qk)); + auto nblocks = (nrows * (ncols / QK_K)); auto total_qs_bytes = nblocks * (QK_K / 2) + nblocks * (QK_K / 4); auto block_scales = total_qs_bytes + block_index * (QK_K / 16); - auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16); + auto sb_scale = total_qs_bytes + nblocks * (QK_K / 16) + block_index * sizeof(ggml_half); return { block_scales, sb_scale }; } static constexpr int block_to_q8_1_ratio() { return traits::qk / QK8_1; } }; + } // namespace ggml_sycl_reordered #endif // GGML_SYCL_QUANTS_HPP diff --git a/ggml/src/ggml-sycl/vecdotq.hpp b/ggml/src/ggml-sycl/vecdotq.hpp index 0a5d4999419..4088ddb54f0 100644 --- a/ggml/src/ggml-sycl/vecdotq.hpp +++ b/ggml/src/ggml-sycl/vecdotq.hpp @@ -350,11 +350,9 @@ template <> struct reorder_vec_dot_q_sycl { __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, const std::pair d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds, const int & iqs) { - const int ib = ibx_offset.first / (QK_K / 2); - const uint8_t * base = static_cast(vbq); const uint8_t * qs = base + ibx_offset.first; - const uint8_t * scs = base + d_offset.first + ib * K_SCALE_SIZE; + const uint8_t * scs = base + d_offset.first; const ggml_half2 * dms = reinterpret_cast(base + d_offset.second); const int bq8_offset = QR4_K * ((iqs / 2) / (QI8_1 / 2)); @@ -427,13 +425,11 @@ template <> struct reorder_vec_dot_q_sycl { __dpct_inline__ float operator()(const void * __restrict__ vbq, const std::pair ibx_offset, const std::pair d_offset, const int8_t * q8_1_quant_ptr, const sycl::half2 * q8_1_ds, const int iqs) { - const int ib = ibx_offset.first / (QK_K / 2); - const uint8_t * base = static_cast(vbq); const uint8_t * ql = base + ibx_offset.first; const uint8_t * qh = base + ibx_offset.second; const int8_t * scales = reinterpret_cast(base + d_offset.first); - const ggml_half * d = (const ggml_half *) (base + d_offset.second) + ib; + const ggml_half * d = (const ggml_half *) (base + d_offset.second); const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 4); const int scale_offset = (QI6_K / 4) * (iqs / (QI6_K / 2)) + (iqs % (QI6_K / 2)) / (QI6_K / 8); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 3019a545d58..a99b1c73130 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -328,6 +328,7 @@ struct vk_device_struct { uint64_t max_memory_allocation_size; uint64_t suballocation_block_size; bool fp16; + bool bf16; bool pipeline_robustness; vk::Device device; uint32_t vendor_id; @@ -482,6 +483,8 @@ struct vk_device_struct { vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_conv2d_f32; + vk_pipeline pipeline_conv2d_f16_f32; vk_pipeline pipeline_conv2d_dw_whcn_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32; @@ -875,6 +878,38 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t H; }; +struct vk_op_conv2d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; +}; + struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -974,18 +1009,45 @@ class vk_memory_logger { #endif // GGML_VULKAN_MEMORY_DEBUG class vk_perf_logger { -public: + public: void print_timings() { + if (timings.empty()) { + return; + } + uint64_t total_all_op_times = 0; std::cerr << "----------------\nVulkan Timings:" << std::endl; - for (const auto& t : timings) { - uint64_t total = 0; - for (const auto& time : t.second) { - total += time; + for (const auto & t : timings) { + uint64_t total_op_times = 0; + for (const auto & time : t.second) { + total_op_times += time; + } + std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0) + << " us"; + + // If we have as many flops entries as timing entries for the op, then compute and log the flops/S. + auto it = flops.find(t.first); + if (it != flops.end() && (it->second).size() == t.second.size()) { + uint64_t total_op_flops = 0; + for (const auto & elem : it->second) { + total_op_flops += elem; + } + std::cerr << " (" + << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) / + (double(total_op_times) / (1000.0 * 1000.0 * 1000.0)) + << " GFLOPS/s)"; } - std::cerr << t.first << ": " << t.second.size() << " x " << (total / t.second.size() / 1000.0) << " us" << std::endl; + + total_all_op_times += total_op_times; + + std::cerr << std::endl; + } + + if (timings.size() > 0) { + std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl; } timings.clear(); + flops.clear(); } void log_timing(const ggml_tensor * node, uint64_t time) { @@ -994,22 +1056,45 @@ class vk_perf_logger { return; } if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { - const uint64_t m = node->src[0]->ne[1]; - const uint64_t n = node->src[1]->ne[1]; - const uint64_t k = node->src[1]->ne[0]; - std::string name = ggml_op_name(node->op); + const uint64_t m = node->src[0]->ne[1]; + const uint64_t n = node->src[1]->ne[1]; + const uint64_t k = node->src[1]->ne[0]; + std::string name = ggml_op_name(node->op); if (n == 1) { name += "_VEC m=" + std::to_string(m) + " k=" + std::to_string(k); } else { name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); } timings[name].push_back(time); + flops[name].push_back(m * n * (k + (k - 1))); + return; + } + if (node->op == GGML_OP_CONV_2D) { + std::string name = ggml_op_name(node->op); + ggml_tensor * knl = node->src[0]; + uint64_t OW = node->ne[0]; + uint64_t OH = node->ne[1]; + uint64_t N = node->ne[3]; + uint64_t Cout = node->ne[2]; + uint64_t KW = knl->ne[0]; + uint64_t KH = knl->ne[1]; + uint64_t Cin = knl->ne[2]; + // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ + uint64_t size_M = Cout; + uint64_t size_K = Cin * KW * KH; + uint64_t size_N = N * OW * OH; + uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1)); + name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + + ", N=N*OW*OH=" + std::to_string(size_N); + flops[name].push_back(n_flops); + timings[name].push_back(time); return; } timings[ggml_op_name(node->op)].push_back(time); } -private: + private: std::map> timings; + std::map> flops; }; struct ggml_backend_vk_context { @@ -2112,6 +2197,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } compile_count++; } + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); }; @@ -2961,6 +3047,51 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + // conv2d + uint32_t conv2d_WG_SIZE = 256; + uint32_t conv2d_BS_K = 128; + uint32_t conv2d_BS_CRS = 16; + uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. + if (device->subgroup_shuffle && + device->vendor_id != VK_VENDOR_ID_INTEL) { // Do not enable collectives on Intel, see PR 14316 + use_collectives = 1; + conv2d_BS_CRS = std::min( + device->subgroup_size, + conv2d_BS_CRS); // CRS block size should be capped at sugroup size for correctness when shuffle is used. + } + uint32_t conv2d_BS_NPQ = 128; + uint32_t conv2d_TS_K = 8; + uint32_t conv2d_shmem_req = + (conv2d_BS_K * (conv2d_BS_CRS + 1) + conv2d_BS_CRS * (conv2d_BS_NPQ + 1)) * sizeof(float); + if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { + conv2d_BS_CRS = 8; + if (use_collectives) { + conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); + } + } + + if (use_collectives) { + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, + { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true); + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, + { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, true); + } else { + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f32, "conv2d_f32", conv2d_f32_len, conv2d_f32_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, + { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, + false); + ggml_vk_create_pipeline( + device, device->pipeline_conv2d_f16_f32, "conv2d_f16_f32", conv2d_f16_f32_len, conv2d_f16_f32_data, "main", 3, + sizeof(vk_op_conv2d_push_constants), { conv2d_BS_K, conv2d_BS_NPQ, 1 }, + { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives }, 1, true, + false); + } + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -3273,6 +3404,12 @@ static vk_device ggml_vk_get_device(size_t idx) { device->fp16 = device->fp16 && vk12_features.shaderFloat16; +#if defined(VK_KHR_shader_bfloat16) + device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; +#else + device->bf16 = false; +#endif + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; if (device->subgroup_size_control) { @@ -3615,6 +3752,7 @@ static void ggml_vk_print_gpu_info(size_t idx) { bool coopmat_support = false; bool coopmat2_support = false; bool integer_dot_product = false; + bool bfloat16_support = false; for (auto properties : ext_props) { if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { @@ -3635,6 +3773,11 @@ static void ggml_vk_print_gpu_info(size_t idx) { } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { integer_dot_product = true; +#endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; #endif } } @@ -3701,10 +3844,25 @@ static void ggml_vk_print_gpu_info(size_t idx) { last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; } +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + } +#endif + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); fp16 = fp16 && vk12_features.shaderFloat16; +#if defined(VK_KHR_shader_bfloat16) + bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; +#else + bool bf16 = false; +#endif + uint32_t default_subgroup_size = get_subgroup_size("", device_architecture); const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; @@ -3722,8 +3880,8 @@ static void ggml_vk_print_gpu_info(size_t idx) { std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; std::string device_name = props2.properties.deviceName.data(); - GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", - idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size, props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { @@ -6809,6 +6967,16 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_leaky_relu_f32; } return nullptr; + case GGML_OP_CONV_2D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + if (src0->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv2d_f32; + } else if (src0->type == GGML_TYPE_F16) { + return ctx->device->pipeline_conv2d_f16_f32; + } + } + return nullptr; case GGML_OP_CONV_2D_DW: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (ggml_is_contiguous(src1)) { @@ -7131,6 +7299,31 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co const uint32_t OW = dst->ne[0]; elements = { N * OC * OH * OW, 1, 1}; } break; + case GGML_OP_CONV_2D: + { + // src0 - kernel: [KW, KH, Cin, Cout] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[3]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); + int64_t NPQ = N * OW * OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + elements = { static_cast(Cout), static_cast(NPQ), 1 }; + } + break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: @@ -7703,6 +7896,13 @@ static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); + // Skip empty skip_rows operations. For most ops the empty check at the start + // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst + // with empty srcs. + if (ggml_is_empty(src0) || ggml_is_empty(src1)) { + return; + } + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, { (uint32_t)ggml_nelements(src0), (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, @@ -7997,6 +8197,55 @@ static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, c }, dryrun); } +static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv2d_push_constants p{}; + p.Cout = static_cast(ne03); + p.Cin = static_cast(ne02); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[1]); + p.p0 = static_cast(dst->op_params[2]); + p.p1 = static_cast(dst->op_params[3]); + p.d0 = static_cast(dst->op_params[4]); + p.d1 = static_cast(dst->op_params[5]); + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + GGML_ASSERT(ne03 == ne2); + GGML_ASSERT(ne02 == ne12); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); +} + static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { vk_op_conv2d_dw_push_constants p{}; p.ne = ggml_nelements(dst); @@ -9059,6 +9308,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -9126,6 +9376,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_LEAKY_RELU: { @@ -9332,6 +9583,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_POOL_2D: ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + break; + case GGML_OP_CONV_2D: + ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun); + break; case GGML_OP_CONV_2D_DW: ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); @@ -9462,6 +9717,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: case GGML_OP_CONV_2D_DW: case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: @@ -10013,7 +10269,7 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st } // if rms_norm is the B operand, then we don't handle broadcast if (rms_norm == mul->src[1] && - mul->src[0]->ne[1] != rms_norm->ne[1]) { + !ggml_are_same_shape(mul->src[0], rms_norm)) { return false; } // rms_norm shader assumes contiguous rows @@ -10043,6 +10299,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D) { + // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode. + auto CRS_size = + cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[0]->ne[2]; + auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3]; + total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type); } i += ctx->num_additional_fused_ops; ctx->num_additional_fused_ops = 0; @@ -10619,6 +10881,20 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_CONV_TRANSPOSE_1D: return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_CONV_2D: + { + // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + bool is_Apple = ggml_vk_get_device(ctx->device)->vendor_id == VK_VENDOR_ID_APPLE; + // Channel-contiguous format is not supported yet. + return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op)) && !is_Apple; + } default: return false; } @@ -11177,6 +11453,14 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const int32_t p1 = tensor->op_params[6]; tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); + } else if (tensor->op == GGML_OP_CONV_2D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); } else if (tensor->op == GGML_OP_LEAKY_RELU) { const float * op_params = (const float *)tensor->op_params; tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp new file mode 100644 index 00000000000..481940a52b3 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -0,0 +1,265 @@ +#version 450 + +#ifdef USE_COLLECTIVES +# extension GL_KHR_shader_subgroup_shuffle : enable +#endif + +#include "types.comp" + +// Make spec constant +#define SHMEM_PAD 0 + +// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j +layout(binding = 0) readonly buffer A { + A_TYPE knl_data[]; +}; // src0 - kernel: [KW, KH, Cin, Cout] + +layout(binding = 1) readonly buffer B { + B_TYPE src_data[]; +}; // src1 - input: [W, H, Cin, N] -- channel_first format + +layout(binding = 2) writeonly buffer D { + D_TYPE dst_data[]; +}; // dst - result: [OW, OH, Cout, N] + +layout(push_constant) uniform parameter { + // I/O channels, batch size + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + // Tensor spatial sizes: kernel, input, output + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + // Parameters: stride, padding, dilation - 0=y, 1=x + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + // Strides in elements + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; +} + +p; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +// Blocktile sizes +layout(constant_id = 1) const uint BS_K = 128; +layout(constant_id = 2) const uint BS_CRS = 16; +layout(constant_id = 3) const uint BS_NPQ = 128; +// Thread-tile sizes +layout(constant_id = 4) const uint TS_K = 8; +layout(constant_id = 5) const uint use_collectives = 1; + +uint32_t tid = gl_LocalInvocationID.x; +const uint32_t WG_SIZE = gl_WorkGroupSize.x; + +uint splitWork(uint work_size, uint block_size) { + return (block_size + work_size - 1) / block_size; +} + +uint32_t K = p.Cout; +uint32_t CRS = p.Cin * p.KH * p.KW; +uint32_t NPQ = p.N * p.OH * p.OW; + +uint32_t n_elems_out = K * NPQ; + +// Number of blocktiles per input +uint32_t NB_CRS = splitWork(CRS, BS_CRS); + +const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; +const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; + +const uint32_t Ash_numel = BS_K * BS_CRS; +const uint32_t Bsh_numel = BS_CRS * BS_NPQ; + +const uint32_t Ash_len = BS_K * Ash_stride; +const uint32_t Bsh_len = BS_CRS * Bsh_stride; + +shared float Ash[Ash_len]; // K x CRS +shared float Bsh[Bsh_len]; // CRS x NPQ + +// Threadtile sizes +const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; + +// Number of threadtiles per blocktile +const uint32_t NT_K = BS_K / TS_K; +const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; + +float regA[TS_K]; +float regB[TS_NPQ]; +float regC[TS_K][TS_NPQ]; + +/* +Compute +KxCRS @ CRSxNPQ = K x NPQ +K=Cout +C=Cin +R,S=KH,KW +P,Q=OH,OW +*/ + +uint32_t B_idx_K = gl_WorkGroupID.x; +uint32_t B_idx_NPQ = gl_WorkGroupID.y; + +uint32_t T_y = tid / NT_NPQ; +uint32_t T_x = tid % NT_NPQ; + +uint32_t Ar = tid / BS_CRS; +uint32_t Ac = tid % BS_CRS; +const uint32_t ArpWg = WG_SIZE / BS_CRS; + +uint32_t Br = tid / BS_NPQ; +uint32_t Bc = tid % BS_NPQ; +const uint32_t BrpWg = WG_SIZE / BS_NPQ; + +void main() { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = 0.0; + } + } + /* Advance block in CRS dim */ + for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { + uint32_t CRS_idx_a; + uint32_t Cin_idx_a; + uint32_t KH_idx_a; + uint32_t KW_idx_a; + +#ifdef USE_COLLECTIVES + uint32_t cached_CRS_idx; + uint32_t cached_Cin_idx; + uint32_t cached_KH_idx; + uint32_t cached_KW_idx; + if (use_collectives == 1) { + cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID; + cached_Cin_idx = cached_CRS_idx / (p.KW * p.KH); + uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH); + cached_KH_idx = cached_CRS_remainder / p.KW; + cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW; + + CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); + Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); + KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); + KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); + } else { + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = CRS_idx_a / (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = CRS_remainder / p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; + } +#else + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = CRS_idx_a / (p.KW * p.KH); + CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = CRS_remainder / p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; +#endif + + /* Load kernel to A_block: (BS_K x BS_CRS)*/ + for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + uint32_t B_ly = r_offset + Ar; + uint32_t B_lx = Ac; + uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); + float val = knl_data[knl_idx]; + if (K_idx >= K || CRS_idx_a >= CRS) { + val = 0.0; + } + Ash[B_ly * Ash_stride + B_lx] = val; + } + /* Load input to B_block: (BS_CRS x BS_NPQ) */ + for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + uint32_t B_ly = r_offset + Br; /* Row index of B block */ + uint32_t B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx = NPQ_idx / (p.OH * p.OW); + uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; + uint32_t OH_idx = NPQ_remainder / p.OW; + uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; + + uint32_t CRS_idx_b; + uint32_t Cin_idx_b; + uint32_t KH_idx_b; + uint32_t KW_idx_b; +#ifdef USE_COLLECTIVES + if (use_collectives == 1) { + CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); + Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); + KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); + KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); + } else { + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = CRS_idx_b / (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = CRS_remainder / p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; + } +#else + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = CRS_idx_b / (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = CRS_remainder / p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; +#endif + + uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; + uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; + uint32_t src_idx = + min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + float val = src_data[src_idx]; + if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) { + val = 0.0; + } + Bsh[B_ly * Bsh_stride + B_lx] = val; + } + barrier(); + for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; + } + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } + } + } + barrier(); + } + /* Save C* */ + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx = NPQ_idx / (p.OH * p.OW); + uint32_t OH_idx = (NPQ_idx - N_idx * p.OH * p.OW) / p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = regC[T_ly][T_lx]; + } + } + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp index 17c7ccb90d0..fdbcf7eba0f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -40,12 +40,10 @@ void main() { const uint src_base = ic * p.offset_delta + batch * p.batch_offset; const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); const int oh_s1 = int(oh) * p.s1; - const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); + const uint ksize = p.OW * p.KH; const uint base_linear_idx = gidx * NUM_ITER; - const uint max_ky = ksize / p.OW; - uint current_kx = base_linear_idx / ksize; const uint rem = base_linear_idx - (current_kx * ksize); uint current_ky = rem / p.OW; @@ -76,7 +74,7 @@ void main() { if (++current_ix == p.OW) { current_ix = 0; - if (++current_ky == max_ky) { + if (++current_ky == p.KH) { current_ky = 0; current_kx++; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp index 6428ca7ba33..bdd7db2d698 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -50,8 +50,14 @@ void main() { const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); if (do_multiply) { - [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { - data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + if (ncols > p.ne10) { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } + } else { + [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } } } else { [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 809c0bd9bd3..f9f0c95b8b2 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -655,6 +655,9 @@ void process_shaders() { string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}}); + string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}}); + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); @@ -765,8 +768,8 @@ void write_output_files() { len += "};\n"; } } - fprintf(src, data.c_str()); - fprintf(src, len.c_str()); + fputs(data.c_str(), src); + fputs(len.c_str(), src); } fclose(hdr); fclose(src); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5ae1c527df6..124cf3e8b60 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6640,20 +6640,18 @@ static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgr static void ggml_graph_dump_dot_node_edge(FILE * fp, const struct ggml_cgraph * gb, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { struct ggml_tensor * gparent = ggml_graph_get_parent(gb, node); struct ggml_tensor * gparent0 = ggml_graph_get_parent(gb, parent); - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", + fprintf(fp, " \"%p\" -> \"%p\" [ arrowhead = %s; style = %s; label = \"%s\"; ]\n", gparent0 ? (void *) gparent0 : (void *) parent, - gparent0 ? "g" : "x", gparent ? (void *) gparent : (void *) node, - gparent ? "g" : "x", gparent ? "empty" : "vee", gparent ? "dashed" : "solid", label); } static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, struct ggml_tensor * parent, const char * label) { - fprintf(fp, " \"%p\":%s -> \"%p\":%s [ label = \"%s\"; ]\n", - (void *) parent, "x", - (void *) node, "x", + fprintf(fp, " \"%p\" -> \"%p\" [ label = \"%s\"; ]\n", + (void *) parent, + (void *) node, label); } diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 9b223827afb..5be1dce60aa 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -a0361ace408ba2c30820deb39e793ad9ed787a85 +b96890f3ab5ffbdbe56bc126df5366c34bd08d39