diff --git a/examples/talk-llama/llama-sampling.cpp b/examples/talk-llama/llama-sampling.cpp index e255a8fc4fd..25536eb6c5a 100644 --- a/examples/talk-llama/llama-sampling.cpp +++ b/examples/talk-llama/llama-sampling.cpp @@ -63,6 +63,30 @@ static void llama_log_softmax(float * array, size_t size) { } */ +static void llama_sampler_temp_impl(llama_token_data_array * cur_p, float temp) { + if (temp <= 0.0f) { + // find the token with the highest logit and set the rest to -inf + size_t max_i = 0; + float max_l = cur_p->data[0].logit; + + for (size_t i = 1; i < cur_p->size; ++i) { + if (cur_p->data[i ].logit > max_l) { + cur_p->data[max_i].logit = -INFINITY; + max_i = i; + max_l = cur_p->data[i].logit; + } else { + cur_p->data[i].logit = -INFINITY; + } + } + + return; + } + + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].logit /= temp; + } +} + static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) { GGML_ASSERT(cur_p->size > 0); @@ -427,6 +451,9 @@ static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl* static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { auto * ctx = (llama_sampler_dist *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + cur_p->selected = llama_sample_dist(cur_p, ctx->rng); } @@ -912,9 +939,8 @@ static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl* static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { const auto * ctx = (llama_sampler_temp *) smpl->ctx; - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= ctx->temp; - } + + llama_sampler_temp_impl(cur_p, ctx->temp); } static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) { @@ -961,6 +987,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke if (ctx->delta > 0) { const float min_temp = std::max(0.0f, ctx->temp - ctx->delta); const float max_temp = ctx->temp + ctx->delta; + float exponent_val = ctx->exponent; // no need to do anything if there is only one (or zero) candidates @@ -998,9 +1025,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke #endif // Apply the dynamically calculated temperature scaling - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= dyn_temp; - } + llama_sampler_temp_impl(cur_p, dyn_temp); // Re-compute softmax probabilities after scaling logits with dynamic temperature const double max_l_double = cur_p->data[0].logit; @@ -1024,9 +1049,7 @@ static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_toke } #endif } else { - for (size_t i = 0; i < cur_p->size; ++i) { - cur_p->data[i].logit /= ctx->temp; - } + llama_sampler_temp_impl(cur_p, ctx->temp); } } @@ -1059,6 +1082,101 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa }; } +// xtc + +struct llama_sampler_xtc { + const float probability; + const float threshold; + const size_t min_keep; + + const uint32_t seed; + uint32_t seed_cur; + + std::mt19937 rng; +}; + +static const char * llama_sampler_xtc_name(const struct llama_sampler * /*smpl*/) { + return "xtc"; +} + +static void llama_sample_xtc_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + + if (ctx->probability <= 0.0f + || ctx->threshold > 0.5f + || cur_p->size < 2) { + return; + } + + std::uniform_real_distribution distribution(0.0f, 1.0f); + float chance = distribution(ctx->rng); + if (chance > ctx->probability) return; + + // in case it's not sorted/recalculated yet + llama_sampler_softmax_impl(cur_p); + + int pos_last = 0; + + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].p >= ctx->threshold) { + pos_last = i; + } else break; + } + + if (cur_p->size - pos_last >= ctx->min_keep && pos_last > 0) { + cur_p->data += pos_last; + cur_p->size -= pos_last; + } +} + +static struct llama_sampler * llama_sampler_xtc_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_xtc *) smpl->ctx; + auto * result = llama_sampler_init_xtc(ctx->probability, ctx->threshold, ctx->min_keep, ctx->seed); + + // copy the state + { + auto * result_ctx = (llama_sampler_xtc *) result->ctx; + + result_ctx->rng = ctx->rng; + } + + return result; +} + +static void llama_sampler_xtc_free(struct llama_sampler * smpl) { + delete (llama_sampler_xtc *) smpl->ctx; +} + +static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_xtc *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + +static struct llama_sampler_i llama_sampler_xtc_i = { + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, +}; + +struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { + auto seed_cur = get_rng_seed(seed); + return new llama_sampler { + /* .iface = */ &llama_sampler_xtc_i, + /* .ctx = */ new llama_sampler_xtc { + /* .probability = */ p, + /* .threshold = */ t, + /* .min_keep = */ min_keep, + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + }, + }; +} + // mirostat struct llama_sampler_mirostat { @@ -1565,6 +1683,397 @@ struct llama_sampler * llama_sampler_init_penalties( }; } +// DRY + +struct llama_sampler_dry { + int32_t total_context_size; + + const float dry_multiplier; + const float dry_base; + const int32_t dry_allowed_length; + const int32_t dry_penalty_last_n; + + std::unordered_multimap> dry_processed_breakers; + std::vector dry_repeat_count; + std::unordered_map dry_max_token_repeat; + ring_buffer last_tokens; +}; + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void get_overlapping_token_sequences(const llama_vocab & vocab, const std::string& str, std::unordered_multimap>& token_sequences, int max_tail_len = -1) { + for (llama_token token_id = 0; token_id < (llama_token)vocab.n_vocab; token_id++) { + std::string word = llama_detokenize(vocab, {token_id}, true); + if (word.find(str) != std::string::npos) { + token_sequences.emplace(token_id, std::vector()); + } else { + size_t word_len = word.size(), str_len = str.size(); + size_t pos = -1; + while ((pos = word.find(str[0], pos + 1)) != std::string::npos) { + bool match = true; + size_t i; + for (i = 1; i < str_len && i + pos < word_len; ++i) { + if (word[pos + i] != str[i]) { + match = false; + break; + } + } + if (match) { + std::vector tokenization = llama_tokenize_internal(vocab, str.substr(i), false, false); + if (max_tail_len >= 0 && tokenization.size() > (size_t)max_tail_len) { + tokenization.resize(max_tail_len); + } + + // Ensure we don't already have a duplicate matching tokenization + auto its = token_sequences.equal_range(token_id); + bool found = false; + for (auto it = its.first; it != its.second; ++it) { + if (tokenization == it->second) { + found = true; + break; + } + } + if (!found) { + token_sequences.emplace(token_id, tokenization); + } + } + } + } + } +} + +static const char * llama_sampler_dry_name(const struct llama_sampler * /*smpl*/) { + return "dry"; +} + +static void llama_sampler_dry_accept(struct llama_sampler * smpl, llama_token token) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + + ctx->last_tokens.push_back(token); +} + +// Ported from Koboldcpp, original PR: https://github.com/LostRuins/koboldcpp/pull/982 (Original author: pi6am) +static void llama_sampler_dry_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + + if (ctx->dry_multiplier == 0.0f || ctx->dry_base < 1.0f || ctx->dry_penalty_last_n == 0) { + return; + } + + int32_t effective_dry_penalty_last_n = (ctx->dry_penalty_last_n == -1) ? ctx->total_context_size : std::max(ctx->dry_penalty_last_n, 0); + int last_n_repeat = std::min(std::min((int)ctx->last_tokens.size(), effective_dry_penalty_last_n), ctx->total_context_size); + + if (last_n_repeat <= ctx->dry_allowed_length) { + return; + } + + ctx->dry_repeat_count.assign(last_n_repeat, 0); + ctx->dry_max_token_repeat.clear(); + + // Step 1: Look for restart sequences to limit the maximum repetition length. + // Work backwards through the context looking for any token that begins a restart sequence. + // + // The collection `restart_sequences` is a mapping from a "head" token to all "tail" + // sequences that together comprise a restart sequence. This allows us to quickly check + // whether each token is the head of a complete sequence. Most restart sequences are actually + // a single token, and for these the "tail" is an empty vector. + // + // If the token is a "head", test all restart sequences that begin with this token + // (there will often only be one sequence for each token, but if sequences like 'aaaq1' and + // 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The + // longest matching sequence (if any) is used to limit the maximum repetition length. + // + // Note that in the case case of a short sequence contained in a longer one, this might fail to + // find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as + // restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress + // 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare. + // + // This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we + // have already clamped the maximum tail sequence length when generating `restart_sequences`. + // With clamping, this scan is O(N) in the context length. + + int rep_limit = last_n_repeat; + for (int i = 0; i < last_n_repeat; ++i) { + llama_token token = ctx->last_tokens.rat(i); + auto its = ctx->dry_processed_breakers.equal_range(token); + if (its.first == ctx->dry_processed_breakers.end()) { + continue; + } + int longest_match = -1; + for (auto it = its.first; it != its.second; ++it) { + // Note that (*it) does not contain the head character, so seq_len will be + // the restart sequence length minus 1. + // In the common case of a single-token restart sequence, (*it) will be empty + // and we will trivially match. + int seq_len = (int)it->second.size(); + if (seq_len > longest_match && seq_len <= (int)i) { + bool match = true; + for (int offset = 0; offset < seq_len; ++offset) { + // The -1 when indexing `last_tokens` is because we already matched the head. + if (it->second[offset] != ctx->last_tokens.rat(i - offset - 1)) { + match = false; + break; + } + } + if (match) { + longest_match = seq_len; + } + } + } + if (longest_match >= 0) { + // We found a restart sequence starting `i` tokens from the end and continuing for + // `longest_match` tokens. + rep_limit = i - longest_match; + break; + } + } + if (rep_limit < ctx->dry_allowed_length) { + return; + } + + // Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in + // the reverse direction) to efficiently compute the positions and lengths of suffixes appearing + // elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences. + // + // This algorithm is not currently documented on Wikipedia, but there is a clear description here: + // https://ivanyu.me/blog/2014/10/15/z-algorithm/ + // + // The code below is adapted from the public domain implementation by the same author here: + // https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py + // + // Example: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // ^ + // This `3` means that the last three tokens of the context (a b c) also appear here. + // + // This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested + // for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each + // repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables + // ensure that the inner while loops only examine each token in the context once as the outer + // for loop iterates over the context. + + { + const int last = last_n_repeat - 1; + int rt = 0, lt = 0; + + for (int k = 1; k < last_n_repeat; ++k) { + if (k > rt) { + // If k is outside the current Z-box, do naive computation. + int n = 0; + while (n + k < last_n_repeat && ctx->last_tokens.rat(n) == ctx->last_tokens.rat(n+k)) { + ++n; + } + ctx->dry_repeat_count[last - k] = std::min(n, rep_limit); + if (n > 0) { + lt = k; + rt = k+n-1; + } + } else { + // If k is inside the current Z-box, consider two cases. + + int p = k - lt; // Pair index. + int right_part_len = rt - k + 1; + + if (ctx->dry_repeat_count[last - p] < right_part_len) { + int n = std::min(ctx->dry_repeat_count[last - p], rep_limit); + ctx->dry_repeat_count[last - k] = n; + } else { + int i = rt + 1; + while (i < last_n_repeat && ctx->last_tokens.rat(i) == ctx->last_tokens.rat(i - k)) { + i += 1; + } + + int n = std::min(i - k, rep_limit); + ctx->dry_repeat_count[last - k] = n; + lt = k; + rt = i - 1; + } + } + } + } + + // Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length + // that would be generated by emitting each new token that would extend a sequence. + // + // Following the same example as above: + // Last N tokens: a b c c b c y a b c + // Repeat counts: 0 0 3 1 0 2 0 0 0 0 + // + // For each non-zero, look ahead one token. This token, if emitted, would extend the repetition. + // c: 3 -> 4 (from `a b c` to `a b c c`) + // b: 1 -> 2 (from `c` to `c b`) + // y: 2 -> 3 (from `b c` to `b c y`) + + for (int i = 0; i < last_n_repeat - 1; ++i) { + int repeat_len = ctx->dry_repeat_count[i]; + if (repeat_len >= ctx->dry_allowed_length) { + // This token ends a repeat, so the next token would continue one. + // By convention, the value of `repeat_len` only includes the tokens currently + // in the context, not the new token that would be added. + llama_token token = ctx->last_tokens.rat(last_n_repeat - 2 - i); + // Track the maximum sequence ending in this token. + const auto& it = ctx->dry_max_token_repeat.find(token); + if (it == ctx->dry_max_token_repeat.end() || it->second < repeat_len) { + ctx->dry_max_token_repeat[token] = repeat_len; + } + } + } + + // Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens. + + // Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`. + // Compute it from `penalty_base` and the approximate log of `std::numeric_limits::max()` + const float FLOAT_MAX_LOG = 88.7228391f; + int max_exponent = 0; + if (ctx->dry_base > 1.000001f) { + max_exponent = FLOAT_MAX_LOG / std::log(ctx->dry_base); + } + + for (size_t i = 0; i < cur_p->size; ++i) { + const auto& af_kvp = ctx->dry_max_token_repeat.find(cur_p->data[i].id); + if (af_kvp != ctx->dry_max_token_repeat.end()) { + // Check all sequence breakers starting with this token + auto range = ctx->dry_processed_breakers.equal_range(cur_p->data[i].id); + bool is_single_token_breaker = false; + + for (auto it = range.first; it != range.second; ++it) { + if (it->second.empty()) { + is_single_token_breaker = true; + break; + } + } + + // Apply penalty only if it's not a single-token sequence breaker + if (!is_single_token_breaker) { + int repeat_exp = af_kvp->second - ctx->dry_allowed_length; + if (max_exponent > 0 && repeat_exp > max_exponent) { + repeat_exp = max_exponent; + } + float penalty = ctx->dry_multiplier * std::pow(ctx->dry_base, repeat_exp); + cur_p->data[i].logit -= penalty; + } + } + } + + cur_p->sorted = false; +} + +static void llama_sampler_dry_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dry *) smpl->ctx; + ctx->last_tokens.clear(); + ctx->dry_repeat_count.clear(); + ctx->dry_max_token_repeat.clear(); +} + +static struct llama_sampler * llama_sampler_dry_clone(const struct llama_sampler * smpl) { + const auto * ctx = (llama_sampler_dry *) smpl->ctx; + + // nullptr is passed as vocab because it is only needed for raw sequence breaker processing, which we have already done and will be copying + auto * result = llama_sampler_init_dry(nullptr, ctx->dry_multiplier, ctx->dry_base, ctx->dry_allowed_length, ctx->dry_penalty_last_n, NULL, 0); + // Copy the state, including the processed breakers + { + auto * result_ctx = (llama_sampler_dry *) result->ctx; + result_ctx->dry_processed_breakers = ctx->dry_processed_breakers; + result_ctx->dry_repeat_count = ctx->dry_repeat_count; + result_ctx->dry_max_token_repeat = ctx->dry_max_token_repeat; + result_ctx->last_tokens = ctx->last_tokens; + } + + return result; +} + +static void llama_sampler_dry_free(struct llama_sampler * smpl) { + delete (llama_sampler_dry *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_dry_i = { + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, +}; + +struct llama_sampler * llama_sampler_init_dry_impl(const struct llama_vocab & vocab, int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { + int32_t effective_dry_penalty_last_n = (dry_penalty_last_n == -1) ? context_size : std::max(dry_penalty_last_n, 0); + std::unordered_multimap> processed_breakers; + const int MAX_CHAR_LEN = 40; + const int MAX_SEQ_LEN = 20; + + const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { + // Process sequence breakers + for (size_t i = 0; i < num_breakers; ++i) { + if (seq_breakers[i] == nullptr || std::strlen(seq_breakers[i]) == 0) { + LLAMA_LOG_WARN("skipping null or empty DRY sequence breaker at index %zu\n", i); + continue; + } + + std::string sequence_break(seq_breakers[i]); + if (sequence_break.empty()) { + LLAMA_LOG_WARN("skipping empty DRY sequence breaker\n"); + continue; + } + + if (sequence_break.size() > MAX_CHAR_LEN) { + LLAMA_LOG_WARN("truncating DRY sequence breaker to %d characters\n", MAX_CHAR_LEN); + sequence_break.resize(MAX_CHAR_LEN); + } + + get_overlapping_token_sequences(vocab, sequence_break, processed_breakers, MAX_SEQ_LEN); + } + } + + return new llama_sampler { + /* .iface = */ &llama_sampler_dry_i, + /* .ctx = */ new llama_sampler_dry { + /* .total_context_size = */ context_size, + /* .dry_multiplier = */ dry_multiplier, + /* .dry_base = */ dry_base, + /* .dry_allowed_length = */ dry_allowed_length, + /* .dry_penalty_last_n = */ dry_penalty_last_n, + /* .dry_processed_breakers = */ std::move(processed_breakers), + /* .dry_repeat_count = */ dry_enabled ? std::vector(effective_dry_penalty_last_n, 0) : std::vector{}, + /* .dry_max_token_repeat = */ {}, + /* .last_tokens = */ dry_enabled ? ring_buffer(effective_dry_penalty_last_n) : ring_buffer(0), + }, + }; +} + +// wrapper for test-sampling.cpp +struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const std::vector>& seq_breakers) { + llama_vocab dummy_vocab; + auto * result = llama_sampler_init_dry_impl(dummy_vocab, context_size, dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, NULL, 0); + auto * ctx = (llama_sampler_dry *) result->ctx; + + // Process the token-based sequence breakers + ctx->dry_processed_breakers.clear(); + if (seq_breakers.empty()) { + LLAMA_LOG_WARN("empty DRY sequence breakers list in llama_sampler_init_dry_testing\n"); + } else { + for (const auto& breaker : seq_breakers) { + if (breaker.empty()) { + LLAMA_LOG_WARN("skipping DRY empty sequence breaker\n"); + continue; + } + llama_token head_token = breaker[0]; + std::vector tail_tokens(breaker.begin() + 1, breaker.end()); + ctx->dry_processed_breakers.emplace(head_token, std::move(tail_tokens)); + } + + if (ctx->dry_processed_breakers.empty()) { + LLAMA_LOG_WARN("no valid DRY sequence breakers processed in llama_sampler_init_dry_testing\n"); + } + } + + return result; +} + // logit-bias struct llama_sampler_logit_bias { @@ -1644,6 +2153,229 @@ struct llama_sampler * llama_sampler_init_logit_bias( }; } +// infill + +//#define GGML_DEBUG_SAMPLER_INFILL + +struct llama_sampler_infill { + const struct llama_vocab * vocab; + + std::vector buf0; + std::vector buf1; +}; + +static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) { + return "infill"; +} + +static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_infill *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + +#if defined(GGML_DEBUG_SAMPLER_INFILL) +#define LOG_DBG_CUR LLAMA_LOG_DEBUG +#else +#define LOG_DBG_CUR(...) +#endif + + for (size_t i = 0; i < cur_p->size; ++i) { + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + float p_txt_sum = 0.0f; + float p_eog_sum = 0.0f; + + for (size_t i = 0; i < cur_p->size; ++i) { + if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + p_eog_sum += cur_p->data[i].p; + } else { + p_txt_sum += cur_p->data[i].p; + } + } + + const float rat = p_eog_sum == 0.0 ? INFINITY : p_txt_sum / p_eog_sum; GGML_UNUSED(rat); + + LOG_DBG_CUR("%s: p_txt_sum = %.2f, p_eog_sum = %.2f, rat = %.2f, n = %zu\n", __func__, p_txt_sum, p_eog_sum, rat, cur_p->size); + + if (3*p_eog_sum*cur_p->size > p_txt_sum) { + LOG_DBG_CUR("%s: the ratio p_txt/p_eog = %.2f is too low -> sampling EOG\n", __func__, p_txt_sum/p_eog_sum); + + // keep just the EOG tokens + const auto size_org = cur_p->size; + + cur_p->size = 0; + + float p_sum = 0.0f; + + for (size_t i = 0; i < size_org; ++i) { + if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + } + + return; + } + + size_t n_combined = 0; GGML_UNUSED(n_combined); + + // combine tokens with common prefix + for (size_t i0 = 0; i0 < cur_p->size; ++i0) { + for (size_t i1 = 0; i1 < cur_p->size; ++i1) { + if (cur_p->data[i0].logit == -INFINITY) { + break; + } + + if (i0 == i1 || cur_p->data[i1].logit == -INFINITY) { + continue; + } + + int len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); + if (len0 < 0) { + ctx->buf0.resize(len0); + len0 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i0].id, ctx->buf0.data(), ctx->buf0.size(), 0, false); + assert(len0 > 0); + } + + int len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); + if (len1 < 0) { + ctx->buf1.resize(len1); + len1 = llama_token_to_piece_impl(*ctx->vocab, cur_p->data[i1].id, ctx->buf1.data(), ctx->buf1.size(), 0, false); + assert(len1 > 0); + } + + // token i0 is a prefix of token i1 + if (len0 > 0 && len0 <= len1 && memcmp(ctx->buf0.data(), ctx->buf1.data(), len0) == 0) { + int dst = i0; + int src = i1; + + // merge into the token with higher probability + if (cur_p->data[i1].p > cur_p->data[i0].p) { + std::swap(dst, src); + } + + cur_p->data[dst].p += cur_p->data[src].p; + cur_p->data[src].logit = -INFINITY; + cur_p->data[src].p = 0.0f; + + n_combined++; + } + } + } + + size_t n_non_eog = 0; + + size_t size_org = cur_p->size; + + float p_sum = 0.0f; + float thold = 0.2f; + + cur_p->size = 0; + + LOG_DBG_CUR("%s: n_combined = %zu, applying thold = %.3f\n", __func__, n_combined, thold); + + for (size_t i = 0; i < size_org; ++i) { + const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id); + + if (cur_p->data[i].p < thold && !is_eog) { + continue; + } + + if (!is_eog) { + ++n_non_eog; + } + + p_sum += cur_p->data[i].p; + + // keep this token + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + + LOG_DBG_CUR("%s: n_non_eog = %zu\n", __func__, n_non_eog); + + // if no non-EOG tokens are left -> reduce cur_p to single EOT token + if (n_non_eog == 0) { + cur_p->size = 1; + cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab); + cur_p->data[0].logit = 1.0f; + + return; + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + size_org = cur_p->size; + p_sum = 0.0f; + thold = 1.0/(n_non_eog + 1); + + cur_p->size = 0; + + LOG_DBG_CUR("%s: applying thold = %.3f\n", __func__, thold); + + for (size_t i = 0; i < size_org; ++i) { + const bool is_eog = llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id); + + if (cur_p->data[i].p < thold && !is_eog) { + continue; + } + + p_sum += cur_p->data[i].p; + + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + + // normalize probs + for (size_t i = 0; i < cur_p->size; ++i) { + cur_p->data[i].p /= p_sum; + + LOG_DBG_CUR("%s: cur_p[%3zu] = { id: %6d, p: %.6f, logit: %6.3f }\n", __func__, i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + +#undef LOG_DBG_CUR +} + +static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_infill *) smpl->ctx; + return llama_sampler_init_infill_impl(*ctx->vocab); +} + +static void llama_sampler_infill_free(struct llama_sampler * smpl) { + delete (llama_sampler_infill *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_infill_i = { + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, +}; + +struct llama_sampler * llama_sampler_init_infill_impl( + const struct llama_vocab & vocab) { + return new llama_sampler { + /* .iface = */ &llama_sampler_infill_i, + /* .ctx = */ new llama_sampler_infill { + /* .vocab = */ &vocab, + /* .buf0 = */ std::vector(512), + /* .buf1 = */ std::vector(512), + }, + }; +} + // utils uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { diff --git a/examples/talk-llama/llama-sampling.h b/examples/talk-llama/llama-sampling.h index d90b147130e..919f6fdfcef 100644 --- a/examples/talk-llama/llama-sampling.h +++ b/examples/talk-llama/llama-sampling.h @@ -4,8 +4,6 @@ #include "llama-grammar.h" -#include - struct llama_vocab; struct llama_grammar; @@ -27,3 +25,24 @@ struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); + +struct llama_sampler * llama_sampler_init_infill_impl( + const struct llama_vocab & vocab); + +struct llama_sampler * llama_sampler_init_dry_impl( + const struct llama_vocab & vocab, + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const char ** seq_breakers, + size_t num_breakers); + +struct llama_sampler * llama_sampler_init_dry_testing( + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector>& seq_breakers); diff --git a/examples/talk-llama/llama-vocab.cpp b/examples/talk-llama/llama-vocab.cpp index d2f34ddd6b3..d1dc96276c2 100644 --- a/examples/talk-llama/llama-vocab.cpp +++ b/examples/talk-llama/llama-vocab.cpp @@ -221,7 +221,7 @@ struct llm_tokenizer_spm_session { } // seed the work queue with all possible 2-character tokens. - for (size_t i = 1; i < symbols.size(); ++i) { + for (int i = 1; i < (int) symbols.size(); ++i) { try_add_bigram(i - 1, i); } @@ -563,7 +563,7 @@ struct llm_tokenizer_bpe_session { index++; symbols.emplace_back(sym); } - for (size_t i = 1; i < symbols.size(); ++i) { + for (int i = 1; i < (int) symbols.size(); ++i) { add_new_bigram(i - 1, i); } @@ -1663,6 +1663,14 @@ llama_token llama_token_eos_impl(const struct llama_vocab & vocab) { return vocab.special_eos_id; } +llama_token llama_token_eot_impl(const struct llama_vocab & vocab) { + return vocab.special_eot_id; +} + +llama_token llama_token_eom_impl(const struct llama_vocab & vocab) { + return vocab.special_eom_id; +} + llama_token llama_token_cls_impl(const struct llama_vocab & vocab) { return vocab.special_cls_id; } @@ -1688,23 +1696,39 @@ bool llama_add_eos_token_impl(const struct llama_vocab & vocab) { } llama_token llama_token_prefix_impl(const struct llama_vocab & vocab) { - return vocab.special_prefix_id; + return vocab.special_fim_pre_id; } llama_token llama_token_middle_impl(const struct llama_vocab & vocab) { - return vocab.special_middle_id; + return vocab.special_fim_mid_id; } llama_token llama_token_suffix_impl(const struct llama_vocab & vocab) { - return vocab.special_suffix_id; + return vocab.special_fim_suf_id; } -llama_token llama_token_eot_impl(const struct llama_vocab & vocab) { - return vocab.special_eot_id; +llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_pre_id; } -llama_token llama_token_eom_impl(const struct llama_vocab & vocab) { - return vocab.special_eom_id; +llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_suf_id; +} + +llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_mid_id; +} + +llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_pad_id; +} + +llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_rep_id; +} + +llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab) { + return vocab.special_fim_sep_id; } int32_t llama_tokenize_impl( @@ -1942,3 +1966,19 @@ int32_t llama_detokenize_impl( return total <= text_len_max ? total : -total; } + +std::string llama_detokenize(const struct llama_vocab & vocab, const std::vector & tokens, bool special) { + std::string text; + text.resize(std::max(text.capacity(), tokens.size())); + int32_t n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + if (n_chars < 0) { + text.resize(-n_chars); + n_chars = llama_detokenize_impl(vocab, tokens.data(), (int32_t)tokens.size(), &text[0], (int32_t)text.size(), false, special); + GGML_ASSERT(n_chars <= (int32_t)text.size()); // whitespace trimming is performed after per-token detokenization + } + + text.resize(n_chars); + + // NOTE: the original tokenizer decodes bytes after collecting the pieces. + return text; +} diff --git a/examples/talk-llama/llama-vocab.h b/examples/talk-llama/llama-vocab.h index 069bdc423a6..4bb16d2e429 100644 --- a/examples/talk-llama/llama-vocab.h +++ b/examples/talk-llama/llama-vocab.h @@ -37,20 +37,26 @@ struct llama_vocab { std::map, int> bpe_ranks; // default LLaMA special tokens + // TODO: should we set all of these to LLAMA_TOKEN_NULL? id special_bos_id = 1; id special_eos_id = 2; + id special_eot_id = LLAMA_TOKEN_NULL; + id special_eom_id = LLAMA_TOKEN_NULL; id special_unk_id = 0; - id special_sep_id = -1; - id special_pad_id = -1; - id special_cls_id = -1; - id special_mask_id = -1; - - id linefeed_id = 13; - id special_prefix_id = -1; - id special_suffix_id = -1; - id special_middle_id = -1; - id special_eot_id = -1; // TODO: move above after "eos_id", and here add "file separator" token - id special_eom_id = -1; + id special_sep_id = LLAMA_TOKEN_NULL; + id special_pad_id = LLAMA_TOKEN_NULL; + id special_cls_id = LLAMA_TOKEN_NULL; + id special_mask_id = LLAMA_TOKEN_NULL; + + id linefeed_id = 13; + + // fim tokens + id special_fim_pre_id = LLAMA_TOKEN_NULL; + id special_fim_suf_id = LLAMA_TOKEN_NULL; + id special_fim_mid_id = LLAMA_TOKEN_NULL; + id special_fim_pad_id = LLAMA_TOKEN_NULL; + id special_fim_rep_id = LLAMA_TOKEN_NULL; // repo + id special_fim_sep_id = LLAMA_TOKEN_NULL; // file separator // set of all tokens that cause "end of generation" std::set special_eog_ids; @@ -104,19 +110,26 @@ bool llama_token_is_control_impl(const struct llama_vocab & vocab, llama_token t llama_token llama_token_bos_impl(const struct llama_vocab & vocab); llama_token llama_token_eos_impl(const struct llama_vocab & vocab); +llama_token llama_token_eot_impl(const struct llama_vocab & vocab); +llama_token llama_token_eom_impl(const struct llama_vocab & vocab); llama_token llama_token_cls_impl(const struct llama_vocab & vocab); llama_token llama_token_sep_impl(const struct llama_vocab & vocab); llama_token llama_token_nl_impl (const struct llama_vocab & vocab); llama_token llama_token_pad_impl(const struct llama_vocab & vocab); -bool llama_add_bos_token_impl(const struct llama_vocab & vocab); -bool llama_add_eos_token_impl(const struct llama_vocab & vocab); - llama_token llama_token_prefix_impl(const struct llama_vocab & vocab); llama_token llama_token_middle_impl(const struct llama_vocab & vocab); llama_token llama_token_suffix_impl(const struct llama_vocab & vocab); -llama_token llama_token_eot_impl (const struct llama_vocab & vocab); -llama_token llama_token_eom_impl (const struct llama_vocab & vocab); + +llama_token llama_token_fim_pre_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_suf_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_mid_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_pad_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_rep_impl(const struct llama_vocab & vocab); +llama_token llama_token_fim_sep_impl(const struct llama_vocab & vocab); + +bool llama_add_bos_token_impl(const struct llama_vocab & vocab); +bool llama_add_eos_token_impl(const struct llama_vocab & vocab); int32_t llama_tokenize_impl( const struct llama_vocab & vocab, @@ -136,6 +149,12 @@ int32_t llama_token_to_piece_impl( int32_t lstrip, bool special); +// check if token0 is contained as a prefix in token1 +bool llama_token_is_prefix_impl( + const struct llama_vocab & vocab, + llama_token token0, + llama_token token1); + int32_t llama_detokenize_impl( const struct llama_vocab & vocab, const llama_token * tokens, @@ -144,3 +163,8 @@ int32_t llama_detokenize_impl( int32_t text_len_max, bool remove_special, bool unparse_special); + +std::string llama_detokenize( + const struct llama_vocab & vocab, + const std::vector & tokens, + bool special); diff --git a/examples/talk-llama/llama.cpp b/examples/talk-llama/llama.cpp index 3443b0689bf..53979e83f8b 100644 --- a/examples/talk-llama/llama.cpp +++ b/examples/talk-llama/llama.cpp @@ -8,26 +8,16 @@ #include "ggml-alloc.h" #include "ggml-backend.h" -#ifdef GGML_USE_RPC -# include "ggml-rpc.h" -#endif - -#if defined(GGML_USE_VULKAN) -# include "ggml-vulkan.h" -#elif defined(GGML_USE_SYCL) -# include "ggml-sycl.h" -#elif defined(GGML_USE_KOMPUTE) +#if defined(GGML_USE_KOMPUTE) # include "ggml-kompute.h" -#elif defined(GGML_USE_CANN) -# include "ggml-cann.h" #endif -#ifdef GGML_USE_BLAS -# include "ggml-blas.h" +#ifndef __AMX_INT8__ +#undef GGML_USE_AMX #endif -#ifdef GGML_USE_METAL -# include "ggml-metal.h" +#ifdef GGML_USE_AMX +# include "ggml-amx.h" #endif // TODO: replace with ggml API call @@ -357,6 +347,8 @@ enum llm_kv { LLM_KV_TOKENIZER_MERGES, LLM_KV_TOKENIZER_BOS_ID, LLM_KV_TOKENIZER_EOS_ID, + LLM_KV_TOKENIZER_EOT_ID, + LLM_KV_TOKENIZER_EOM_ID, LLM_KV_TOKENIZER_UNK_ID, LLM_KV_TOKENIZER_SEP_ID, LLM_KV_TOKENIZER_PAD_ID, @@ -369,14 +361,20 @@ enum llm_kv { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, LLM_KV_TOKENIZER_HF_JSON, LLM_KV_TOKENIZER_RWKV, - LLM_KV_TOKENIZER_PREFIX_ID, - LLM_KV_TOKENIZER_SUFFIX_ID, - LLM_KV_TOKENIZER_MIDDLE_ID, - LLM_KV_TOKENIZER_EOT_ID, - LLM_KV_TOKENIZER_EOM_ID, + LLM_KV_TOKENIZER_FIM_PRE_ID, + LLM_KV_TOKENIZER_FIM_SUF_ID, + LLM_KV_TOKENIZER_FIM_MID_ID, + LLM_KV_TOKENIZER_FIM_PAD_ID, + LLM_KV_TOKENIZER_FIM_REP_ID, + LLM_KV_TOKENIZER_FIM_SEP_ID, LLM_KV_ADAPTER_TYPE, LLM_KV_ADAPTER_LORA_ALPHA, + + // deprecated: + LLM_KV_TOKENIZER_PREFIX_ID, + LLM_KV_TOKENIZER_SUFFIX_ID, + LLM_KV_TOKENIZER_MIDDLE_ID, }; static const std::map LLM_KV_NAMES = { @@ -434,57 +432,65 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" }, { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, - { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, - { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, - { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, - { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, - { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, - { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, - { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, - { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, - { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, - - { LLM_KV_SPLIT_NO, "split.no" }, - { LLM_KV_SPLIT_COUNT, "split.count" }, - { LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" }, - - { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, - { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, - { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, - { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, - { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, - - { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, - - { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, - { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, - { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, - { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, - { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, - { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, - { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, - { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, - { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, - { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, - { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, - { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, - { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, - { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, - { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, - { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, - { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, - { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, - { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, - { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, - { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, - { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, - { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, - { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, - { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, - - { LLM_KV_ADAPTER_TYPE, "adapter.type" }, - { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, + { LLM_KV_ROPE_FREQ_BASE, "%s.rope.freq_base" }, + { LLM_KV_ROPE_SCALE_LINEAR, "%s.rope.scale_linear" }, + { LLM_KV_ROPE_SCALING_TYPE, "%s.rope.scaling.type" }, + { LLM_KV_ROPE_SCALING_FACTOR, "%s.rope.scaling.factor" }, + { LLM_KV_ROPE_SCALING_ATTN_FACTOR, "%s.rope.scaling.attn_factor" }, + { LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" }, + { LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" }, + { LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" }, + + { LLM_KV_SPLIT_NO, "split.no" }, + { LLM_KV_SPLIT_COUNT, "split.count" }, + { LLM_KV_SPLIT_TENSORS_COUNT, "split.tensors.count" }, + + { LLM_KV_SSM_CONV_KERNEL, "%s.ssm.conv_kernel" }, + { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, + { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, + { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, + + { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, + + { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, + { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, + { LLM_KV_TOKENIZER_LIST, "tokenizer.ggml.tokens" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE, "tokenizer.ggml.token_type" }, + { LLM_KV_TOKENIZER_TOKEN_TYPE_COUNT, "tokenizer.ggml.token_type_count" }, + { LLM_KV_TOKENIZER_SCORES, "tokenizer.ggml.scores" }, + { LLM_KV_TOKENIZER_MERGES, "tokenizer.ggml.merges" }, + { LLM_KV_TOKENIZER_BOS_ID, "tokenizer.ggml.bos_token_id" }, + { LLM_KV_TOKENIZER_EOS_ID, "tokenizer.ggml.eos_token_id" }, + { LLM_KV_TOKENIZER_EOT_ID, "tokenizer.ggml.eot_token_id" }, + { LLM_KV_TOKENIZER_EOM_ID, "tokenizer.ggml.eom_token_id" }, + { LLM_KV_TOKENIZER_UNK_ID, "tokenizer.ggml.unknown_token_id" }, + { LLM_KV_TOKENIZER_SEP_ID, "tokenizer.ggml.seperator_token_id" }, + { LLM_KV_TOKENIZER_PAD_ID, "tokenizer.ggml.padding_token_id" }, + { LLM_KV_TOKENIZER_CLS_ID, "tokenizer.ggml.cls_token_id" }, + { LLM_KV_TOKENIZER_MASK_ID, "tokenizer.ggml.mask_token_id" }, + { LLM_KV_TOKENIZER_ADD_BOS, "tokenizer.ggml.add_bos_token" }, + { LLM_KV_TOKENIZER_ADD_EOS, "tokenizer.ggml.add_eos_token" }, + { LLM_KV_TOKENIZER_ADD_PREFIX, "tokenizer.ggml.add_space_prefix" }, + { LLM_KV_TOKENIZER_REMOVE_EXTRA_WS, "tokenizer.ggml.remove_extra_whitespaces" }, + { LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP, "tokenizer.ggml.precompiled_charsmap" }, + { LLM_KV_TOKENIZER_HF_JSON, "tokenizer.huggingface.json" }, + { LLM_KV_TOKENIZER_RWKV, "tokenizer.rwkv.world" }, + { LLM_KV_TOKENIZER_FIM_PRE_ID, "tokenizer.ggml.fim_pre_token_id" }, + { LLM_KV_TOKENIZER_FIM_SUF_ID, "tokenizer.ggml.fim_suf_token_id" }, + { LLM_KV_TOKENIZER_FIM_MID_ID, "tokenizer.ggml.fim_mid_token_id" }, + { LLM_KV_TOKENIZER_FIM_PAD_ID, "tokenizer.ggml.fim_pad_token_id" }, + { LLM_KV_TOKENIZER_FIM_REP_ID, "tokenizer.ggml.fim_rep_token_id" }, + { LLM_KV_TOKENIZER_FIM_SEP_ID, "tokenizer.ggml.fim_sep_token_id" }, + + { LLM_KV_ADAPTER_TYPE, "adapter.type" }, + { LLM_KV_ADAPTER_LORA_ALPHA, "adapter.lora.alpha" }, + + // deprecated + { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, + { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, + { LLM_KV_TOKENIZER_MIDDLE_ID, "tokenizer.ggml.middle_token_id" }, }; struct LLM_KV { @@ -2412,7 +2418,7 @@ struct llama_hparams { // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 - llama_token dec_start_token_id = -1; + llama_token dec_start_token_id = LLAMA_TOKEN_NULL; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -2941,9 +2947,6 @@ struct llama_sbatch_seq { llama_seq_id * seq_id; size_t offset; size_t length; - - // helper for smoother batch API transition -- can be deprecated in the future - llama_seq_id all_seq_id; // used if seq_id == NULL }; // sequence-length-aware batch splitting @@ -3038,30 +3041,18 @@ struct llama_sbatch { } else { ubatch.embd = nullptr; } - // from here on, the else branches are deprecated; - // they are helpers for smoother batch API transition - if (batch->pos) { - if (ubatch.equal_seqs) { - for (size_t i = 0; i < length; ++i) { - ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; - } - } else { - // simple split - ubatch.pos = batch->pos + seq.offset; - } - } else { + if (ubatch.equal_seqs) { for (size_t i = 0; i < length; ++i) { - llama_pos bi = ids[seq.offset + i]; - ubatch.pos[ubatch.n_tokens + i] = batch->all_pos_0 + (bi * batch->all_pos_1); + ubatch.pos[ubatch.n_tokens + i] = batch->pos[ids[seq.offset + i]]; } + } else { + // simple split + ubatch.pos = batch->pos + seq.offset; } if (ubatch.equal_seqs) { ubatch.n_seq_id[ubatch.n_seqs] = seq.n_seq_id; if (seq.seq_id) { ubatch.seq_id[ubatch.n_seqs] = seq.seq_id; - } else { - GGML_ASSERT(seq.n_seq_id == 1); - ubatch.seq_id[ubatch.n_seqs] = &seq.all_seq_id; } } else { // simple split @@ -3074,10 +3065,6 @@ struct llama_sbatch { } if (batch->seq_id) { ubatch.seq_id = batch->seq_id + seq.offset; - } else { - for (size_t i = 0; i < length; ++i) { - ubatch.seq_id[ubatch.n_seqs + i] = &seq.all_seq_id; - } } } if (logits_all) { @@ -3196,7 +3183,6 @@ struct llama_sbatch { s.seq_id = nullptr; s.offset = 0; s.length = n_tokens; - s.all_seq_id = batch.all_seq_id; return; } std::sort(ids.begin(), ids.end(), @@ -3219,7 +3205,7 @@ struct llama_sbatch { if (batch.pos) { return batch.pos[a] < batch.pos[b]; } - // no pos, sort by id (assuming batch.all_pos_1 is positive) + // no pos, sort by id return a < b; } // shared prompts go first @@ -3229,30 +3215,25 @@ struct llama_sbatch { // init seq llama_sbatch_seq * last_seq = nullptr; - if (batch.n_seq_id != nullptr && batch.seq_id != nullptr) { - for (size_t i = 0; i < n_tokens; ++i) { - const size_t bi = ids[i]; - const int32_t n_seqs = batch.n_seq_id[bi]; - llama_seq_id * seq_ids = batch.seq_id[bi]; - if (last_seq != nullptr) { - bool same = n_seqs == last_seq->n_seq_id; - for (int32_t j = 0; same && j < n_seqs; ++j) { - if (seq_ids[j] != last_seq->seq_id[j]) { - same = false; - } - } - if (same) { - last_seq->length += 1; - continue; + for (size_t i = 0; i < n_tokens; ++i) { + const size_t bi = ids[i]; + const int32_t n_seqs = batch.n_seq_id[bi]; + llama_seq_id * seq_ids = batch.seq_id[bi]; + if (last_seq != nullptr) { + bool same = n_seqs == last_seq->n_seq_id; + for (int32_t j = 0; same && j < n_seqs; ++j) { + if (seq_ids[j] != last_seq->seq_id[j]) { + same = false; } } - llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1, batch.all_seq_id}; - seq.push_back(new_seq); - last_seq = &seq.back(); + if (same) { + last_seq->length += 1; + continue; + } } - } else { - llama_sbatch_seq new_seq = {1, nullptr, 0, n_tokens, batch.all_seq_id}; + llama_sbatch_seq new_seq = {n_seqs, seq_ids, i, 1}; seq.push_back(new_seq); + last_seq = &seq.back(); } // keep shared prompts first at the end, then sort by length descending. std::sort(seq.begin(), seq.end(), @@ -3292,12 +3273,8 @@ struct llama_context { std::unordered_map lora_adapters; std::vector backends; -#ifdef GGML_USE_METAL - ggml_backend_t backend_metal = nullptr; -#endif -#ifdef GGML_USE_BLAS - ggml_backend_t backend_blas = nullptr; -#endif + std::vector> set_n_threads_fns; + ggml_backend_t backend_cpu = nullptr; ggml_threadpool_t threadpool = nullptr; @@ -3420,16 +3397,6 @@ static int llama_get_device_count(const llama_model & model) { count += (int) model.rpc_servers.size(); #endif -#if defined(GGML_USE_METAL) - count += 1; -#elif defined(GGML_USE_SYCL) - count += ggml_backend_sycl_get_device_count(); -#elif defined(GGML_USE_VULKAN) - count += ggml_backend_vk_get_device_count(); -#elif defined(GGML_USE_CANN) - count += ggml_backend_cann_get_device_count(); -#endif - return count; GGML_UNUSED(model); @@ -3447,20 +3414,8 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(const llama_mode } } -#if defined(GGML_USE_SYCL) - if (host_buffer) { - buft = ggml_backend_sycl_host_buffer_type(); - } -#elif defined(GGML_USE_CANN) - if (host_buffer) { - buft = ggml_backend_cann_host_buffer_type(); - } -#elif defined(GGML_USE_CPU_HBM) +#if defined(GGML_USE_CPU_HBM) buft = ggml_backend_cpu_hbm_buffer_type(); -#elif defined(GGML_USE_VULKAN) - if (host_buffer) { - buft = ggml_backend_vk_host_buffer_type(); - } #endif if (buft == nullptr) { @@ -3474,30 +3429,13 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(const llama_mode static ggml_backend_buffer_type_t llama_default_buffer_type_offload(const llama_model & model, int device) { ggml_backend_buffer_type_t buft = nullptr; -#if defined(GGML_USE_RPC) - int rpc_count = (int)model.rpc_servers.size(); - if (device < rpc_count) { - const char * endpoint = model.rpc_servers[device].c_str(); - return ggml_backend_rpc_buffer_type(endpoint); - } - device -= rpc_count; -#endif - if (device < (int)model.devices.size()) { return ggml_backend_dev_buffer_type(model.devices[device]); } device -= (int)model.devices.size(); -#if defined(GGML_USE_METAL) - buft = ggml_backend_metal_buffer_type(); -#elif defined(GGML_USE_VULKAN) - buft = ggml_backend_vk_buffer_type(device); -#elif defined(GGML_USE_SYCL) - buft = ggml_backend_sycl_buffer_type(device); -#elif defined(GGML_USE_KOMPUTE) +#if defined(GGML_USE_KOMPUTE) buft = ggml_backend_kompute_buffer_type(device); -#elif defined(GGML_USE_CANN) - buft = ggml_backend_cann_buffer_type(device); #endif if (buft == nullptr) { @@ -3524,12 +3462,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo } } -#ifdef GGML_USE_SYCL - if (ggml_backend_sycl_get_device_count() > 1) { - buft = ggml_backend_sycl_split_buffer_type(tensor_split); - } -#endif - if (buft == nullptr) { buft = llama_default_buffer_type_offload(model, fallback_gpu); } @@ -3539,18 +3471,6 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_split(const llama_mo } static size_t llama_get_device_memory(const llama_model & model, int device) { -#if defined(GGML_USE_RPC) - int rpc_count = (int)model.rpc_servers.size(); - if (device < rpc_count) { - size_t total; - size_t free; - const char * endpoint = model.rpc_servers[device].c_str(); - ggml_backend_rpc_get_device_memory(endpoint, &free, &total); - return free; - } - device = device - rpc_count; -#endif - if (device < (int)model.devices.size()) { ggml_backend_dev_t dev = model.devices[device]; size_t total; @@ -3559,24 +3479,14 @@ static size_t llama_get_device_memory(const llama_model & model, int device) { return free; } -#if defined(GGML_USE_SYCL) - size_t total; - size_t free; - ggml_backend_sycl_get_device_memory(device, &free, &total); - return free; -#elif defined(GGML_USE_VULKAN) - size_t total; - size_t free; - ggml_backend_vk_get_device_memory(device, &free, &total); - return free; -#elif defined(GGML_USE_CANN) - size_t total; - size_t free; - ggml_backend_cann_get_device_memory(device, &free, &total); - return free; -#else + if (model.devices.size() > 0) { + ggml_backend_reg_t reg = ggml_backend_dev_backend_reg(model.devices[0]); + LLAMA_LOG_WARN("%s: failed to get free memmory of device:%d of backend:%s, for device id is out of range.\n", __func__, device, ggml_backend_reg_name(reg)); + } else { + LLAMA_LOG_WARN("%s: failed to get free memmory of device, no devices in inputted model.\n", __func__); + } return 1; -#endif + GGML_UNUSED(model); GGML_UNUSED(device); } @@ -5267,6 +5177,57 @@ struct llama_model_loader { } }; +// temporary allocate memory for the input batch if needed +static const llama_seq_id batch_default_seq_id = 0; +struct llama_batch_allocr { + std::array seq_id_0 = {batch_default_seq_id}; + std::vector pos; + std::vector n_seq_id; + std::vector seq_id; + std::vector logits; + struct llama_batch batch; + // optionally fulfill the batch returned by llama_batch_get_one + llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) { + batch = in_batch; + GGML_ASSERT(batch.n_tokens > 0); + if (!batch.pos) { + // determine the last position in KV cache + llama_pos last_pos = -1; + for (const auto & cell : ctx.kv_self.cells) { + if (cell.has_seq_id(batch_default_seq_id)) { + last_pos = std::max(last_pos, cell.pos); + } + } + last_pos++; // next position + pos.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + pos[i] = i+last_pos; + } + batch.pos = pos.data(); + } + if (!batch.n_seq_id) { + n_seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { + n_seq_id[i] = seq_id_0.size(); + } + batch.n_seq_id = n_seq_id.data(); + } + if (!batch.seq_id) { + seq_id.resize(batch.n_tokens + 1); + seq_id[batch.n_tokens] = NULL; + for (int32_t i = 0; i < batch.n_tokens; i++) { + seq_id[i] = seq_id_0.data(); + } + batch.seq_id = seq_id.data(); + } + if (!batch.logits) { + logits.resize(batch.n_tokens); + logits[logits.size() - 1] = true; + batch.logits = logits.data(); + } + } +}; + template<> bool llama_model_loader::get_key(const enum llm_kv kid, enum llama_pooling_type & result, const bool required) { uint32_t tmp; @@ -6209,14 +6170,14 @@ static void llm_load_vocab( vocab.type = LLAMA_VOCAB_TYPE_NONE; // default special tokens - vocab.special_bos_id = -1; - vocab.special_eos_id = -1; - vocab.special_unk_id = -1; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - vocab.special_cls_id = -1; - vocab.special_mask_id = -1; - vocab.linefeed_id = -1; + vocab.special_bos_id = LLAMA_TOKEN_NULL; + vocab.special_eos_id = LLAMA_TOKEN_NULL; + vocab.special_unk_id = LLAMA_TOKEN_NULL; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; + vocab.linefeed_id = LLAMA_TOKEN_NULL; // read vocab size from metadata if (!ml.get_key(LLM_KV_VOCAB_SIZE, vocab.n_vocab, false)) { @@ -6233,16 +6194,16 @@ static void llm_load_vocab( vocab.special_bos_id = 1; vocab.special_eos_id = 2; vocab.special_unk_id = 0; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - vocab.special_cls_id = -1; - vocab.special_mask_id = -1; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; } else if (tokenizer_model == "bert") { vocab.type = LLAMA_VOCAB_TYPE_WPM; // default special tokens - vocab.special_bos_id = -1; - vocab.special_eos_id = -1; + vocab.special_bos_id = LLAMA_TOKEN_NULL; + vocab.special_eos_id = LLAMA_TOKEN_NULL; vocab.special_unk_id = 100; vocab.special_sep_id = 102; vocab.special_pad_id = 0; @@ -6278,22 +6239,22 @@ static void llm_load_vocab( // default special tokens vocab.special_bos_id = 11; vocab.special_eos_id = 11; - vocab.special_unk_id = -1; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; - vocab.special_cls_id = -1; - vocab.special_mask_id = -1; + vocab.special_unk_id = LLAMA_TOKEN_NULL; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; } else if (tokenizer_model == "t5") { vocab.type = LLAMA_VOCAB_TYPE_UGM; // default special tokens - vocab.special_bos_id = -1; + vocab.special_bos_id = LLAMA_TOKEN_NULL; vocab.special_eos_id = 1; vocab.special_unk_id = 2; - vocab.special_sep_id = -1; + vocab.special_sep_id = LLAMA_TOKEN_NULL; vocab.special_pad_id = 0; - vocab.special_cls_id = -1; - vocab.special_mask_id = -1; + vocab.special_cls_id = LLAMA_TOKEN_NULL; + vocab.special_mask_id = LLAMA_TOKEN_NULL; const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); if (precompiled_charsmap_keyidx != -1) { @@ -6316,11 +6277,11 @@ static void llm_load_vocab( vocab.type = LLAMA_VOCAB_TYPE_RWKV; // default special tokens - vocab.special_bos_id = -1; - vocab.special_eos_id = -1; - vocab.special_unk_id = -1; - vocab.special_sep_id = -1; - vocab.special_pad_id = -1; + vocab.special_bos_id = LLAMA_TOKEN_NULL; + vocab.special_eos_id = LLAMA_TOKEN_NULL; + vocab.special_unk_id = LLAMA_TOKEN_NULL; + vocab.special_sep_id = LLAMA_TOKEN_NULL; + vocab.special_pad_id = LLAMA_TOKEN_NULL; } else { throw std::runtime_error(format("unknown tokenizer: '%s'", tokenizer_model.c_str())); } @@ -6404,7 +6365,7 @@ static void llm_load_vocab( } else if ( tokenizer_pre == "chatglm-bpe") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CHATGLM4; - vocab.special_bos_id = -1; + vocab.special_bos_id = LLAMA_TOKEN_NULL; } else if ( tokenizer_pre == "viking") { vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_VIKING; @@ -6530,44 +6491,6 @@ static void llm_load_vocab( // determine the newline token: LLaMA "<0x0A>" == 10 == '\n', Falcon 193 == '\n' if (vocab.type == LLAMA_VOCAB_TYPE_SPM) { - // For Fill-In-the-Middle (FIM)/infill models which where converted - // prior to support of FIM special tokens in GGUF, the following - // will allow those models to continue to work. The general names - // of the known models are currently CodeLlama (LLM_ARCH_LLAMA) and - // CodeGemma (LLM_ARCH_GEMMA). This can potentially be removed once - // new versions of these models have been published. - std::string gen_name; - ml.get_key(LLM_KV_GENERAL_NAME, gen_name, false); - - std::transform(gen_name.begin(), gen_name.end(), gen_name.begin(), - [](unsigned char c){ return std::tolower(c); }); - - if (gen_name.find("code") != std::string::npos) { - if (model.arch == LLM_ARCH_LLAMA - && 32010 < vocab.id_to_token.size() - && vocab.id_to_token[32007].text.find("
") != std::string::npos
-              && vocab.id_to_token[32008].text.find("") != std::string::npos
-              && vocab.id_to_token[32009].text.find("") != std::string::npos
-              && vocab.id_to_token[32010].text.find("") != std::string::npos) {
-                vocab.special_prefix_id = 32007;
-                vocab.special_suffix_id = 32008;
-                vocab.special_middle_id = 32009;
-                vocab.special_eot_id    = 32010;
-            } else if (model.arch == LLM_ARCH_GEMMA
-              && 107 < vocab.id_to_token.size()
-              && vocab.id_to_token[67].text == "<|fim_prefix|>"
-              && vocab.id_to_token[69].text == "<|fim_suffix|>"
-              && vocab.id_to_token[68].text == "<|fim_middle|>"
-              && vocab.id_to_token[107].text == "") {
-                vocab.special_prefix_id = 67;
-                vocab.special_suffix_id = 69;
-                vocab.special_middle_id = 68;
-                // TODO: this is not EOT, it is "file separator" token, needs fix
-                //       https://huggingface.co/google/codegemma-7b-it/blob/9b1d9231388358c04d90bd003458f5070d97db44/tokenizer_config.json#L565-L572
-                //vocab.special_eot_id    = 70;
-                vocab.special_eot_id    = 107;
-            }
-        }
         try {
             vocab.linefeed_id = llama_byte_to_token_impl(vocab, '\n');
         } catch (const std::exception & e) {
@@ -6595,18 +6518,26 @@ static void llm_load_vocab(
     // special tokens
     {
         const std::vector> special_token_types = {
-            { LLM_KV_TOKENIZER_BOS_ID,    vocab.special_bos_id    },
-            { LLM_KV_TOKENIZER_EOS_ID,    vocab.special_eos_id    },
-            { LLM_KV_TOKENIZER_UNK_ID,    vocab.special_unk_id    },
-            { LLM_KV_TOKENIZER_SEP_ID,    vocab.special_sep_id    },
-            { LLM_KV_TOKENIZER_PAD_ID,    vocab.special_pad_id    },
-            { LLM_KV_TOKENIZER_CLS_ID,    vocab.special_cls_id    },
-            { LLM_KV_TOKENIZER_MASK_ID,   vocab.special_mask_id   },
-            { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_prefix_id },
-            { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_suffix_id },
-            { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_middle_id },
-            { LLM_KV_TOKENIZER_EOT_ID,    vocab.special_eot_id    },
-            { LLM_KV_TOKENIZER_EOM_ID,    vocab.special_eom_id    },
+            { LLM_KV_TOKENIZER_BOS_ID,     vocab.special_bos_id     },
+            { LLM_KV_TOKENIZER_EOS_ID,     vocab.special_eos_id     },
+            { LLM_KV_TOKENIZER_EOT_ID,     vocab.special_eot_id     },
+            { LLM_KV_TOKENIZER_EOM_ID,     vocab.special_eom_id     },
+            { LLM_KV_TOKENIZER_UNK_ID,     vocab.special_unk_id     },
+            { LLM_KV_TOKENIZER_SEP_ID,     vocab.special_sep_id     },
+            { LLM_KV_TOKENIZER_PAD_ID,     vocab.special_pad_id     },
+            { LLM_KV_TOKENIZER_CLS_ID,     vocab.special_cls_id     },
+            { LLM_KV_TOKENIZER_MASK_ID,    vocab.special_mask_id    },
+            { LLM_KV_TOKENIZER_FIM_PRE_ID, vocab.special_fim_pre_id },
+            { LLM_KV_TOKENIZER_FIM_SUF_ID, vocab.special_fim_suf_id },
+            { LLM_KV_TOKENIZER_FIM_MID_ID, vocab.special_fim_mid_id },
+            { LLM_KV_TOKENIZER_FIM_PAD_ID, vocab.special_fim_pad_id },
+            { LLM_KV_TOKENIZER_FIM_REP_ID, vocab.special_fim_rep_id },
+            { LLM_KV_TOKENIZER_FIM_SEP_ID, vocab.special_fim_sep_id },
+
+            // deprecated
+            { LLM_KV_TOKENIZER_PREFIX_ID, vocab.special_fim_pre_id },
+            { LLM_KV_TOKENIZER_SUFFIX_ID, vocab.special_fim_suf_id },
+            { LLM_KV_TOKENIZER_MIDDLE_ID, vocab.special_fim_mid_id },
         };
 
         for (const auto & it : special_token_types) {
@@ -6637,46 +6568,140 @@ static void llm_load_vocab(
             }
         }
 
-        // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
-        //
-        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOT_ID
-        //       for now, we apply this workaround to find the EOT token based on its text
-        if (vocab.special_eot_id == -1) {
-            for (const auto & t : vocab.token_to_id) {
+        // auto-detect special tokens by text
+        // TODO: convert scripts should provide these tokens through the KV metadata LLM_KV_TOKENIZER_...
+        //       for now, we apply this workaround to find the tokens based on their text
+
+        for (const auto & t : vocab.token_to_id) {
+            // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc.
+            if (vocab.special_eot_id == LLAMA_TOKEN_NULL) {
                 if (false
-                        // TODO: gemma "" is exported as a normal token, so the following check does not work
-                        //       need to fix convert script
-                        //vocab.id_to_token[t.second].type == LLAMA_TOKEN_TYPE_CONTROL &&
                         || t.first == "<|eot_id|>"
                         || t.first == "<|im_end|>"
                         || t.first == "<|end|>"
                         || t.first == ""
                         || t.first == "<|endoftext|>"
                         || t.first == ""
+                        || t.first == "<|end▁of▁sentence|>" // DeepSeek
                    ) {
                     vocab.special_eot_id = t.second;
                     if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                        LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                                __func__, t.first.c_str());
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find EOM token: "<|eom_id|>"
+            if (vocab.special_eom_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|eom_id|>"
+                        ) {
+                    vocab.special_eom_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PRE token: "<|fim_prefix|>", "", "
", etc.
+            if (vocab.special_fim_pre_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_prefix|>"  // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁begin|>" // DeepSeek
+                        || t.first == "
"
+                        ) {
+                    vocab.special_fim_pre_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
                         vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                     }
-                    break;
                 }
             }
-        }
 
-        // find EOM token: "<|eom_id|>"
-        //
-        // TODO: convert scripts should provide this token through the KV metadata LLAMA_KV_TOKENIZER_EOM_ID
-        //       for now, we apply this workaround to find the EOM token based on its text
-        if (vocab.special_eom_id == -1) {
-            const auto & t = vocab.token_to_id.find("<|eom_id|>");
-            if (t != vocab.token_to_id.end()) {
-                vocab.special_eom_id = t->second;
-                if ((vocab.id_to_token[t->second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                    LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                        __func__, t->first.c_str());
-                    vocab.id_to_token[t->second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+            // find FIM_SUF token: "<|fim_suffix|>", "", "", etc.
+            if (vocab.special_fim_suf_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_suffix|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁hole|>" // DeepSeek
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_suf_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_MID token: "<|fim_middle|>", "", "", etc.
+            if (vocab.special_fim_mid_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_middle|>" // Qwen
+                        || t.first == ""
+                        || t.first == "<|fim▁end|>"  // DeepSeek
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_mid_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_PAD token: "<|fim_pad|>", "", "", etc.
+            if (vocab.special_fim_pad_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_pad|>" // Qwen
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_pad_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_REP token: "<|fim_repo|>", "", "", etc.
+            if (vocab.special_fim_rep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|fim_repo|>"  // Qwen
+                        || t.first == "<|repo_name|>"
+                        || t.first == ""
+                        || t.first == ""
+                        ) {
+                    vocab.special_fim_rep_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
+                }
+            }
+
+            // find FIM_SEP token: "<|file_sep|>"
+            if (vocab.special_fim_sep_id == LLAMA_TOKEN_NULL) {
+                if (false
+                        || t.first == "<|file_sep|>" // Qwen
+                        ) {
+                    vocab.special_fim_sep_id = t.second;
+                    if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
+                        LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                                __func__, t.second, t.first.c_str());
+                        vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
+                    }
                 }
             }
         }
@@ -6685,6 +6710,19 @@ static void llm_load_vocab(
         // this is currently determined based on the token text, which is obviously not ideal
         // ref: https://github.com/ggerganov/llama.cpp/issues/9606
         vocab.special_eog_ids.clear();
+
+        if (vocab.special_fim_pad_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_pad_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_pad_id);
+        }
+
+        if (vocab.special_fim_rep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_rep_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_rep_id);
+        }
+
+        if (vocab.special_fim_sep_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_fim_sep_id) == 0) {
+            vocab.special_eog_ids.insert(vocab.special_fim_sep_id);
+        }
+
         for (const auto & t : vocab.token_to_id) {
             if (false
                     || t.first == "<|eot_id|>"
@@ -6697,24 +6735,31 @@ static void llm_load_vocab(
                ) {
                 vocab.special_eog_ids.insert(t.second);
                 if ((vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {
-                    LLAMA_LOG_WARN("%s: control-looking token: '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
-                            __func__, t.first.c_str());
+                    LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n",
+                            __func__, t.second, t.first.c_str());
                     vocab.id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL;
                 }
+            } else {
+                // token is control, but not marked as EOG -> print a debug log
+                if (vocab.id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && vocab.special_eog_ids.count(t.second) == 0) {
+                    LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n",
+                            __func__, t.second, t.first.c_str());
+                }
             }
         }
 
-        if (vocab.special_eos_id != -1 && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
+        // sanity checks
+        if (vocab.special_eos_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eos_id) == 0) {
             vocab.special_eog_ids.insert(vocab.special_eos_id);
             LLAMA_LOG_WARN("%s: special_eos_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
         }
 
-        if (vocab.special_eot_id != -1 && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
+        if (vocab.special_eot_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eot_id) == 0) {
             vocab.special_eog_ids.insert(vocab.special_eot_id);
             LLAMA_LOG_WARN("%s: special_eot_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
         }
 
-        if (vocab.special_eom_id != -1 && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
+        if (vocab.special_eom_id != LLAMA_TOKEN_NULL && vocab.special_eog_ids.count(vocab.special_eom_id) == 0) {
             vocab.special_eog_ids.insert(vocab.special_eom_id);
             LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__);
         }
@@ -6908,20 +6953,24 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
     LLAMA_LOG_INFO("%s: general.name     = %s\n",    __func__, model.name.c_str());
 
     // special tokens
-    if (vocab.special_bos_id    != -1) { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,  vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
-    if (vocab.special_eos_id    != -1) { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,  vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
-    if (vocab.special_unk_id    != -1) { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,  vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
-    if (vocab.special_sep_id    != -1) { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,  vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
-    if (vocab.special_pad_id    != -1) { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,  vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
-    if (vocab.special_cls_id    != -1) { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,  vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
-    if (vocab.special_mask_id   != -1) { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id, vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
-
-    if (vocab.linefeed_id       != -1) { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,       vocab.id_to_token[vocab.linefeed_id].text.c_str() );       }
-    if (vocab.special_prefix_id != -1) { LLAMA_LOG_INFO( "%s: PRE token        = %d '%s'\n", __func__, vocab.special_prefix_id, vocab.id_to_token[vocab.special_prefix_id].text.c_str() ); }
-    if (vocab.special_suffix_id != -1) { LLAMA_LOG_INFO( "%s: SUF token        = %d '%s'\n", __func__, vocab.special_suffix_id, vocab.id_to_token[vocab.special_suffix_id].text.c_str() ); }
-    if (vocab.special_middle_id != -1) { LLAMA_LOG_INFO( "%s: MID token        = %d '%s'\n", __func__, vocab.special_middle_id, vocab.id_to_token[vocab.special_middle_id].text.c_str() ); }
-    if (vocab.special_eot_id    != -1) { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,    vocab.id_to_token[vocab.special_eot_id].text.c_str() );    }
-    if (vocab.special_eom_id    != -1) { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, vocab.special_eom_id,    vocab.id_to_token[vocab.special_eom_id].text.c_str() );    }
+    if (vocab.special_bos_id  != -1)    { LLAMA_LOG_INFO( "%s: BOS token        = %d '%s'\n", __func__, vocab.special_bos_id,     vocab.id_to_token[vocab.special_bos_id].text.c_str() );  }
+    if (vocab.special_eos_id  != -1)    { LLAMA_LOG_INFO( "%s: EOS token        = %d '%s'\n", __func__, vocab.special_eos_id,     vocab.id_to_token[vocab.special_eos_id].text.c_str() );  }
+    if (vocab.special_eot_id  != -1)    { LLAMA_LOG_INFO( "%s: EOT token        = %d '%s'\n", __func__, vocab.special_eot_id,     vocab.id_to_token[vocab.special_eot_id].text.c_str() );  }
+    if (vocab.special_eom_id  != -1)    { LLAMA_LOG_INFO( "%s: EOM token        = %d '%s'\n", __func__, vocab.special_eom_id,     vocab.id_to_token[vocab.special_eom_id].text.c_str() );  }
+    if (vocab.special_unk_id  != -1)    { LLAMA_LOG_INFO( "%s: UNK token        = %d '%s'\n", __func__, vocab.special_unk_id,     vocab.id_to_token[vocab.special_unk_id].text.c_str() );  }
+    if (vocab.special_sep_id  != -1)    { LLAMA_LOG_INFO( "%s: SEP token        = %d '%s'\n", __func__, vocab.special_sep_id,     vocab.id_to_token[vocab.special_sep_id].text.c_str() );  }
+    if (vocab.special_pad_id  != -1)    { LLAMA_LOG_INFO( "%s: PAD token        = %d '%s'\n", __func__, vocab.special_pad_id,     vocab.id_to_token[vocab.special_pad_id].text.c_str() );  }
+    if (vocab.special_cls_id  != -1)    { LLAMA_LOG_INFO( "%s: CLS token        = %d '%s'\n", __func__, vocab.special_cls_id,     vocab.id_to_token[vocab.special_cls_id].text.c_str() );  }
+    if (vocab.special_mask_id != -1)    { LLAMA_LOG_INFO( "%s: MASK token       = %d '%s'\n", __func__, vocab.special_mask_id,    vocab.id_to_token[vocab.special_mask_id].text.c_str() ); }
+
+    if (vocab.linefeed_id != -1)        { LLAMA_LOG_INFO( "%s: LF token         = %d '%s'\n", __func__, vocab.linefeed_id,        vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
+
+    if (vocab.special_fim_pre_id != -1) { LLAMA_LOG_INFO( "%s: FIM PRE token    = %d '%s'\n", __func__, vocab.special_fim_pre_id, vocab.id_to_token[vocab.special_fim_pre_id].text.c_str() ); }
+    if (vocab.special_fim_suf_id != -1) { LLAMA_LOG_INFO( "%s: FIM SUF token    = %d '%s'\n", __func__, vocab.special_fim_suf_id, vocab.id_to_token[vocab.special_fim_suf_id].text.c_str() ); }
+    if (vocab.special_fim_mid_id != -1) { LLAMA_LOG_INFO( "%s: FIM MID token    = %d '%s'\n", __func__, vocab.special_fim_mid_id, vocab.id_to_token[vocab.special_fim_mid_id].text.c_str() ); }
+    if (vocab.special_fim_pad_id != -1) { LLAMA_LOG_INFO( "%s: FIM PAD token    = %d '%s'\n", __func__, vocab.special_fim_pad_id, vocab.id_to_token[vocab.special_fim_pad_id].text.c_str() ); }
+    if (vocab.special_fim_rep_id != -1) { LLAMA_LOG_INFO( "%s: FIM REP token    = %d '%s'\n", __func__, vocab.special_fim_rep_id, vocab.id_to_token[vocab.special_fim_rep_id].text.c_str() ); }
+    if (vocab.special_fim_sep_id != -1) { LLAMA_LOG_INFO( "%s: FIM SEP token    = %d '%s'\n", __func__, vocab.special_fim_sep_id, vocab.id_to_token[vocab.special_fim_sep_id].text.c_str() ); }
 
     for (const auto & id : vocab.special_eog_ids) {
         LLAMA_LOG_INFO( "%s: EOG token        = %d '%s'\n", __func__, id, vocab.id_to_token[id].text.c_str() );
@@ -6987,7 +7036,14 @@ static bool llm_load_tensors(
 
     // assign cpu layers
     for (int i = 0; i < i_gpu_start; ++i) {
+#ifdef GGML_USE_AMX
+        model.buft_layer[i] = {
+            ggml_backend_amx_buffer_type(),
+            llama_default_buffer_type_cpu(model, true)
+        };
+#else
         model.buft_layer[i] = llama_default_buffer_type_cpu(model, true);
+#endif
     }
 
     if (split_mode == LLAMA_SPLIT_MODE_LAYER) {
@@ -8918,48 +8974,40 @@ static bool llm_load_tensors(
         llama_buf_map bufs;
         bufs.reserve(n_max_backend_buffer);
 
-        // only the mmap region containing the tensors in the model is mapped to the backend buffer
-        // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
-        // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
-        if (ml.use_mmap && use_mmap_buffer && buft == llama_default_buffer_type_cpu(model, true)) {
-            for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                void * addr = nullptr;
-                size_t first, last;
-                ml.get_mapping_range(&first, &last, &addr, idx, ctx);
-                if (first >= last) {
-                    continue;
-                }
-                ggml_backend_buffer_t buf = ggml_backend_cpu_buffer_from_ptr((char *) addr + first, last - first);
-                if (buf == nullptr) {
-                    throw std::runtime_error("unable to allocate backend CPU buffer");
-                }
-                model.bufs.push_back(buf);
-                bufs.emplace(idx, buf);
-            }
+        // check if this backend device supports buffer_from_host_ptr
+        // when using a host buffer as the CPU bakcend buffer, use the CPU device to prioritize using buffer_from_host_ptr over the host buffer
+        ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft == llama_default_buffer_type_cpu(model, true) ? ggml_backend_cpu_buffer_type() : buft);
+        bool buffer_from_host_ptr_supported = false;
+        if (dev) {
+            ggml_backend_dev_props props;
+            ggml_backend_dev_get_props(dev, &props);
+            buffer_from_host_ptr_supported = props.caps.buffer_from_host_ptr;
         }
-#ifdef GGML_USE_METAL
-        else if (ml.use_mmap && use_mmap_buffer && buft == ggml_backend_metal_buffer_type()) {
+
+        if (ml.use_mmap && use_mmap_buffer && buffer_from_host_ptr_supported) {
             for (uint32_t idx = 0; idx < ml.files.size(); idx++) {
-                const size_t max_size = ggml_get_max_tensor_size(ctx);
+                // only the mmap region containing the tensors in the model is mapped to the backend buffer
+                // this is important for metal with apple silicon: if the entire model could be mapped to a metal buffer, then we could just use metal for all layers
+                // this allows using partial offloading when the model size exceeds the metal buffer size, but not the RAM size
                 void * addr = nullptr;
-                size_t first, last;
+                size_t first, last; // NOLINT
                 ml.get_mapping_range(&first, &last, &addr, idx, ctx);
                 if (first >= last) {
                     continue;
                 }
-                ggml_backend_buffer_t buf = ggml_backend_metal_buffer_from_ptr((char *) addr + first, last - first, max_size);
+                const size_t max_size = ggml_get_max_tensor_size(ctx);
+                ggml_backend_buffer_t buf = ggml_backend_dev_buffer_from_host_ptr(dev, (char *) addr + first, last - first, max_size);
                 if (buf == nullptr) {
-                    throw std::runtime_error("unable to allocate backend metal buffer");
+                    throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
                 }
                 model.bufs.push_back(buf);
                 bufs.emplace(idx, buf);
             }
         }
-#endif
         else {
             ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft);
             if (buf == nullptr) {
-                throw std::runtime_error("unable to allocate backend buffer");
+                throw std::runtime_error(format("unable to allocate %s buffer", ggml_backend_buft_name(buft)));
             }
             model.bufs.push_back(buf);
             if (use_mlock && ggml_backend_buffer_is_host(buf)) {
@@ -9570,20 +9618,16 @@ static struct ggml_tensor * llm_build_kqv(
         cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
                                   hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_GEMMA2) {
-            ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
-        }
+        ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
 
         cur = ggml_reshape_2d(ctx, cur, n_embd_head_v*n_head, n_tokens);
     } else {
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
         cb(kq, "kq", il);
 
-        if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2 || model.arch == LLM_ARCH_NEMOTRON || model.arch == LLM_ARCH_CHATGLM) {
-            // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
-            // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
-            ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
-        }
+        // note: this op tends to require high floating point range
+        //       while for some models F16 is enough, for others it is not, so we default to F32 here
+        ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
 
         if (model.arch == LLM_ARCH_GROK) {
             // need to do the following:
@@ -9592,9 +9636,6 @@ static struct ggml_tensor * llm_build_kqv(
             // kq = 30 * tanh(kq / 30)
             // before the softmax below
 
-            //try from phi2
-            //ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
-
             kq = ggml_tanh(ctx, ggml_scale(ctx, kq, 0.08838834764831845f/30.0f));
             kq = ggml_scale(ctx, kq, 30);
         }
@@ -10020,7 +10061,7 @@ struct llm_build_context {
           llama_context  & lctx;
     const llama_hparams  & hparams;
     const llama_cparams  & cparams;
-    const llama_ubatch   & batch;
+    const llama_ubatch   & ubatch;
     const llama_kv_cache & kv_self;
 
     const int64_t n_embd;
@@ -10066,14 +10107,14 @@ struct llm_build_context {
     // TODO: consider making the entire interface noexcept
     llm_build_context(
         llama_context  & lctx,
-    const llama_ubatch & batch,
+    const llama_ubatch & ubatch,
     const llm_build_cb & cb,
                   bool   worst_case) :
         model            (lctx.model),
         lctx             (lctx),
         hparams          (model.hparams),
         cparams          (lctx.cparams),
-        batch            (batch),
+        ubatch           (ubatch),
         kv_self          (lctx.kv_self),
         n_embd           (hparams.n_embd),
         n_layer          (hparams.n_layer),
@@ -10095,7 +10136,7 @@ struct llm_build_context {
         beta_slow        (cparams.yarn_beta_slow),
         norm_eps         (hparams.f_norm_eps),
         norm_rms_eps     (hparams.f_norm_rms_eps),
-        n_tokens         (batch.n_tokens),
+        n_tokens         (ubatch.n_tokens),
         n_kv             (worst_case ? kv_self.size : kv_self.n),
         n_outputs        (worst_case ? n_tokens : lctx.n_outputs),
         n_outputs_enc    (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
@@ -10464,7 +10505,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -10624,7 +10665,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = model.type == MODEL_7B ? build_inp_pos() : nullptr;
@@ -10739,7 +10780,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -10843,7 +10884,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -10965,7 +11006,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // multiply by embedding_multiplier_scale of 78.38367176906169
         inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
@@ -11123,7 +11164,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -11245,7 +11286,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -11348,7 +11389,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -11450,7 +11491,7 @@ struct llm_build_context {
         }
 
         // construct input embeddings (token, type, position)
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // token types are hardcoded to zero ("Sentence A")
         struct ggml_tensor * type_row0 = ggml_view_1d(ctx0, model.type_embd, n_embd, 0);
@@ -11637,7 +11678,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -11739,7 +11780,7 @@ struct llm_build_context {
         struct ggml_tensor * pos;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -11877,7 +11918,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12027,7 +12068,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12140,7 +12181,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12255,7 +12296,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12400,7 +12441,7 @@ struct llm_build_context {
         struct ggml_tensor * ffn_output;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12519,7 +12560,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12647,7 +12688,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12752,7 +12793,7 @@ struct llm_build_context {
         struct ggml_tensor * pos;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12857,7 +12898,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -12967,7 +13008,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -13085,7 +13126,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -13212,7 +13253,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // scale the input embeddings
         inpL = ggml_scale(ctx0, inpL, scale_embd);
@@ -13356,7 +13397,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // scale the input embeddings
         inpL = ggml_scale(ctx0, inpL, scale_embd);
@@ -13557,7 +13598,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
         cb(inpL, "inp_scaled", -1);
@@ -13665,7 +13706,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd));
         cb(inpL, "inp_scaled", -1);
@@ -13803,7 +13844,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -13919,7 +13960,7 @@ struct llm_build_context {
         struct ggml_tensor * inpL;
 
         // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         struct ggml_tensor * state_copy = build_inp_s_copy();
         struct ggml_tensor * state_mask = build_inp_s_mask();
@@ -13931,7 +13972,7 @@ struct llm_build_context {
                     LLM_NORM_RMS, cb, il);
             cb(cur, "attn_norm", il);
 
-            cur = llm_build_mamba(ctx0, lctx, batch, gf, cur,
+            cur = llm_build_mamba(ctx0, lctx, ubatch, gf, cur,
                     state_copy, state_mask,
                     kv_head, n_kv, cb, il);
 
@@ -13977,7 +14018,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -14134,7 +14175,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -14262,7 +14303,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -14381,7 +14422,7 @@ struct llm_build_context {
 
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -14508,7 +14549,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -14653,7 +14694,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -14794,7 +14835,7 @@ struct llm_build_context {
         struct ggml_tensor * inpL;
 
         // {n_embd, n_tokens}
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -15009,7 +15050,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -15163,7 +15204,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         GGML_ASSERT(lctx.is_encoding);
         struct ggml_tensor * pos_bucket_enc = llm_build_pos_bucket(false);
@@ -15295,7 +15336,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         GGML_ASSERT(!lctx.is_encoding);
         GGML_ASSERT(n_outputs_enc > 0 && "call llama_encode() first");
@@ -15497,7 +15538,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
@@ -15589,7 +15630,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -15703,7 +15744,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -15827,7 +15868,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -15947,11 +15988,11 @@ struct llm_build_context {
         // Token shift state dimensions should be 2 * n_emb
         GGML_ASSERT(n_embd == hparams.n_embd_k_s() / 2);
 
-        const int64_t n_seqs = batch.n_seqs;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_seqs = ubatch.n_seqs;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
         GGML_ASSERT(n_seqs != 0);
-        GGML_ASSERT(batch.equal_seqs);
+        GGML_ASSERT(ubatch.equal_seqs);
         GGML_ASSERT(n_tokens == n_seq_tokens * n_seqs);
 
         struct ggml_tensor * cur;
@@ -15959,7 +16000,7 @@ struct llm_build_context {
         struct ggml_tensor * state_copy = build_inp_s_copy();
         struct ggml_tensor * state_mask = build_inp_s_mask();
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
         inpL = llm_build_norm(ctx0, inpL, hparams, model.tok_norm, model.tok_norm_b, LLM_NORM, cb, -1);
 
         for (int il = 0; il < n_layer; ++il) {
@@ -16044,9 +16085,11 @@ struct llm_build_context {
         cur = ggml_get_rows(ctx0, cur, inp_out_ids);
 
         cur = llm_build_norm(ctx0, cur, hparams, model.output_norm, model.output_norm_b, LLM_NORM, cb, -1);
-        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
+        cb(cur, "result_norm", -1);
 
+        cur = llm_build_lora_mm(lctx, ctx0, model.output, cur);
         cb(cur, "result_output", -1);
+
         ggml_build_forward_expand(gf, cur);
 
         return gf;
@@ -16071,7 +16114,7 @@ struct llm_build_context {
         struct ggml_tensor * cur;
         struct ggml_tensor * inpL;
 
-        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
 
         // inp_pos - contains the positions
         struct ggml_tensor * inp_pos = build_inp_pos();
@@ -16267,7 +16310,7 @@ static struct ggml_cgraph * llama_build_graph_k_shift(llama_context & lctx) {
 
 static struct ggml_cgraph * llama_build_graph(
          llama_context & lctx,
-    const llama_ubatch & batch,
+    const llama_ubatch & ubatch,
                   bool   worst_case) {
     const auto & model = lctx.model;
 
@@ -16289,7 +16332,7 @@ static struct ggml_cgraph * llama_build_graph(
         // norm may be automatically assigned to the backend of the previous layer, increasing data transfer between backends
         // FIXME: fix in ggml_backend_sched
         const bool full_offload = lctx.model.n_gpu_layers > (int)lctx.model.hparams.n_layer;
-        if (batch.n_tokens < 32 || full_offload) {
+        if (ubatch.n_tokens < 32 || full_offload) {
             if (il != -1 && strcmp(name, "norm") == 0) {
                 for (auto * backend : lctx.backends) {
                     if (ggml_backend_supports_buft(backend, lctx.model.buft_layer[il].buft) &&
@@ -16304,7 +16347,7 @@ static struct ggml_cgraph * llama_build_graph(
 
     struct ggml_cgraph * result = NULL;
 
-    struct llm_build_context llm(lctx, batch, cb, worst_case);
+    struct llm_build_context llm(lctx, ubatch, cb, worst_case);
 
     llm.init();
 
@@ -16555,7 +16598,7 @@ static int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t
     return relative_bucket;
 }
 
-static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
+static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch) {
     //
     // set input data
     //
@@ -16564,28 +16607,28 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     const auto & cparams = lctx.cparams;
     const auto & kv_self = lctx.kv_self;
 
-    if (batch.token) {
-        const int64_t n_tokens = batch.n_tokens;
+    if (ubatch.token) {
+        const int64_t n_tokens = ubatch.n_tokens;
 
-        ggml_backend_tensor_set(lctx.inp_tokens, batch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
+        ggml_backend_tensor_set(lctx.inp_tokens, ubatch.token, 0, n_tokens*ggml_element_size(lctx.inp_tokens));
     }
 
-    if (batch.embd) {
+    if (ubatch.embd) {
         const int64_t n_embd   = hparams.n_embd;
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
 
-        ggml_backend_tensor_set(lctx.inp_embd, batch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
+        ggml_backend_tensor_set(lctx.inp_embd, ubatch.embd, 0, n_tokens*n_embd*ggml_element_size(lctx.inp_embd));
     }
 
-    if (batch.pos && lctx.inp_pos) {
-        const int64_t n_tokens = batch.n_tokens;
+    if (ubatch.pos && lctx.inp_pos) {
+        const int64_t n_tokens = ubatch.n_tokens;
 
-        ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
+        ggml_backend_tensor_set(lctx.inp_pos, ubatch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos));
     }
 
     if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) {
         GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs");
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_out_ids->buffer));
         int32_t * data = (int32_t *) lctx.inp_out_ids->data;
@@ -16594,10 +16637,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             for (int i = 0; i < n_tokens; ++i) {
                 data[i] = i;
             }
-        } else if (batch.output) {
+        } else if (ubatch.output) {
             int32_t n_outputs = 0;
             for (int i = 0; i < n_tokens; ++i) {
-                if (batch.output[i]) {
+                if (ubatch.output[i]) {
                     data[n_outputs++] = i;
                 }
             }
@@ -16622,9 +16665,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         // NOTE: hparams.causal_attn indicates the model is capable of generation and uses the kv cache.
         if (cparams.causal_attn && !lctx.is_encoding) {
             const int64_t n_kv         = kv_self.n;
-            const int64_t n_tokens     = batch.n_tokens;
-            const int64_t n_seq_tokens = batch.n_seq_tokens;
-            const int64_t n_seqs       = batch.n_seqs;
+            const int64_t n_tokens     = ubatch.n_tokens;
+            const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+            const int64_t n_seqs       = ubatch.n_seqs;
 
 
             float * data     = nullptr;
@@ -16641,14 +16684,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             }
 
             // For causal attention, use only the previous KV cells
-            // of the correct sequence for each token of the batch.
+            // of the correct sequence for each token of the ubatch.
             // It's assumed that if a token in the batch has multiple sequences, they are equivalent.
             for (int h = 0; h < 1; ++h) {
                 for (int s = 0; s < n_seqs; ++s) {
-                    const llama_seq_id seq_id = batch.seq_id[s][0];
+                    const llama_seq_id seq_id = ubatch.seq_id[s][0];
 
                     for (int j = 0; j < n_seq_tokens; ++j) {
-                        const llama_pos pos = batch.pos[s*n_seq_tokens + j];
+                        const llama_pos pos = ubatch.pos[s*n_seq_tokens + j];
 
                         for (int i = 0; i < n_kv; ++i) {
                             float f;
@@ -16694,9 +16737,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
                 }
             }
         } else {
-            const int64_t n_tokens     = batch.n_tokens;
-            const int64_t n_seq_tokens = batch.n_seq_tokens;
-            const int64_t n_seqs       = batch.n_seqs;
+            const int64_t n_tokens     = ubatch.n_tokens;
+            const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+            const int64_t n_seqs       = ubatch.n_seqs;
             // when using kv cache, the mask needs to match the kv cache size
             const int64_t n_stride = hparams.causal_attn && !lctx.is_encoding ? kv_self.n : n_tokens;
 
@@ -16706,7 +16749,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
 
             for (int h = 0; h < 1; ++h) {
                 for (int s1 = 0; s1 < n_seqs; ++s1) {
-                    const llama_seq_id seq_id = batch.seq_id[s1][0];
+                    const llama_seq_id seq_id = ubatch.seq_id[s1][0];
 
                     for (int j = 0; j < n_seq_tokens; ++j) {
                         const int32_t tj = s1*n_seq_tokens + j;
@@ -16716,10 +16759,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
                                 const int32_t ti = s0*n_seq_tokens + i;
                                 float f = -INFINITY;
 
-                                for (int s = 0; s < batch.n_seq_id[s0]; ++s) {
-                                    if (batch.seq_id[s0][s] == seq_id) {
+                                for (int s = 0; s < ubatch.n_seq_id[s0]; ++s) {
+                                    if (ubatch.seq_id[s0][s] == seq_id) {
                                         if (hparams.use_alibi) {
-                                            f = -std::abs(batch.pos[ti] - batch.pos[tj]);
+                                            f = -std::abs(ubatch.pos[ti] - ubatch.pos[tj]);
                                         } else {
                                             f = 0.0f;
                                         }
@@ -16741,9 +16784,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
-        const int64_t n_tokens     = batch.n_tokens;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_seqs       = batch.n_seqs;
+        const int64_t n_tokens     = ubatch.n_tokens;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_seqs       = ubatch.n_seqs;
 
         GGML_ASSERT(lctx.inp_mean);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_mean->buffer));
@@ -16754,12 +16797,12 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         std::vector sum(n_tokens, 0);
 
         for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
+            const llama_seq_id seq_id = ubatch.seq_id[s][0];
 
-            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
+            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == MEAN");
 
-            sum[seq_id] += batch.n_seq_tokens;
+            sum[seq_id] += ubatch.n_seq_tokens;
         }
 
         std::vector div(n_tokens, 0.0f);
@@ -16771,7 +16814,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         }
 
         for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
+            const llama_seq_id seq_id = ubatch.seq_id[s][0];
 
             for (int i = 0; i < n_seq_tokens; ++i) {
                 data[seq_id*n_tokens + s*n_seq_tokens + i] = div[seq_id];
@@ -16782,9 +16825,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     if (cparams.embeddings && (
                 cparams.pooling_type == LLAMA_POOLING_TYPE_CLS ||
                 cparams.pooling_type == LLAMA_POOLING_TYPE_RANK)) {
-        const int64_t n_tokens     = batch.n_tokens;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_seqs       = batch.n_seqs;
+        const int64_t n_tokens     = ubatch.n_tokens;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_seqs       = ubatch.n_seqs;
 
         GGML_ASSERT(lctx.inp_cls);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -16793,13 +16836,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls));
 
         for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
+            const llama_seq_id seq_id = ubatch.seq_id[s][0];
 
-            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
+            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == CLS or RANK");
 
             for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = batch.pos[s*n_seq_tokens + i];
+                const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
 
                 if (pos == 0) {
                     data[seq_id] = s*n_seq_tokens + i;
@@ -16809,9 +16852,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     }
 
     if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) {
-        const int64_t n_tokens     = batch.n_tokens;
-        const int64_t n_seq_tokens = batch.n_seq_tokens;
-        const int64_t n_seqs       = batch.n_seqs;
+        const int64_t n_tokens     = ubatch.n_tokens;
+        const int64_t n_seq_tokens = ubatch.n_seq_tokens;
+        const int64_t n_seqs       = ubatch.n_seqs;
 
         GGML_ASSERT(lctx.inp_cls);
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer));
@@ -16823,13 +16866,13 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
         std::vector last_row(n_tokens, -1);
 
         for (int s = 0; s < n_seqs; ++s) {
-            const llama_seq_id seq_id = batch.seq_id[s][0];
+            const llama_seq_id seq_id = ubatch.seq_id[s][0];
 
-            // TODO: adapt limits to n_seqs when batch.equal_seqs is true
+            // TODO: adapt limits to n_seqs when ubatch.equal_seqs is true
             GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST");
 
             for (int i = 0; i < n_seq_tokens; ++i) {
-                const llama_pos pos = batch.pos[s*n_seq_tokens + i];
+                const llama_pos pos = ubatch.pos[s*n_seq_tokens + i];
 
                 if (pos >= last_pos[seq_id]) {
                     last_pos[seq_id] = pos;
@@ -16891,10 +16934,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
     }
 
     if (lctx.inp_pos_bucket) {
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_pos_bucket->buffer));
-        GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
+        GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
 
         int32_t * data = (int32_t *) lctx.inp_pos_bucket->data;
 
@@ -16903,7 +16946,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             for (int h = 0; h < 1; ++h) {
                 for (int j = 0; j < n_tokens; ++j) {
                     for (int i = 0; i < n_kv; ++i) {
-                        data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
+                        data[h*(n_kv*n_tokens) + j*n_kv + i] = llama_relative_position_bucket(lctx.kv_self.cells[i].pos, ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
                     }
                 }
             }
@@ -16911,7 +16954,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             for (int h = 0; h < 1; ++h) {
                 for (int j = 0; j < n_tokens; ++j) {
                     for (int i = 0; i < n_tokens; ++i) {
-                        data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(batch.pos[i], batch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
+                        data[h*(n_tokens*n_tokens) + j*n_tokens + i] = llama_relative_position_bucket(ubatch.pos[i], ubatch.pos[j], hparams.n_rel_attn_bkts, lctx.is_encoding);
                     }
                 }
             }
@@ -16927,10 +16970,10 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
 
     if (!lctx.is_encoding && lctx.inp_KQ_mask_cross) {
         const int64_t n_output_enc = lctx.embd_enc.size() / hparams.n_embd;
-        const int64_t n_tokens = batch.n_tokens;
+        const int64_t n_tokens = ubatch.n_tokens;
 
         GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_mask_cross->buffer));
-        GGML_ASSERT(!batch.equal_seqs); // TODO: use batch.n_seqs instead of failing
+        GGML_ASSERT(!ubatch.equal_seqs); // TODO: use ubatch.n_seqs instead of failing
 
         float * data = (float *) lctx.inp_KQ_mask_cross->data;
 
@@ -16938,8 +16981,8 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & batch) {
             for (int j = 0; j < n_tokens; ++j) {
                 for (int i = 0; i < n_output_enc; ++i) {
                     float f = -INFINITY;
-                    for (int s = 0; s < batch.n_seq_id[j]; ++s) {
-                        const llama_seq_id seq_id = batch.seq_id[j][s];
+                    for (int s = 0; s < ubatch.n_seq_id[j]; ++s) {
+                        const llama_seq_id seq_id = ubatch.seq_id[j][s];
                         if (lctx.seq_ids_enc[i].find(seq_id) != lctx.seq_ids_enc[i].end()) {
                             f = 0.0f;
                         }
@@ -17068,17 +17111,19 @@ static void llama_graph_compute(
                     int   n_threads,
         ggml_threadpool * threadpool) {
     if (lctx.backend_cpu != nullptr) {
-        ggml_backend_cpu_set_n_threads(lctx.backend_cpu, n_threads);
         ggml_backend_cpu_set_threadpool(lctx.backend_cpu, threadpool);
         ggml_backend_cpu_set_abort_callback(lctx.backend_cpu, lctx.abort_callback, lctx.abort_callback_data);
     }
-#ifdef GGML_USE_BLAS
-    if (lctx.backend_blas != nullptr) {
-        ggml_backend_blas_set_n_threads(lctx.backend_blas, n_threads);
+
+    // set the number of threads for all the backends
+    for (const auto & set_n_threads_fn : lctx.set_n_threads_fns) {
+        set_n_threads_fn.second(set_n_threads_fn.first, n_threads);
     }
-#endif
 
-    ggml_backend_sched_graph_compute_async(lctx.sched, gf);
+    auto err = ggml_backend_sched_graph_compute_async(lctx.sched, gf);
+    if (err != GGML_STATUS_SUCCESS) {
+        LLAMA_LOG_ERROR("%s: ggml_backend_sched_graph_compute_async failed with error %d\n", __func__, err);
+    }
 
     // fprintf(stderr, "splits: %d\n", ggml_backend_sched_get_n_splits(lctx.sched));
 }
@@ -17094,26 +17139,30 @@ static void llama_graph_compute(
 //
 static int llama_decode_internal(
          llama_context & lctx,
-           llama_batch   batch_all) { // TODO: rename back to batch
+           llama_batch   inp_batch) {
 
     lctx.is_encoding = false;
-    const uint32_t n_tokens_all = batch_all.n_tokens;
 
-    if (n_tokens_all == 0) {
+    if (inp_batch.n_tokens == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
+    // temporary allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(lctx, inp_batch);
+    const llama_batch & batch = batch_allocr.batch;
+    const uint32_t n_tokens_all = batch.n_tokens;
+
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
 
-    GGML_ASSERT((!batch_all.token && batch_all.embd) || (batch_all.token && !batch_all.embd)); // NOLINT
+    GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT
 
-    if (batch_all.token) {
+    if (batch.token) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            if (batch_all.token[i] < 0 || (uint32_t)batch_all.token[i] >= model.vocab.n_vocab) {
-                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch_all.token[i]);
+            if (batch.token[i] < 0 || (uint32_t)batch.token[i] >= model.vocab.n_vocab) {
+                LLAMA_LOG_ERROR("%s: invalid token[%d] = %d\n", __func__, i, batch.token[i]);
                 return -1;
             }
         }
@@ -17144,9 +17193,9 @@ static int llama_decode_internal(
     lctx.embd_seq.clear();
 
     // count outputs
-    if (batch_all.logits && !embd_pooled) {
+    if (batch.logits && !embd_pooled) {
         for (uint32_t i = 0; i < n_tokens_all; ++i) {
-            n_outputs += batch_all.logits[i] != 0;
+            n_outputs += batch.logits[i] != 0;
         }
     } else if (lctx.logits_all || embd_pooled) {
         n_outputs = n_tokens_all;
@@ -17155,7 +17204,7 @@ static int llama_decode_internal(
         n_outputs = 1;
     }
 
-    lctx.sbatch.from_batch(batch_all, n_embd,
+    lctx.sbatch.from_batch(batch, n_embd,
         /* simple_split */ !kv_self.recurrent,
         /* logits_all   */ n_outputs == n_tokens_all);
 
@@ -17408,17 +17457,20 @@ static int llama_decode_internal(
 //
 static int llama_encode_internal(
          llama_context & lctx,
-           llama_batch   batch) {
+           llama_batch   inp_batch) {
 
     lctx.is_encoding = true;
 
-    const uint32_t n_tokens = batch.n_tokens;
-
-    if (n_tokens == 0) {
+    if (inp_batch.n_tokens == 0) {
         LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
         return -1;
     }
 
+    // temporary allocate memory for the input batch if needed
+    llama_batch_allocr batch_allocr(lctx, inp_batch);
+    const llama_batch & batch = batch_allocr.batch;
+    const uint32_t n_tokens = batch.n_tokens;
+
     const auto & model   = lctx.model;
     const auto & hparams = model.hparams;
     const auto & cparams = lctx.cparams;
@@ -17894,10 +17946,9 @@ static void llama_tensor_dequantize_internal(
     }
     float * f32_output = (float *) output.data();
 
-    ggml_type_traits_t qtype;
+    const ggml_type_traits * qtype = ggml_get_type_traits(tensor->type);
     if (ggml_is_quantized(tensor->type)) {
-        qtype = ggml_internal_get_type_traits(tensor->type);
-        if (qtype.to_float == NULL) {
+        if (qtype->to_float == NULL) {
             throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(tensor->type)));
         }
     } else if (tensor->type != GGML_TYPE_F16 &&
@@ -17911,7 +17962,7 @@ static void llama_tensor_dequantize_internal(
         } else if (tensor->type == GGML_TYPE_BF16) {
             ggml_bf16_to_fp32_row((ggml_bf16_t *)tensor->data, f32_output, nelements);
         } else if (ggml_is_quantized(tensor->type)) {
-            qtype.to_float(tensor->data, f32_output, nelements);
+            qtype->to_float(tensor->data, f32_output, nelements);
         } else {
             GGML_ABORT("fatal error"); // unreachable
         }
@@ -17947,7 +17998,7 @@ static void llama_tensor_dequantize_internal(
             } else if (typ == GGML_TYPE_BF16) {
                 ggml_bf16_to_fp32_row((ggml_bf16_t *)inbuf, outbuf, nels);
             } else {
-                qtype.to_float(inbuf, outbuf, nels);
+                qtype->to_float(inbuf, outbuf, nels);
             }
         };
         workers.emplace_back(compute, tensor->type, (uint8_t *) tensor->data + in_buff_offs, f32_output + out_buff_offs, thr_elems);
@@ -19041,16 +19092,20 @@ bool llama_supports_mlock(void) {
 }
 
 bool llama_supports_gpu_offload(void) {
-#if defined(GGML_USE_METAL)   || defined(GGML_USE_VULKAN) || \
-    defined(GGML_USE_SYCL) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_RPC)
+#if defined(GGML_USE_KOMPUTE)
     // Defined when llama.cpp is compiled with support for offloading model layers to GPU.
     return true;
 #else
     return ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU) != nullptr ||
-        ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU_FULL) != nullptr;
+           ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU_FULL) != nullptr ||
+           llama_supports_rpc();
 #endif
 }
 
+bool llama_supports_rpc(void) {
+    return ggml_backend_reg_by_name("RPC") != nullptr;
+}
+
 void llama_backend_init(void) {
     ggml_time_init();
 
@@ -19125,14 +19180,56 @@ struct llama_model * llama_load_model_from_file(
         model->rpc_servers.push_back(servers);
     }
 
+    // add RPC devices
+    if (!model->rpc_servers.empty()) {
+        ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC");
+        if (!rpc_reg) {
+            LLAMA_LOG_ERROR("%s: failed to find RPC backend\n", __func__);
+            llama_free_model(model);
+            return nullptr;
+        }
+
+        // ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
+        using ggml_backend_rpc_add_device_t = ggml_backend_dev_t (*)(const char *);
+        ggml_backend_rpc_add_device_t ggml_backend_rpc_add_device_fn = (ggml_backend_rpc_add_device_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_device");
+        if (!ggml_backend_rpc_add_device_fn) {
+            LLAMA_LOG_ERROR("%s: failed to find RPC device add function\n", __func__);
+            llama_free_model(model);
+            return nullptr;
+        }
+
+        for (const std::string & server : model->rpc_servers) {
+            ggml_backend_dev_t dev = ggml_backend_rpc_add_device_fn(server.c_str());
+            if (dev) {
+                model->devices.push_back(dev);
+            } else {
+                LLAMA_LOG_ERROR("%s: failed to add RPC device for server '%s'\n", __func__, server.c_str());
+                llama_free_model(model);
+                return nullptr;
+            }
+        }
+    }
+
     // create list of devices to use with this model
     // currently, we use all available devices
     // TODO: rework API to give user more control over device selection
     for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
         ggml_backend_dev_t dev = ggml_backend_dev_get(i);
-        // skip the CPU backend since it is handled separately
-        if (ggml_backend_dev_type(dev) != GGML_BACKEND_DEVICE_TYPE_CPU_FULL) {
-            model->devices.push_back(dev);
+        switch (ggml_backend_dev_type(dev)) {
+            case GGML_BACKEND_DEVICE_TYPE_CPU:
+            case GGML_BACKEND_DEVICE_TYPE_CPU_FULL:
+                // skip CPU backends since they are `handled separately
+                break;
+
+            case GGML_BACKEND_DEVICE_TYPE_GPU:
+            case GGML_BACKEND_DEVICE_TYPE_GPU_FULL:
+            {
+                size_t free, total; // NOLINT
+                ggml_backend_dev_memory(dev, &free, &total);
+                LLAMA_LOG_INFO("%s: using device %s (%s) - %zu MiB free\n", __func__, ggml_backend_dev_name(dev), ggml_backend_dev_description(dev), free/1024/1024);
+                model->devices.push_back(dev);
+                break;
+            }
         }
     }
 
@@ -19144,7 +19241,7 @@ struct llama_model * llama_load_model_from_file(
         } else if (status == -2) {
             LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
         }
-        delete model;
+        llama_free_model(model);
         return nullptr;
     }
 
@@ -19184,7 +19281,7 @@ struct llama_context * llama_new_context_with_model(
         params.flash_attn = false;
     }
 
-    if (params.type_v != GGML_TYPE_F16 && !params.flash_attn) {
+    if (ggml_is_quantized(params.type_v) && !params.flash_attn) {
         LLAMA_LOG_ERROR("%s: V cache quantization requires flash_attn\n", __func__);
         return nullptr;
     }
@@ -19327,81 +19424,7 @@ struct llama_context * llama_new_context_with_model(
             main_gpu -= (int)model->devices.size();
         }
 
-#if defined(GGML_USE_RPC)
-        if (model->n_gpu_layers > 0) {
-            for (const auto & endpoint : model->rpc_servers) {
-                ggml_backend_t backend = ggml_backend_rpc_init(endpoint.c_str());
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize RPC to '%s'\n", __func__, endpoint.c_str());
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-        if (main_gpu >= (int)model->rpc_servers.size()) {
-            main_gpu -= (int)model->rpc_servers.size();
-        }
-#endif
-
-#if defined(GGML_USE_METAL)
-        if (model->n_gpu_layers > 0) {
-            ctx->backend_metal = ggml_backend_metal_init();
-            if (ctx->backend_metal == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Metal backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(ctx->backend_metal);
-        }
-#elif defined(GGML_USE_VULKAN)
-        if (model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            LLAMA_LOG_ERROR("%s: Row split not supported. Failed to initialize Vulkan backend\n", __func__);
-            llama_free(ctx);
-            return nullptr;
-        }
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE) {
-            ggml_backend_t backend = ggml_backend_vk_init(main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize Vulkan backend\n", __func__);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) {
-                ggml_backend_t backend = ggml_backend_vk_init(device);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize Vulkan%d backend\n", __func__, device);
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#elif defined(GGML_USE_SYCL)
-        // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            ggml_backend_t backend = ggml_backend_sycl_init(main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d backend\n", __func__, main_gpu);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            // LLAMA_SPLIT_LAYER requires a backend for each GPU
-            for (int i = 0; i < ggml_backend_sycl_get_device_count(); ++i) {
-                ggml_backend_t backend = ggml_backend_sycl_init(i);
-                if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize SYCL%d for No.%d backend\n", __func__, i, i);
-                    llama_free(ctx);
-                    return nullptr;
-                }
-                ctx->backends.push_back(backend);
-            }
-        }
-#elif defined(GGML_USE_KOMPUTE)
+#if defined(GGML_USE_KOMPUTE)
         if (model->n_gpu_layers > 0) {
             auto * backend = ggml_backend_kompute_init(main_gpu);
             if (backend == nullptr) {
@@ -19411,40 +19434,21 @@ struct llama_context * llama_new_context_with_model(
             }
             ctx->backends.push_back(backend);
         }
-#elif defined(GGML_USE_CANN)
-        // with split_mode LLAMA_SPLIT_MODE_NONE or LLAMA_SPLIT_MODE_ROW, only the main GPU backend is used
-        // TODO: ggml_backend_cann is not support split tensor now, just leave code here.
-        if (model->split_mode == LLAMA_SPLIT_MODE_NONE || model->split_mode == LLAMA_SPLIT_MODE_ROW) {
-            ggml_backend_t backend = ggml_backend_cann_init(main_gpu);
-            if (backend == nullptr) {
-                LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, main_gpu);
-                llama_free(ctx);
-                return nullptr;
-            }
-            ctx->backends.push_back(backend);
-        } else {
-            // LLAMA_SPLIT_MODE_LAYER requires a backend for each GPU
-            // TODO: currently, CANN can't use multi-gpus, just leave code here for further cann version.
-            for (int32_t device = 0; device < ggml_backend_cann_get_device_count(); ++device) {
-                ggml_backend_t backend = ggml_backend_cann_init(device);
+#endif
+
+        // add other backends (such as BLAS)
+        for (size_t i = 0; i < ggml_backend_dev_count(); ++i) {
+            ggml_backend_dev_t dev = ggml_backend_dev_get(i);
+            if (ggml_backend_dev_type(dev) == GGML_BACKEND_DEVICE_TYPE_CPU) {
+                ggml_backend_t backend = ggml_backend_dev_init(dev, nullptr);
                 if (backend == nullptr) {
-                    LLAMA_LOG_ERROR("%s: failed to initialize CANN%d backend\n", __func__, device);
+                    LLAMA_LOG_ERROR("%s: failed to initialize %s backend\n", __func__, ggml_backend_dev_name(dev));
                     llama_free(ctx);
                     return nullptr;
                 }
                 ctx->backends.push_back(backend);
             }
         }
-#endif
-
-#ifdef GGML_USE_BLAS
-        ctx->backend_blas = ggml_backend_blas_init();
-        if (ctx->backend_blas == nullptr) {
-            LLAMA_LOG_WARN("%s: failed to initialize BLAS backend\n", __func__);
-        } else {
-            ctx->backends.push_back(ctx->backend_blas);
-        }
-#endif
 
         ctx->backend_cpu = ggml_backend_cpu_init();
         if (ctx->backend_cpu == nullptr) {
@@ -19454,6 +19458,18 @@ struct llama_context * llama_new_context_with_model(
         }
         ctx->backends.push_back(ctx->backend_cpu);
 
+        // create a list of the set_n_threads functions in the backends
+        for (auto * backend : ctx->backends) {
+            ggml_backend_dev_t dev = ggml_backend_get_device(backend);
+            ggml_backend_reg_t reg = dev ? ggml_backend_dev_backend_reg(dev) : nullptr;
+            if (reg) {
+                auto ggml_backend_set_n_threads_fn = (ggml_backend_set_n_threads_t) ggml_backend_reg_get_proc_address(reg, "ggml_backend_set_n_threads");
+                if (ggml_backend_set_n_threads_fn) {
+                    ctx->set_n_threads_fns.emplace_back(backend, ggml_backend_set_n_threads_fn);
+                }
+            }
+        }
+
         if (!llama_kv_cache_init(ctx->kv_self, ctx, type_k, type_v, kv_size, cparams.offload_kqv)) {
             LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
             llama_free(ctx);
@@ -19473,7 +19489,7 @@ struct llama_context * llama_new_context_with_model(
             }
 
             LLAMA_LOG_INFO("%s: KV self size  = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
-                (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
+                      (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f),
                 ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
                 ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
         }
@@ -21069,9 +21085,7 @@ void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) {
 
 struct llama_batch llama_batch_get_one(
              llama_token * tokens,
-                 int32_t   n_tokens,
-               llama_pos   pos_0,
-            llama_seq_id   seq_id) {
+                 int32_t   n_tokens) {
     return {
         /*n_tokens       =*/ n_tokens,
         /*tokens         =*/ tokens,
@@ -21080,9 +21094,6 @@ struct llama_batch llama_batch_get_one(
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
         /*logits         =*/ nullptr,
-        /*all_pos_0      =*/ pos_0,
-        /*all_pos_1      =*/ 1,
-        /*all_seq_id     =*/ seq_id,
     };
 }
 
@@ -21095,9 +21106,6 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_
         /*n_seq_id       =*/ nullptr,
         /*seq_id         =*/ nullptr,
         /*logits         =*/ nullptr,
-        /*all_pos_0      =*/ 0,
-        /*all_pos_1      =*/ 0,
-        /*all_seq_id     =*/ 0,
     };
 
     if (embd) {
@@ -21137,7 +21145,7 @@ int32_t llama_encode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
     const int ret = llama_encode_internal(*ctx, batch);
-    if (ret < 0) {
+    if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
     }
 
@@ -21148,7 +21156,7 @@ int32_t llama_decode(
         struct llama_context * ctx,
           struct llama_batch   batch) {
     const int ret = llama_decode_internal(*ctx, batch);
-    if (ret < 0) {
+    if (ret != 0) {
         LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
     }
 
@@ -21327,6 +21335,10 @@ llama_token llama_token_eos(const struct llama_model * model) {
     return llama_token_eos_impl(model->vocab);
 }
 
+llama_token llama_token_eot(const struct llama_model * model) {
+    return llama_token_eot_impl(model->vocab);
+}
+
 llama_token llama_token_cls(const struct llama_model * model) {
     return llama_token_cls_impl(model->vocab);
 }
@@ -21363,8 +21375,28 @@ llama_token llama_token_suffix(const struct llama_model * model) {
     return llama_token_suffix_impl(model->vocab);
 }
 
-llama_token llama_token_eot(const struct llama_model * model) {
-    return llama_token_eot_impl(model->vocab);
+llama_token llama_token_fim_pre(const struct llama_model * model) {
+    return llama_token_fim_pre_impl(model->vocab);
+}
+
+llama_token llama_token_fim_suf(const struct llama_model * model) {
+    return llama_token_fim_suf_impl(model->vocab);
+}
+
+llama_token llama_token_fim_mid(const struct llama_model * model) {
+    return llama_token_fim_mid_impl(model->vocab);
+}
+
+llama_token llama_token_fim_pad(const struct llama_model * model) {
+    return llama_token_fim_pad_impl(model->vocab);
+}
+
+llama_token llama_token_fim_rep(const struct llama_model * model) {
+    return llama_token_fim_rep_impl(model->vocab);
+}
+
+llama_token llama_token_fim_sep(const struct llama_model * model) {
+    return llama_token_fim_sep_impl(model->vocab);
 }
 
 //
@@ -21664,6 +21696,16 @@ static int32_t llama_chat_apply_template_internal(
         if (add_ass) {
             ss << "[|assistant|]";
         }
+    } else if (tmpl == "rwkv-world" || tmpl_contains("rwkv-world")) {
+        // this template requires the model to have "\n\n" as EOT token
+        for (auto message : chat) {
+            std::string role(message->role);
+            if (role == "user") {
+                ss << "User: " << message->content << "\n\nAssistant:";
+            } else {
+                ss << message->content << "\n\n";
+            }
+        }
     } else {
         // template not supported
         return -1;
@@ -21722,6 +21764,14 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod
     return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
 }
 
+struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model) {
+    return llama_sampler_init_infill_impl(model->vocab);
+}
+
+struct llama_sampler * llama_sampler_init_dry(const struct llama_model * model, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) {
+    return llama_sampler_init_dry_impl(model->vocab, llama_n_ctx_train(model), dry_multiplier, dry_base, dry_allowed_length, dry_penalty_last_n, seq_breakers, num_breakers);
+}
+
 //
 // model split
 //
@@ -21761,6 +21811,7 @@ const char * llama_print_system_info(void) {
     s += "AVX512_VBMI = " + std::to_string(ggml_cpu_has_avx512_vbmi()) + " | ";
     s += "AVX512_VNNI = " + std::to_string(ggml_cpu_has_avx512_vnni()) + " | ";
     s += "AVX512_BF16 = " + std::to_string(ggml_cpu_has_avx512_bf16()) + " | ";
+    s += "AMX_INT8 = "    + std::to_string(ggml_cpu_has_amx_int8())    + " | ";
     s += "FMA = "         + std::to_string(ggml_cpu_has_fma())         + " | ";
     s += "NEON = "        + std::to_string(ggml_cpu_has_neon())        + " | ";
     s += "SVE = "         + std::to_string(ggml_cpu_has_sve())         + " | ";
diff --git a/examples/talk-llama/llama.h b/examples/talk-llama/llama.h
index 7cae1bbe2e5..b2d1e7d5ae1 100644
--- a/examples/talk-llama/llama.h
+++ b/examples/talk-llama/llama.h
@@ -217,6 +217,7 @@ extern "C" {
 
     typedef struct llama_token_data_array {
         // TODO: consider SoA
+        // NOTE: this pointer can be modified by the samplers
         llama_token_data * data;
         size_t size;
         int64_t selected; // this is the index in the data array (i.e. not the token id)
@@ -232,8 +233,11 @@ extern "C" {
     // - token  : the token ids of the input (used when embd is NULL)
     // - embd   : token embeddings (i.e. float vector of size n_embd) (used when token is NULL)
     // - pos    : the positions of the respective token in the sequence
+    //            (if set to NULL, the token position will be tracked automatically by llama_decode)
     // - seq_id : the sequence to which the respective token belongs
+    //            (if set to NULL, the sequence ID will be assumed to be 0)
     // - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output
+    //            (if set to NULL, only the logits for last token will be returned)
     //
     typedef struct llama_batch {
         int32_t n_tokens;
@@ -244,15 +248,6 @@ extern "C" {
         int32_t      *  n_seq_id;
         llama_seq_id ** seq_id;
         int8_t       *  logits; // TODO: rename this to "output"
-
-        // NOTE: helpers for smooth API transition - can be deprecated in the future
-        //       for future-proof code, use the above fields instead and ignore everything below
-        //
-        // pos[i] = all_pos_0 + i*all_pos_1
-        //
-        llama_pos    all_pos_0;  // used if pos == NULL
-        llama_pos    all_pos_1;  // used if pos == NULL
-        llama_seq_id all_seq_id; // used if seq_id == NULL
     } llama_batch;
 
     enum llama_model_kv_override_type {
@@ -433,6 +428,7 @@ extern "C" {
     LLAMA_API bool llama_supports_mmap       (void);
     LLAMA_API bool llama_supports_mlock      (void);
     LLAMA_API bool llama_supports_gpu_offload(void);
+    LLAMA_API bool llama_supports_rpc        (void);
 
     LLAMA_API uint32_t llama_n_ctx      (const struct llama_context * ctx);
     LLAMA_API uint32_t llama_n_batch    (const struct llama_context * ctx);
@@ -775,15 +771,15 @@ extern "C" {
     // Decoding
     //
 
-    // Return batch for single sequence of tokens starting at pos_0
+    // Return batch for single sequence of tokens
+    // The sequence ID will be fixed to 0
+    // The position of the tokens will be tracked automatically by llama_decode
     //
     // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it
     //
     LLAMA_API struct llama_batch llama_batch_get_one(
                   llama_token * tokens,
-                      int32_t   n_tokens,
-                    llama_pos   pos_0,
-                 llama_seq_id   seq_id);
+                      int32_t   n_tokens);
 
     // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens
     // Each token can be assigned up to n_seq_max sequence ids
@@ -896,6 +892,7 @@ extern "C" {
     // Special tokens
     LLAMA_API llama_token llama_token_bos(const struct llama_model * model); // beginning-of-sentence
     LLAMA_API llama_token llama_token_eos(const struct llama_model * model); // end-of-sentence
+    LLAMA_API llama_token llama_token_eot(const struct llama_model * model); // end-of-turn
     LLAMA_API llama_token llama_token_cls(const struct llama_model * model); // classification
     LLAMA_API llama_token llama_token_sep(const struct llama_model * model); // sentence separator
     LLAMA_API llama_token llama_token_nl (const struct llama_model * model); // next-line
@@ -904,11 +901,17 @@ extern "C" {
     LLAMA_API bool llama_add_bos_token(const struct llama_model * model);
     LLAMA_API bool llama_add_eos_token(const struct llama_model * model);
 
-    // Codellama infill tokens
-    LLAMA_API llama_token llama_token_prefix(const struct llama_model * model); // Beginning of infill prefix
-    LLAMA_API llama_token llama_token_middle(const struct llama_model * model); // Beginning of infill middle
-    LLAMA_API llama_token llama_token_suffix(const struct llama_model * model); // Beginning of infill suffix
-    LLAMA_API llama_token llama_token_eot   (const struct llama_model * model); // End of infill middle
+    // infill tokens
+    DEPRECATED(LLAMA_API llama_token llama_token_prefix(const struct llama_model * model), "use llama_token_fim_pre instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_middle(const struct llama_model * model), "use llama_token_fim_mid instead");
+    DEPRECATED(LLAMA_API llama_token llama_token_suffix(const struct llama_model * model), "use llama_token_fim_suf instead");
+
+    LLAMA_API llama_token llama_token_fim_pre(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_fim_suf(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_fim_mid(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_fim_pad(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_fim_rep(const struct llama_model * model);
+    LLAMA_API llama_token llama_token_fim_sep(const struct llama_model * model);
 
     //
     // Tokenization
@@ -1067,12 +1070,13 @@ extern "C" {
 
     // available samplers:
 
-    LLAMA_API struct llama_sampler * llama_sampler_init_greedy     (void);
-    LLAMA_API struct llama_sampler * llama_sampler_init_dist       (uint32_t seed);
+    LLAMA_API struct llama_sampler * llama_sampler_init_greedy(void);
+    LLAMA_API struct llama_sampler * llama_sampler_init_dist  (uint32_t seed);
 
     /// @details Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
     /// NOTE: Avoid using on the full vocabulary as the sorting can become slow. For example, apply top-k or top-p sampling first.
-    LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void);
+    DEPRECATED(LLAMA_API struct llama_sampler * llama_sampler_init_softmax    (void),
+        "will be removed in the future (see https://github.com/ggerganov/llama.cpp/pull/9896#discussion_r1800920915)");
 
     /// @details Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
     LLAMA_API struct llama_sampler * llama_sampler_init_top_k      (int32_t k);
@@ -1088,11 +1092,16 @@ extern "C" {
 
     /// @details Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
     LLAMA_API struct llama_sampler * llama_sampler_init_typical    (float   p, size_t min_keep);
+
+    /// #details Updates the logits l_i` = l_i/t. When t <= 0.0f, the maximum logit is kept at it's original value, the rest are set to -inf
     LLAMA_API struct llama_sampler * llama_sampler_init_temp       (float   t);
 
     /// @details Dynamic temperature implementation (a.k.a. entropy) described in the paper https://arxiv.org/abs/2309.02772.
     LLAMA_API struct llama_sampler * llama_sampler_init_temp_ext   (float   t, float   delta, float exponent);
 
+    /// @details XTC sampler as described in https://github.com/oobabooga/text-generation-webui/pull/6335
+    LLAMA_API struct llama_sampler * llama_sampler_init_xtc        (float   p, float   t,     size_t min_keep, uint32_t seed);
+
     /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
     /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.
     /// @param tau  The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
@@ -1132,11 +1141,43 @@ extern "C" {
                                 bool   penalize_nl,     // consider newlines as a repeatable token
                                 bool   ignore_eos);     // ignore the end-of-sequence token
 
+    ///  @details DRY sampler, designed by p-e-w, as described in: https://github.com/oobabooga/text-generation-webui/pull/5677, porting Koboldcpp implementation authored by pi6am: https://github.com/LostRuins/koboldcpp/pull/982
+    LLAMA_API struct llama_sampler *    llama_sampler_init_dry(
+            const struct llama_model *  model,
+                               float    dry_multiplier,
+                               float    dry_base,
+                             int32_t    dry_allowed_length,
+                             int32_t    dry_penalty_last_n,
+                          const char ** seq_breakers,
+                              size_t    num_breakers);
+
     LLAMA_API struct llama_sampler * llama_sampler_init_logit_bias(
                              int32_t   n_vocab,
                              int32_t   n_logit_bias,
               const llama_logit_bias * logit_bias);
 
+    // this sampler is meant to be used for fill-in-the-middle infilling
+    // it's supposed to be used after top_k + top_p sampling
+    //
+    // 1. if the sum of the EOG probs times the number of candidates is higher than the sum of the other probs -> pick EOG
+    // 2. combine probs of tokens that have the same prefix
+    //
+    // example:
+    //
+    // - before:
+    //   "hel":   0.5
+    //   "hell":  0.2
+    //   "hello": 0.1
+    //   "dummy": 0.1
+    //
+    // - after:
+    //   "hel":   0.8
+    //   "dummy": 0.1
+    //
+    // 3. discard non-EOG tokens with low prob
+    // 4. if no tokens are left -> pick EOT
+    //
+    LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model);
 
     // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
     LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
diff --git a/examples/talk-llama/unicode-data.cpp b/examples/talk-llama/unicode-data.cpp
index 07424bbab54..04dcd7fcfbc 100644
--- a/examples/talk-llama/unicode-data.cpp
+++ b/examples/talk-llama/unicode-data.cpp
@@ -2311,7 +2311,7 @@ const std::unordered_set unicode_set_whitespace = {
 0x003000,
 };
 
-// list is always in ascending order, to enable binary searh
+// list is always in ascending order, to enable binary search
 const std::initializer_list> unicode_map_lowercase = {
 {0x000041, 0x000061},
 {0x000042, 0x000062},
@@ -3748,7 +3748,7 @@ const std::initializer_list> unicode_map_lowercase
 {0x01E921, 0x01E943},
 };
 
-// list is always in ascending order, to enable binary searh
+// list is always in ascending order, to enable binary search
 const std::initializer_list> unicode_map_uppercase = {
 {0x000061, 0x000041},
 {0x000062, 0x000042},
diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt
index 89fdf9d1c11..cfa6e3f70e4 100644
--- a/ggml/CMakeLists.txt
+++ b/ggml/CMakeLists.txt
@@ -99,6 +99,9 @@ option(GGML_AVX512      "ggml: enable AVX512"           OFF)
 option(GGML_AVX512_VBMI "ggml: enable AVX512-VBMI"      OFF)
 option(GGML_AVX512_VNNI "ggml: enable AVX512-VNNI"      OFF)
 option(GGML_AVX512_BF16 "ggml: enable AVX512-BF16"      OFF)
+option(GGML_AMX_TILE    "ggml: enable AMX-TILE"         OFF)
+option(GGML_AMX_INT8    "ggml: enable AMX-INT8"         OFF)
+option(GGML_AMX_BF16    "ggml: enable AMX-BF16"         OFF)
 option(GGML_FMA         "ggml: enable FMA"              ${INS_ENB})
 if (NOT MSVC)
     option(GGML_F16C    "ggml: enable F16C"             ${INS_ENB}) # in MSVC F16C is implied with AVX2/AVX512
@@ -158,6 +161,7 @@ set   (GGML_METAL_MACOSX_VERSION_MIN "" CACHE STRING
 set   (GGML_METAL_STD "" CACHE STRING       "ggml: metal standard version (-std flag)")
 option(GGML_OPENMP                          "ggml: use OpenMP"                                ON)
 option(GGML_RPC                             "ggml: use RPC"                                   OFF)
+option(GGML_AMX                             "ggml: use AMX"                                   OFF)
 option(GGML_SYCL                            "ggml: use SYCL"                                  OFF)
 option(GGML_SYCL_F16                        "ggml: use 16 bit floats for sycl calculations"   OFF)
 set   (GGML_SYCL_TARGET "INTEL" CACHE STRING
diff --git a/ggml/include/ggml-amx.h b/ggml/include/ggml-amx.h
new file mode 100644
index 00000000000..22b3f70f43a
--- /dev/null
+++ b/ggml/include/ggml-amx.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-backend.h"
+
+
+#ifdef  __cplusplus
+extern "C" {
+#endif
+
+// buffer_type API
+GGML_API ggml_backend_buffer_type_t ggml_backend_amx_buffer_type(void);
+
+GGML_API bool ggml_backend_is_amx(ggml_backend_t backend);
+
+// backend API
+GGML_API ggml_backend_t ggml_backend_amx_init(void);
+
+GGML_API void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads);
+
+GGML_API ggml_backend_reg_t ggml_backend_amx_reg(void);
+
+#ifdef  __cplusplus
+}
+#endif
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 4d7d2716e7a..5933b8e8f63 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -127,6 +127,8 @@ extern "C" {
         bool async;
         // pinned host buffer
         bool host_buffer;
+        // creating buffers from host ptr
+        bool buffer_from_host_ptr;
         // event synchronization
         bool events;
     };
@@ -168,6 +170,7 @@ extern "C" {
 
     // Functions that may be obtained using ggml_backend_reg_get_proc_address
     typedef ggml_backend_buffer_type_t (*ggml_backend_split_buffer_type_t)(const float *);
+    typedef void (*ggml_backend_set_n_threads_t)(ggml_backend_t, int);
 
     //
     // Backend registry
diff --git a/ggml/include/ggml-blas.h b/ggml/include/ggml-blas.h
index dd612860d61..25b2e637fb4 100644
--- a/ggml/include/ggml-blas.h
+++ b/ggml/include/ggml-blas.h
@@ -17,6 +17,8 @@ GGML_API bool ggml_backend_is_blas(ggml_backend_t backend);
 // for openblas and blis, this will also set the number of threads used for blas operations
 GGML_API void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads);
 
+GGML_API ggml_backend_reg_t ggml_backend_blas_reg(void);
+
 
 #ifdef  __cplusplus
 }
diff --git a/ggml/include/ggml-cann.h b/ggml/include/ggml-cann.h
index 95bdaf10d17..52897549388 100644
--- a/ggml/include/ggml-cann.h
+++ b/ggml/include/ggml-cann.h
@@ -34,6 +34,8 @@ extern "C" {
  */
 #define GGML_CANN_MAX_DEVICES 16
 
+GGML_API ggml_backend_reg_t ggml_backend_cann_reg(void);
+
 /**
  * @brief Initializes the CANN backend for a specified device.
  *
diff --git a/ggml/include/ggml-metal.h b/ggml/include/ggml-metal.h
index c3ec572b2dc..b8d3f678b71 100644
--- a/ggml/include/ggml-metal.h
+++ b/ggml/include/ggml-metal.h
@@ -43,7 +43,9 @@ GGML_API ggml_backend_t ggml_backend_metal_init(void);
 
 GGML_API bool ggml_backend_is_metal(ggml_backend_t backend);
 
-GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size);
+GGML_DEPRECATED(
+        GGML_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
+        "obsoleted by the new device interface - https://github.com/ggerganov/llama.cpp/pull/9713");
 
 GGML_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
 
@@ -57,6 +59,8 @@ GGML_API bool ggml_backend_metal_supports_family(ggml_backend_t backend, int fam
 // capture all command buffers committed the next time `ggml_backend_graph_compute` is called
 GGML_API void ggml_backend_metal_capture_next_compute(ggml_backend_t backend);
 
+GGML_API ggml_backend_reg_t ggml_backend_metal_reg(void);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/ggml/include/ggml-rpc.h b/ggml/include/ggml-rpc.h
index 64cde7f13d3..d5796736821 100644
--- a/ggml/include/ggml-rpc.h
+++ b/ggml/include/ggml-rpc.h
@@ -17,7 +17,11 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
 
 GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total);
 
-GGML_API void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
+GGML_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem);
+
+GGML_API ggml_backend_reg_t ggml_backend_rpc_reg(void);
+
+GGML_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint);
 
 #ifdef  __cplusplus
 }
diff --git a/ggml/include/ggml-sycl.h b/ggml/include/ggml-sycl.h
index 03b698e61b9..af521f59930 100644
--- a/ggml/include/ggml-sycl.h
+++ b/ggml/include/ggml-sycl.h
@@ -19,6 +19,8 @@ extern "C" {
 // backend API
 GGML_API ggml_backend_t ggml_backend_sycl_init(int device);
 
+GGML_API bool ggml_backend_is_sycl(ggml_backend_t backend);
+
 // devide buffer
 GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device);
 
@@ -29,14 +31,19 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const fl
 GGML_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void);
 
 GGML_API void ggml_backend_sycl_print_sycl_devices(void);
-GGML_API void ggml_sycl_get_gpu_list(int *id_list, int max_len);
-GGML_API void ggml_sycl_get_device_description(int device, char *description, size_t description_size);
+GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len);
+GGML_API void ggml_backend_sycl_get_device_description(int device,
+                                                       char *description,
+                                                       size_t description_size);
 GGML_API int  ggml_backend_sycl_get_device_count();
 GGML_API void ggml_backend_sycl_get_device_memory(int device, size_t *free, size_t *total);
 
 // SYCL doesn't support registering host memory, keep here for reference
 // GGML_API bool ggml_backend_sycl_register_host_buffer(void * buffer, size_t size);
 // GGML_API void ggml_backend_sycl_unregister_host_buffer(void * buffer);
+
+GGML_API ggml_backend_reg_t ggml_backend_sycl_reg(void);
+
 #ifdef  __cplusplus
 }
 #endif
diff --git a/ggml/include/ggml-vulkan.h b/ggml/include/ggml-vulkan.h
index e074042efae..c03bbfe5e37 100644
--- a/ggml/include/ggml-vulkan.h
+++ b/ggml/include/ggml-vulkan.h
@@ -24,6 +24,8 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num);
 // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU
 GGML_API ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type(void);
 
+GGML_API ggml_backend_reg_t ggml_backend_vk_reg(void);
+
 #ifdef  __cplusplus
 }
 #endif
diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h
index 8d36b3d4d42..e5862246c8c 100644
--- a/ggml/include/ggml.h
+++ b/ggml/include/ggml.h
@@ -2489,6 +2489,7 @@ extern "C" {
     GGML_API int ggml_cpu_has_avx512_vbmi(void);
     GGML_API int ggml_cpu_has_avx512_vnni(void);
     GGML_API int ggml_cpu_has_avx512_bf16(void);
+    GGML_API int ggml_cpu_has_amx_int8   (void);
     GGML_API int ggml_cpu_has_fma        (void);
     GGML_API int ggml_cpu_has_neon       (void);
     GGML_API int ggml_cpu_has_sve        (void);
@@ -2536,7 +2537,7 @@ extern "C" {
     typedef void (*ggml_gemm_t)     (int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT x,
                                        const void * GGML_RESTRICT y, int nr, int nc);
 
-    typedef struct {
+    struct ggml_type_traits {
         const char             * type_name;
         int64_t                  blck_size;
         int64_t                  blck_size_interleave; // interleave elements in blocks
@@ -2552,9 +2553,9 @@ extern "C" {
         int64_t                  ncols; // number of columns to process simultaneously
         ggml_gemv_t              gemv;
         ggml_gemm_t              gemm;
-    } ggml_type_traits_t;
+    };
 
-    GGML_API ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type);
+    GGML_API const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type);
 
 #ifdef  __cplusplus
 }
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 286bec255df..aa405e4d0fb 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -163,8 +163,8 @@ if (GGML_OPENMP)
         list(APPEND GGML_EXTRA_LIBS_PRIVATE OpenMP::OpenMP_C OpenMP::OpenMP_CXX)
 
         if (GGML_MUSA)
-            list(APPEND GGML_EXTRA_INCLUDES     "/usr/lib/llvm-10/include/openmp")
-            list(APPEND GGML_EXTRA_LIBS_PRIVATE "/usr/lib/llvm-10/lib/libomp.so")
+            list(APPEND GGML_EXTRA_INCLUDES     "/usr/lib/llvm-14/lib/clang/14.0.0/include")
+            list(APPEND GGML_EXTRA_LIBS_PRIVATE "/usr/lib/llvm-14/lib/libomp.so")
         endif()
     else()
         message(WARNING "OpenMP not found")
@@ -190,22 +190,24 @@ if (GGML_BLAS)
             # see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
             find_package(PkgConfig REQUIRED)
             if (${GGML_BLAS_VENDOR} MATCHES "Generic")
-                pkg_check_modules(DepBLAS REQUIRED blas)
+                pkg_check_modules(DepBLAS blas)
             elseif (${GGML_BLAS_VENDOR} MATCHES "OpenBLAS")
                 # As of openblas v0.3.22, the 64-bit is named openblas64.pc
                 pkg_check_modules(DepBLAS openblas64)
                 if (NOT DepBLAS_FOUND)
-                    pkg_check_modules(DepBLAS REQUIRED openblas)
+                    pkg_check_modules(DepBLAS openblas)
                 endif()
             elseif (${GGML_BLAS_VENDOR} MATCHES "FLAME")
-                pkg_check_modules(DepBLAS REQUIRED blis)
+                add_compile_definitions(GGML_BLAS_USE_BLIS)
+                pkg_check_modules(DepBLAS blis)
             elseif (${GGML_BLAS_VENDOR} MATCHES "ATLAS")
-                pkg_check_modules(DepBLAS REQUIRED blas-atlas)
+                pkg_check_modules(DepBLAS blas-atlas)
             elseif (${GGML_BLAS_VENDOR} MATCHES "FlexiBLAS")
-                pkg_check_modules(DepBLAS REQUIRED flexiblas_api)
+                pkg_check_modules(DepBLAS flexiblas_api)
             elseif (${GGML_BLAS_VENDOR} MATCHES "Intel")
+                add_compile_definitions(GGML_BLAS_USE_MKL)
                 # all Intel* libraries share the same include path
-                pkg_check_modules(DepBLAS REQUIRED mkl-sdl)
+                pkg_check_modules(DepBLAS mkl-sdl)
             elseif (${GGML_BLAS_VENDOR} MATCHES "NVHPC")
                 # this doesn't provide pkg-config
                 # suggest to assign BLAS_INCLUDE_DIRS on your own
@@ -265,6 +267,26 @@ if (GGML_LLAMAFILE)
     set(GGML_SOURCES_LLAMAFILE llamafile/sgemm.cpp)
 endif()
 
+if (GGML_AMX)
+    if (CMAKE_COMPILER_IS_GNUCC AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 11.0)
+    else()
+        set(GGML_AMX OFF)
+        message(WARNING "AMX requires gcc version > 11.0. Turning off GGML_AMX.")
+    endif()
+
+    if (GGML_AMX)
+        message(STATUS "Using AMX")
+
+        list(APPEND GGML_CDEF_PUBLIC GGML_USE_AMX)
+
+        file(GLOB   GGML_HEADERS_AMX "ggml-amx/*.h")
+        list(APPEND GGML_HEADERS_AMX "../include/ggml-amx.h")
+
+        file(GLOB   GGML_SOURCES_AMX "ggml-amx/*.cpp")
+        list(APPEND GGML_SOURCES_AMX "ggml-amx.cpp")
+    endif()
+endif()
+
 if (GGML_CUDA)
     cmake_minimum_required(VERSION 3.18)  # for CMAKE_CUDA_ARCHITECTURES
 
@@ -1178,6 +1200,18 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
                 add_compile_definitions($<$:__AVX512BF16__>)
                 add_compile_definitions($<$:__AVX512BF16__>)
             endif()
+            if (GGML_AMX_TILE)
+                add_compile_definitions($<$:__AMX_TILE__>)
+                add_compile_definitions($<$:__AMX_TILE__>)
+            endif()
+            if (GGML_AMX_INT8)
+                add_compile_definitions($<$:__AMX_INT8__>)
+                add_compile_definitions($<$:__AMX_INT8__>)
+            endif()
+            if (GGML_AMX_BF16)
+                add_compile_definitions($<$:__AMX_BF16__>)
+                add_compile_definitions($<$:__AMX_BF16__>)
+            endif()
         elseif (GGML_AVX2)
             list(APPEND ARCH_FLAGS /arch:AVX2)
         elseif (GGML_AVX)
@@ -1213,6 +1247,15 @@ elseif (CMAKE_OSX_ARCHITECTURES STREQUAL "x86_64" OR CMAKE_GENERATOR_PLATFORM_LW
         if (GGML_AVX512_BF16)
             list(APPEND ARCH_FLAGS -mavx512bf16)
         endif()
+        if (GGML_AMX_TILE)
+            list(APPEND ARCH_FLAGS -mamx-tile)
+        endif()
+        if (GGML_AMX_INT8)
+            list(APPEND ARCH_FLAGS -mamx-int8)
+        endif()
+        if (GGML_AMX_BF16)
+            list(APPEND ARCH_FLAGS -mamx-bf16)
+        endif()
     endif()
 elseif (${CMAKE_SYSTEM_PROCESSOR} MATCHES "ppc64")
     message(STATUS "PowerPC detected")
@@ -1338,6 +1381,7 @@ add_library(ggml
             ${GGML_SOURCES_ROCM}      ${GGML_HEADERS_ROCM}
             ${GGML_SOURCES_BLAS}      ${GGML_HEADERS_BLAS}
             ${GGML_SOURCES_LLAMAFILE} ${GGML_HEADERS_LLAMAFILE}
+            ${GGML_SOURCES_AMX}       ${GGML_HEADERS_AMX}
             ${GGML_SOURCES_CANN}      ${GGML_HEADERS_CANN}
             ggml-aarch64.c            ggml-aarch64.h
             )
@@ -1361,6 +1405,10 @@ if (MATH_LIBRARY)
     endif()
 endif()
 
+if (CMAKE_SYSTEM_NAME MATCHES "Android")
+    list(APPEND GGML_EXTRA_LIBS_PRIVATE dl) # Must be linked explicitly
+endif()
+
 list(REMOVE_DUPLICATES GGML_EXTRA_LIBS_PRIVATE)
 list(REMOVE_DUPLICATES GGML_EXTRA_LIBS_PUBLIC)
 target_link_libraries(ggml PRIVATE ${GGML_EXTRA_LIBS_PRIVATE} PUBLIC ${GGML_EXTRA_LIBS_PUBLIC})
diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c
index 70187b9b65f..041de9e3efc 100644
--- a/ggml/src/ggml-alloc.c
+++ b/ggml/src/ggml-alloc.c
@@ -14,7 +14,7 @@
 
 //#define GGML_ALLOCATOR_DEBUG
 
-//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
+//#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__)
 #define AT_PRINTF(...)
 
 
@@ -89,7 +89,7 @@ void ggml_tallocr_alloc(struct ggml_tallocr * talloc, struct ggml_tensor * tenso
     size = GGML_PAD(size, talloc->alignment);
 
     if (talloc->offset + size > ggml_backend_buffer_get_size(talloc->buffer)) {
-        fprintf(stderr, "%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
+        GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %s (needed %zu, available %zu)\n",
                 __func__, tensor->name, size, ggml_backend_buffer_get_size(talloc->buffer) - talloc->offset);
         GGML_ABORT("not enough space in the buffer");
     }
@@ -172,7 +172,7 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz
             best_fit_block = alloc->n_free_blocks - 1;
         } else {
             // this should never happen
-            fprintf(stderr, "%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
+            GGML_LOG_ERROR("%s: not enough space in the buffer to allocate %zu bytes, largest block available %zu bytes\n",
                     __func__, size, max_avail);
             GGML_ABORT("not enough space in the buffer");
         }
@@ -209,16 +209,16 @@ static size_t ggml_dyn_tallocr_alloc(struct ggml_dyn_tallocr * alloc, size_t siz
                 }
             }
         }
-        fprintf(stderr, "max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
+        GGML_LOG_DEBUG("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
         for (int i = 0; i < 1024; i++) {
             if (alloc->allocated_tensors[i].tensor) {
-                fprintf(stderr, "%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name,
+                GGML_LOG_DEBUG("%s [%zx-%zx] (%.2f MB) ", alloc->allocated_tensors[i].tensor->name,
                     alloc->allocated_tensors[i].offset,
                     alloc->allocated_tensors[i].offset + ggml_nbytes(alloc->allocated_tensors[i].tensor),
                     ggml_nbytes(alloc->allocated_tensors[i].tensor) / 1024.0 / 1024.0);
             }
         }
-        fprintf(stderr, "\n");
+        GGML_LOG_DEBUG("\n");
     }
 #endif
 
@@ -348,7 +348,6 @@ struct tensor_alloc {
 };
 
 struct leaf_alloc {
-    int buffer_id;
     struct tensor_alloc leaf;
 };
 
@@ -740,7 +739,6 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
     for (int i = 0; i < graph->n_leafs; i++) {
         struct ggml_tensor * leaf = graph->leafs[i];
         struct hash_node * hn = ggml_gallocr_hash_get(galloc, leaf);
-        galloc->leaf_allocs[i].buffer_id = hn->buffer_id;
         if (leaf->view_src || leaf->data) {
             galloc->leaf_allocs[i].leaf.buffer_id = -1;
             galloc->leaf_allocs[i].leaf.offset = SIZE_MAX;
@@ -768,13 +766,13 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
         // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views
         if (new_size > cur_size || galloc->buffers[i] == NULL) {
 #ifndef NDEBUG
-            fprintf(stderr, "%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
+            GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
 #endif
 
             ggml_backend_buffer_free(galloc->buffers[i]);
             galloc->buffers[i] = ggml_backend_buft_alloc_buffer(galloc->bufts[i], new_size);
             if (galloc->buffers[i] == NULL) {
-                fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
+                GGML_LOG_ERROR("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), new_size);
                 return false;
             }
             ggml_backend_buffer_set_usage(galloc->buffers[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
@@ -825,14 +823,14 @@ static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_t
 static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph * graph) {
     if (galloc->n_nodes != graph->n_nodes) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: graph has different number of nodes\n", __func__);
+        GGML_LOG_DEBUG("%s: graph has different number of nodes\n", __func__);
 #endif
         return true;
     }
 
     if (galloc->n_leafs != graph->n_leafs) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: graph has different number of leafs\n", __func__);
+        GGML_LOG_DEBUG("%s: graph has different number of leafs\n", __func__);
 #endif
         return true;
     }
@@ -843,7 +841,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
 
         if (!ggml_gallocr_node_needs_realloc(galloc, node, &node_alloc->dst)) {
 #ifndef NDEBUG
-            fprintf(stderr, "%s: node %s is not valid\n", __func__, node->name);
+            GGML_LOG_DEBUG("%s: node %s is not valid\n", __func__, node->name);
 #endif
             return true;
         }
@@ -855,7 +853,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
             }
             if (!ggml_gallocr_node_needs_realloc(galloc, src, &node_alloc->src[j])) {
 #ifndef NDEBUG
-                fprintf(stderr, "%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name);
+                GGML_LOG_DEBUG("%s: src %d (%s) of node %s is not valid\n", __func__, j, src->name, node->name);
 #endif
                 return true;
             }
@@ -869,14 +867,14 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
     if (ggml_gallocr_needs_realloc(galloc, graph)) {
         if (galloc->n_buffers == 1) {
 #ifndef NDEBUG
-            fprintf(stderr, "%s: reallocating buffers automatically\n", __func__);
+            GGML_LOG_DEBUG("%s: reallocating buffers automatically\n", __func__);
 #endif
             if (!ggml_gallocr_reserve(galloc, graph)) {
                 return false;
             }
         } else {
 #ifndef NDEBUG
-            fprintf(stderr, "%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__);
+            GGML_LOG_DEBUG("%s: cannot reallocate multi buffer graph automatically, call reserve\n", __func__);
 #endif
             return false;
         }
@@ -940,7 +938,7 @@ static bool alloc_tensor_range(struct ggml_context * ctx,
     ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
     if (buffer == NULL) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
+        GGML_LOG_DEBUG("%s: failed to allocate %s buffer of size %zu\n", __func__, ggml_backend_buft_name(buft), size);
 #endif
         for (size_t i = 0; i < *n_buffers; i++) {
             ggml_backend_buffer_free((*buffers)[i]);
@@ -990,7 +988,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
         }
 
         if (this_size > max_size) {
-            fprintf(stderr, "%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
+            GGML_LOG_ERROR("%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
                     __func__, t->name,
                     ggml_backend_buft_name(buft),
                     this_size, max_size);
@@ -1022,7 +1020,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
 
     if (n_buffers == 0) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: all tensors in the context are already allocated\n", __func__);
+        GGML_LOG_DEBUG("%s: all tensors in the context are already allocated\n", __func__);
 #endif
         return NULL;
     }
diff --git a/ggml/src/ggml-amx.cpp b/ggml/src/ggml-amx.cpp
new file mode 100644
index 00000000000..ac6ec23426e
--- /dev/null
+++ b/ggml/src/ggml-amx.cpp
@@ -0,0 +1,453 @@
+#include "ggml-amx.h"
+#include "ggml-amx/common.h"
+#include "ggml-amx/mmq.h"
+#include "ggml-backend-impl.h"
+#include "ggml-impl.h"
+
+#if defined(__gnu_linux__)
+#include 
+#include 
+#endif
+
+#include 
+#include 
+#include 
+
+#if defined(__AMX_INT8__)
+
+// AMX buffer interface
+static const char * ggml_backend_amx_buffer_get_name(ggml_backend_buffer_t buffer) {
+    return "AMX";
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_amx_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    free(buffer->context);
+}
+
+static void * ggml_backend_amx_buffer_get_base(ggml_backend_buffer_t buffer) {
+    return (void *)(buffer->context);
+}
+
+static void ggml_backend_amx_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
+    memset((char *)tensor->data + offset, value, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_amx_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+    if (qtype_has_amx_kernels(tensor->type)) {
+        ggml_backend_amx_convert_weight(tensor, data, offset, size);
+    } else {
+        memcpy((char *)tensor->data + offset, data, size);
+    }
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_amx_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
+    GGML_ASSERT(!qtype_has_amx_kernels(tensor->type));
+    memcpy(data, (const char *)tensor->data + offset, size);
+
+    GGML_UNUSED(buffer);
+}
+
+static bool ggml_backend_amx_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
+    if (ggml_backend_buffer_is_host(src->buffer)) {
+        if (qtype_has_amx_kernels(src->type)) {
+            ggml_backend_amx_convert_weight(dst, src->data, 0, ggml_backend_amx_get_alloc_size(dst));
+        } else {
+            memcpy(dst->data, src->data, ggml_nbytes(src));
+        }
+        return true;
+    }
+    return false;
+
+    GGML_UNUSED(buffer);
+}
+
+static void ggml_backend_amx_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    memset(buffer->context, value, buffer->size);
+}
+
+static ggml_backend_buffer_i ggml_backend_amx_buffer_interface = {
+    /* .get_name        = */ ggml_backend_amx_buffer_get_name,
+    /* .free_buffer     = */ ggml_backend_amx_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_amx_buffer_get_base,
+    /* .init_tensor     = */ NULL, // no initialization required
+    /* .memset_tensor   = */ ggml_backend_amx_buffer_memset_tensor,
+    /* .set_tensor      = */ ggml_backend_amx_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_amx_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_amx_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_amx_buffer_clear,
+    /* .reset           = */ NULL,
+};
+
+static const char * ggml_backend_amx_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return "AMX";
+
+    GGML_UNUSED(buft);
+}
+
+static ggml_backend_buffer_t ggml_backend_amx_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    void * data = aligned_alloc(TENSOR_ALIGNMENT, size);
+    if (data == NULL) {
+        fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
+        return NULL;
+    }
+
+    return ggml_backend_buffer_init(buft, ggml_backend_amx_buffer_interface, data, size);
+}
+
+static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return TENSOR_ALIGNMENT;
+
+    GGML_UNUSED(buft);
+}
+
+static size_t ggml_backend_amx_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor* tensor) {
+    return ggml_backend_amx_get_alloc_size(tensor);
+
+    GGML_UNUSED(buft);
+}
+
+static bool ggml_backend_amx_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return false;
+
+    GGML_UNUSED(buft);
+}
+
+ggml_backend_buffer_type_t ggml_backend_amx_buffer_type() {
+    static struct ggml_backend_buffer_type ggml_backend_buffer_type_amx = {
+        /* .iface = */ {
+        /* .get_name         = */ ggml_backend_amx_buffer_type_get_name,
+        /* .alloc_buffer     = */ ggml_backend_amx_buffer_type_alloc_buffer,
+        /* .get_alignment    = */ ggml_backend_amx_buffer_type_get_alignment,
+        /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+        /* .get_alloc_size   = */ ggml_backend_amx_buffer_type_get_alloc_size,
+        /* .is_host          = */ ggml_backend_amx_buffer_type_is_host,
+        },
+        /* .device  = */ NULL,
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_buffer_type_amx;
+}
+
+// backend interface
+
+static const char * ggml_backend_amx_name(ggml_backend_t backend) {
+    return "AMX";
+
+    GGML_UNUSED(backend);
+}
+
+static void ggml_backend_amx_free(ggml_backend_t backend) {
+    ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
+    delete ctx;
+    delete backend;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_amx_get_default_buffer_type(ggml_backend_t backend) {
+    return ggml_backend_amx_buffer_type();
+
+    GGML_UNUSED(backend);
+}
+
+static enum ggml_status ggml_backend_amx_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
+    ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend->context;
+
+    for (int i = 0; i < cgraph->n_nodes; i++) {
+        struct ggml_tensor * node = cgraph->nodes[i];
+
+        switch (node->op) {
+        case GGML_OP_MUL_MAT:
+            ggml_backend_amx_mul_mat(ctx, node);
+            break;
+
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            break;
+
+        default:
+            fprintf(stderr, "%s: unsupported op %s\n", __func__, ggml_op_desc(node));
+            GGML_ASSERT(false);
+        }
+    }
+
+    return GGML_STATUS_SUCCESS;
+
+    GGML_UNUSED(backend);
+}
+
+static struct ggml_backend_i ggml_backend_amx_i = {
+    /* .get_name                = */ ggml_backend_amx_name,
+    /* .free                    = */ ggml_backend_amx_free,
+    /* .get_default_buffer_type = */ ggml_backend_amx_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,
+    /* .get_tensor_async        = */ NULL,
+    /* .cpy_tensor_async        = */ NULL,
+    /* .synchronize             = */ NULL,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_amx_graph_compute,
+    /* .supports_op             = */ NULL,
+    /* .supports_buft           = */ NULL,
+    /* .offload_op              = */ NULL,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_amx_guid() {
+    static ggml_guid guid = { 0x13, 0xb8, 0xa4, 0xc4, 0xba, 0xfe, 0x51, 0x67, 0x87, 0x44, 0x55, 0x15, 0xb2, 0x35, 0x62, 0x3e };
+    return &guid;
+}
+
+#define ARCH_GET_XCOMP_PERM     0x1022
+#define ARCH_REQ_XCOMP_PERM     0x1023
+#define XFEATURE_XTILECFG       17
+#define XFEATURE_XTILEDATA      18
+
+static bool ggml_amx_init() {
+#if defined(__gnu_linux__)
+    if (syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA)) {
+        fprintf(stderr, "AMX is not ready to be used!\n");
+        return false;
+    }
+    return true;
+#elif defined(_WIN32)
+    return true;
+#endif
+}
+
+ggml_backend_t ggml_backend_amx_init() {
+
+    // invoke a Linux system call to request access to AMX features
+    ggml_amx_init();
+
+    // backend context
+    ggml_backend_amx_context * ctx = new ggml_backend_amx_context;
+
+    // ggml amx backend
+    ggml_backend_t backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_amx_guid(),
+        /* .interface = */ ggml_backend_amx_i,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_amx_reg(), 0),
+        /* .context   = */ ctx,
+    };
+
+    return backend;
+}
+
+bool ggml_backend_is_amx(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_amx_guid());
+}
+
+void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
+    GGML_ASSERT(ggml_backend_is_amx(backend_amx));
+
+    ggml_backend_amx_context * ctx = (ggml_backend_amx_context *)backend_amx->context;
+    ctx->n_threads = n_threads;
+}
+
+// device interface
+
+static const char * ggml_backend_amx_device_get_name(ggml_backend_dev_t dev) {
+    return "AMX";
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_amx_device_get_description(ggml_backend_dev_t dev) {
+    return "Intel Advanced Matrix Extensions";
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_amx_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    // TODO
+    *free = 0;
+    *total = 0;
+
+    GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_amx_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_CPU;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_amx_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_amx_device_get_name(dev);
+    props->description = ggml_backend_amx_device_get_description(dev);
+    props->type        = ggml_backend_amx_device_get_type(dev);
+    ggml_backend_amx_device_get_memory(dev, &props->memory_free, &props->memory_total);
+
+    // `buffer_from_host_ptr` is intended to be used in mmap, when memory layout unchanged
+    props->caps = {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_amx_device_init(ggml_backend_dev_t dev, const char * params) {
+    return ggml_backend_amx_init();
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_amx_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_amx_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+
+    // handle only 2d gemm for now
+    auto is_contiguous_2d = [](const struct ggml_tensor * t) {
+        return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1;
+    };
+
+    switch (op->op) {
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            return true;
+
+        case GGML_OP_MUL_MAT: {
+            const struct ggml_tensor * src0 = op->src[0];
+            const struct ggml_tensor * src1 = op->src[1];
+
+            const enum ggml_type type = src0->type;
+            const int64_t ne0 = op->ne[0];
+
+            bool is_training = src0->grad || src1->grad;
+
+            // amx kernels enables for Q4_0, Q4_1, Q8_0, F16
+            // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256
+            bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16);
+
+            bool can_use_amx =
+                is_contiguous_2d(src0) &&       // src0 must be contiguous
+                is_contiguous_2d(src1) &&       // src1 must be contiguous
+                !is_training &&                 // inference only
+                src1->type == GGML_TYPE_F32 &&  // src1 must be float32
+                has_amx_kernels &&              // with amx kernel impls
+                ne0 % (TILE_N * 2) == 0;        // out_features is 32x
+
+            return can_use_amx;
+        }
+        default:
+            return false;
+    }
+
+    GGML_UNUSED(dev);
+}
+
+static bool ggml_backend_amx_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_amx_buffer_type_get_name;
+
+    GGML_UNUSED(dev);
+}
+
+static const struct ggml_backend_device_i ggml_backend_amx_device_i = {
+    /* .get_name             = */ ggml_backend_amx_device_get_name,
+    /* .get_description      = */ ggml_backend_amx_device_get_description,
+    /* .get_memory           = */ ggml_backend_amx_device_get_memory,
+    /* .get_type             = */ ggml_backend_amx_device_get_type,
+    /* .get_props            = */ ggml_backend_amx_device_get_props,
+    /* .init_backend         = */ ggml_backend_amx_device_init,
+    /* .get_buffer_type      = */ ggml_backend_amx_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ NULL,
+    /* .supports_op          = */ ggml_backend_amx_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_amx_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// backend reg interface
+
+static const char * ggml_backend_amx_reg_get_name(ggml_backend_reg_t reg) {
+    return "AMX";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_amx_reg_get_device_count(ggml_backend_reg_t reg) {
+    return 1;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_amx_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    static ggml_backend_device ggml_backend_amx_device = {
+        /* .iface   = */ ggml_backend_amx_device_i,
+        /* .reg     = */ reg,
+        /* .context = */ nullptr,
+    };
+
+    return &ggml_backend_amx_device;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(index);
+}
+
+static void * ggml_backend_amx_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
+        return (void *)ggml_backend_amx_set_n_threads;
+    }
+    return NULL;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(name);
+}
+
+static const struct ggml_backend_reg_i ggml_backend_amx_reg_i = {
+    /* .get_name         = */ ggml_backend_amx_reg_get_name,
+    /* .get_device_count = */ ggml_backend_amx_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_amx_reg_get_device,
+    /* .get_proc_address = */ ggml_backend_amx_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_amx_reg(void) {
+    static struct ggml_backend_reg ggml_backend_amx_reg = {
+        /* .iface   = */ ggml_backend_amx_reg_i,
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_amx_reg;
+}
+
+#else // if defined(__AMX_INT8__)
+
+ggml_backend_t ggml_backend_amx_init(void) {
+    fprintf(stderr, "GGML is not compiled with AMX support!\n");
+    return ggml_backend_t{};
+}
+
+void ggml_backend_amx_set_n_threads(ggml_backend_t backend_amx, int n_threads) {
+    fprintf(stderr, "GGML is not compiled with AMX support!\n");
+
+    GGML_UNUSED(backend_amx);
+    GGML_UNUSED(n_threads);
+}
+
+#endif
diff --git a/ggml/src/ggml-amx/common.h b/ggml/src/ggml-amx/common.h
new file mode 100644
index 00000000000..2b6c6352704
--- /dev/null
+++ b/ggml/src/ggml-amx/common.h
@@ -0,0 +1,93 @@
+#pragma once
+
+#include "ggml.h"
+#include "ggml-cpu-impl.h" // 
+
+#include 
+#include 
+#include 
+
+#if defined(_OPENMP)
+#include 
+#endif
+
+#define TILE_M 16
+#define TILE_N 16
+#define TILE_K 32
+#define VNNI_BLK 4
+
+#define AMX_BLK_SIZE 32
+
+#define TMM0 0
+#define TMM1 1
+#define TMM2 2
+#define TMM3 3
+#define TMM4 4
+#define TMM5 5
+#define TMM6 6
+#define TMM7 7
+
+// parallel routines
+template ::value, int>::type = 0>
+inline T div_up(T x, T y) { return (x + y - 1) / y; }
+
+template 
+inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
+#if 0
+    // onednn partition pattern
+    T& n_my = n_end;
+    if (nth <= 1 || n == 0) {
+        n_start = 0;
+        n_my = n;
+    } else {
+        T n1 = div_up(n, nth);
+        T n2 = n1 - 1;
+        T T1 = n - n2 * nth;
+        n_my = ith < T1 ? n1 : n2;
+        n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
+    }
+    n_end += n_start;
+#else
+    // pytorch aten partition pattern
+    T n_my = div_up(n, nth);
+    n_start = ith * n_my;
+    n_end = std::min(n_start + n_my, n);
+#endif
+}
+
+template 
+inline void parallel_for(int nth, int n, const func_t& f) {
+#if defined(_OPENMP)
+#pragma omp parallel num_threads(nth)
+{
+    //int nth = omp_get_num_threads();
+    int ith = omp_get_thread_num();
+    int tbegin, tend;
+    balance211(n, nth, ith, tbegin, tend);
+    f(tbegin, tend);
+}
+#else
+    f(0, n);
+
+    GGML_UNUSED(nth);
+#endif
+}
+
+// quantized types that have AMX support
+inline bool qtype_has_amx_kernels(const enum ggml_type type) {
+    // TODO: fix padding for vnni format
+    return (type == GGML_TYPE_Q4_0) ||
+        (type == GGML_TYPE_Q4_1);
+        //(type == GGML_TYPE_Q8_0) ||
+        //(type == GGML_TYPE_Q4_K) ||
+        //(type == GGML_TYPE_Q5_K) ||
+        //(type == GGML_TYPE_Q6_K) ||
+        //(type == GGML_TYPE_IQ4_XS);
+}
+
+// ggml backend context
+struct ggml_backend_amx_context {
+    int n_threads = GGML_DEFAULT_N_THREADS;
+    std::unique_ptr work_data;
+    size_t work_size = 0;
+};
diff --git a/ggml/src/ggml-amx/mmq.cpp b/ggml/src/ggml-amx/mmq.cpp
new file mode 100644
index 00000000000..239d15121a6
--- /dev/null
+++ b/ggml/src/ggml-amx/mmq.cpp
@@ -0,0 +1,2509 @@
+
+#if defined(__GNUC__)
+#pragma GCC diagnostic ignored "-Wpedantic"
+#pragma GCC diagnostic ignored "-Wunused-local-typedefs"
+#endif
+
+#include "mmq.h"
+#include "ggml-impl.h"
+#include "ggml-quants.h"
+#include 
+#include 
+
+#if defined(__gnu_linux__)
+#include 
+#include 
+#endif
+
+#if defined(_OPENMP)
+#include 
+#endif
+
+#if (defined(_WIN32) || defined(_WIN64))
+#define RESTRICT __restrict
+#else
+#define RESTRICT __restrict__
+#endif
+
+#if (defined(_WIN32) || defined(_WIN64))
+#define ALWAYS_INLINE __forceinline
+#elif __has_attribute(always_inline) || defined(__GNUC__)
+#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
+#else
+#define ALWAYS_INLINE inline
+#endif
+
+#if defined(__AMX_INT8__)
+
+namespace {
+
+// Forced unrolling
+template 
+struct Unroll {
+    template 
+    ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
+        Unroll{}(f, args...);
+        f(std::integral_constant{}, args...);
+    }
+};
+
+template <>
+struct Unroll<1> {
+    template 
+    ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
+        f(std::integral_constant{}, args...);
+    }
+};
+
+// type traits
+template  struct PackedTypes {};
+template <> struct PackedTypes { using type = int8_t; };
+template <> struct PackedTypes { using type = uint8_t; };
+template <> struct PackedTypes { using type = int8_t; };
+template  using packed_B_type = typename PackedTypes::type;
+
+template 
+struct do_compensate : std::integral_constant::value> {};
+
+template 
+struct do_unpack : std::integral_constant::value ||
+    std::is_same::value> {};
+
+template 
+struct is_type_qkk : std::integral_constant::value ||
+    std::is_same::value ||
+    std::is_same::value ||
+    std::is_same::value> {};
+
+#define GGML_DISPATCH_FLOATING_TYPES(TYPE, ...)                                        \
+    [&] {                                                                              \
+        switch (TYPE) {                                                                \
+            case GGML_TYPE_F16: {                                                      \
+                using type = ggml_fp16_t;                                              \
+                constexpr int blck_size = 16;                                          \
+                return __VA_ARGS__();                                                  \
+            }                                                                          \
+            case GGML_TYPE_BF16: {                                                     \
+                using type = ggml_bf16_t;                                              \
+                constexpr int blck_size = 32;                                          \
+                return __VA_ARGS__();                                                  \
+            }                                                                          \
+            default:                                                                   \
+                fprintf(stderr, "Unsupported floating data type\n");                   \
+        }                                                                              \
+    }()
+
+#define GGML_DISPATCH_QTYPES(QT, ...)                                                  \
+    [&] {                                                                              \
+        switch (QT) {                                                                  \
+            case GGML_TYPE_Q4_0: {                                                     \
+                using type = block_q4_0;                                               \
+                using vec_dot_type = block_q8_0;                                       \
+                constexpr int blck_size = QK4_0;                                       \
+                return __VA_ARGS__();                                                  \
+            }                                                                          \
+            case GGML_TYPE_Q4_1: {                                                     \
+                using type = block_q4_1;                                               \
+                using vec_dot_type = block_q8_1;                                       \
+                constexpr int blck_size = QK4_1;                                       \
+                return __VA_ARGS__();                                                  \
+            }                                                                          \
+            case GGML_TYPE_Q8_0: {                                                     \
+                using type = block_q8_0;                                               \
+                using vec_dot_type = block_q8_0;                                       \
+                constexpr int blck_size = QK8_0;                                       \
+                return __VA_ARGS__();                                                  \
+            }                                                                          \
+            case GGML_TYPE_Q4_K: {                                                     \
+                using type = block_q4_K;                                               \
+                using vec_dot_type = block_q8_K;                                       \
+                constexpr int blck_size = QK_K;                                        \
+                return __VA_ARGS__();                                                  \
+            }                                                                          \
+            case GGML_TYPE_Q5_K: {                                                     \
+                using type = block_q5_K;                                               \
+                using vec_dot_type = block_q8_K;                                       \
+                constexpr int blck_size = QK_K;                                        \
+                return __VA_ARGS__();                                                  \
+            }                                                                          \
+            case GGML_TYPE_Q6_K: {                                                     \
+                using type = block_q6_K;                                               \
+                using vec_dot_type = block_q8_K;                                       \
+                constexpr int blck_size = QK_K;                                        \
+                return __VA_ARGS__();                                                  \
+            }                                                                          \
+            case GGML_TYPE_IQ4_XS: {                                                   \
+                using type = block_iq4_xs;                                             \
+                using vec_dot_type = block_q8_K;                                       \
+                constexpr int blck_size = QK_K;                                        \
+                return __VA_ARGS__();                                                  \
+            }                                                                          \
+            default:                                                                   \
+                fprintf(stderr, "Unsupported quantized data type: %d\n", int(TYPE));   \
+        }                                                                              \
+    }()
+
+#define GGML_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...)                                     \
+    [&] {                                                                              \
+        if (BOOL_V) {                                                                  \
+            constexpr bool BOOL_NAME = true;                                           \
+            return __VA_ARGS__();                                                      \
+        } else {                                                                       \
+            constexpr bool BOOL_NAME = false;                                          \
+            return __VA_ARGS__();                                                      \
+        }                                                                              \
+    }()
+
+// define amx tile config data structure
+struct tile_config_t{
+    uint8_t palette_id = 0;
+    uint8_t start_row = 0;
+    uint8_t reserved_0[14] = {0};
+    uint16_t colsb[16] = {0};
+    uint8_t rows[16] = {0};
+};
+
+// Notes: amx tile config
+//
+// Typically, TMUL calculates A and B of size 16 x 64 containing INT8 values,
+// and accumulate the result to a 16 x 16 matrix C containing INT32 values,
+//
+// As many GGUF quantized types as `block_size` of 32, so a 16-16-32 config is used
+// instead of the normally used 16-16-64 config.
+//
+//    Block A: {16, 32}, dtype = int8_t
+//    Block B: {16, 32}, dtype = uint8_t/int8_t
+//    Block C: {16, 16}, dtype = int32_t
+//
+// Block B needs to be prepacked to vnni format before feeding into  TMUL:
+//    packed_B: from {n, k} to {k/vnni_blk, n, vnni_blck}, viewed in 2d, we get {8, 64}
+//
+// Therefore, we get tileconfig:
+//             A    B    C
+//    rows    16    8   16
+//    colsb   32   64   16
+//
+// For tile distribution, follow a 2-2-4 pattern, e.g. A used TMM2-TMM3, B used TMM0-TMM1,
+// C used TMM4-TMM7:
+//            B TMM0  B TMM1
+//    A TMM2  C TMM4  C TMM6
+//    A TMM3  C TMM5  C TMM7
+//
+// Each `amx` kernel handles 4 blocks at a time: 2MB * 2NB, when m < 2 * BLOCK_M, unpack A
+// will be needed.
+//
+// Here another commonly used pattern 1-3-3 is skipped, as it is mostly used when m <=16;
+// and the sinlge batch gemm (m=1) has a special fast path with `avx512-vnni`.
+//
+// ref: https://www.intel.com/content/www/us/en/developer/articles/code-sample/
+//    advanced-matrix-extensions-intrinsics-functions.html
+//
+
+#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb
+void ggml_tile_config_init(void) {
+    static thread_local bool is_first_time = true;
+
+    if (!is_first_time) {
+        return;
+    }
+
+    static thread_local tile_config_t tc;
+    tile_config_t current_tc;
+    _tile_storeconfig(¤t_tc);
+
+    // load only when config changes
+    if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 &&
+                               memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) {
+        tc.palette_id = 1;
+        tc.start_row = 0;
+        TC_CONFIG_TILE(TMM0, 8, 64);
+        TC_CONFIG_TILE(TMM1, 8, 64);
+        TC_CONFIG_TILE(TMM2, 16, 32);
+        TC_CONFIG_TILE(TMM3, 16, 32);
+        TC_CONFIG_TILE(TMM4, 16, 64);
+        TC_CONFIG_TILE(TMM5, 16, 64);
+        TC_CONFIG_TILE(TMM6, 16, 64);
+        TC_CONFIG_TILE(TMM7, 16, 64);
+        _tile_loadconfig(&tc);
+    }
+
+    is_first_time = false;
+}
+
+// we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation.
+// See the notes `s8s8 igemm compensation in avx512-vnni` for detail.
+template 
+int get_tile_size() {
+    int tile_size = TILE_N * sizeof(TB);
+    if (do_compensate::value) {
+        tile_size += TILE_N * sizeof(int32_t);
+    }
+    if (std::is_same::value ||
+        std::is_same::value) {
+        tile_size += TILE_N * 4;
+    }
+    if (std::is_same::value) {
+        tile_size += TILE_N * 2;
+    }
+    return tile_size;
+}
+
+template 
+int get_row_size(int K) {
+    int KB = K / BLOCK_K;
+    int row_size = KB * sizeof(TB);
+    if (do_compensate::value) {
+        row_size += KB * sizeof(int32_t);
+    }
+    if (std::is_same::value ||
+        std::is_same::value) {
+        row_size += KB * 4;
+    }
+    if (std::is_same::value) {
+        row_size += KB * 2;
+    }
+    return row_size;
+}
+
+// vectorized dtype conversion
+inline float FP16_TO_FP32(ggml_half val) {
+    __m256i v = _mm256_setr_epi16(
+        val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0);
+    __m512 o = _mm512_cvtph_ps(v);
+    return _mm512_cvtss_f32(o);
+}
+
+inline __m512 FP16_TO_FP32_VEC(ggml_half val) {
+    __m256i v = _mm256_set1_epi16(val);
+    return _mm512_cvtph_ps(v);
+}
+
+// horizontal reduce
+inline float _mm512_reduce_max_ps(const __m512 x) {
+    __m512 v = x;
+    __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E);
+    v = _mm512_max_ps(v, v1);
+    v1 = _mm512_shuffle_f32x4(v, v, 0xB1);
+    v = _mm512_max_ps(v, v1);
+    v1 = _mm512_shuffle_ps(v, v, 0x4E);
+    v = _mm512_max_ps(v, v1);
+    v1 = _mm512_shuffle_ps(v, v, 0xB1);
+    v = _mm512_max_ps(v, v1);
+    return _mm512_cvtss_f32(v);
+}
+
+// transpose utils
+#define SHUFFLE_EPI32(a, b, mask) \
+    _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask))
+inline void transpose_8x8_32bit(__m256i * v, __m256i * v1) {
+    // unpacking and 32-bit elements
+    v1[0] = _mm256_unpacklo_epi32(v[0], v[1]);
+    v1[1] = _mm256_unpackhi_epi32(v[0], v[1]);
+    v1[2] = _mm256_unpacklo_epi32(v[2], v[3]);
+    v1[3] = _mm256_unpackhi_epi32(v[2], v[3]);
+    v1[4] = _mm256_unpacklo_epi32(v[4], v[5]);
+    v1[5] = _mm256_unpackhi_epi32(v[4], v[5]);
+    v1[6] = _mm256_unpacklo_epi32(v[6], v[7]);
+    v1[7] = _mm256_unpackhi_epi32(v[6], v[7]);
+
+    // shuffling the 32-bit elements
+    v[0] = SHUFFLE_EPI32(v1[0], v1[2], 0x44);
+    v[1] = SHUFFLE_EPI32(v1[0], v1[2], 0xee);
+    v[2] = SHUFFLE_EPI32(v1[4], v1[6], 0x44);
+    v[3] = SHUFFLE_EPI32(v1[4], v1[6], 0xee);
+    v[4] = SHUFFLE_EPI32(v1[1], v1[3], 0x44);
+    v[5] = SHUFFLE_EPI32(v1[1], v1[3], 0xee);
+    v[6] = SHUFFLE_EPI32(v1[5], v1[7], 0x44);
+    v[7] = SHUFFLE_EPI32(v1[5], v1[7], 0xee);
+
+    // shuffling 128-bit elements
+    v1[0] = _mm256_permute2f128_si256(v[2], v[0], 0x02);
+    v1[1] = _mm256_permute2f128_si256(v[3], v[1], 0x02);
+    v1[2] = _mm256_permute2f128_si256(v[6], v[4], 0x02);
+    v1[3] = _mm256_permute2f128_si256(v[7], v[5], 0x02);
+    v1[4] = _mm256_permute2f128_si256(v[2], v[0], 0x13);
+    v1[5] = _mm256_permute2f128_si256(v[3], v[1], 0x13);
+    v1[6] = _mm256_permute2f128_si256(v[6], v[4], 0x13);
+    v1[7] = _mm256_permute2f128_si256(v[7], v[5], 0x13);
+}
+
+inline void transpose_16x4_32bit(__m512i * r, __m512i * d) {
+
+    static const __m512i index1 = _mm512_set_epi32(
+        0x0f, 0x0b, 0x07, 0x03,
+        0x0e, 0x0a, 0x06, 0x02,
+        0x0d, 0x09, 0x05, 0x01,
+        0x0c, 0x08, 0x04, 0x00);
+
+    d[0] = _mm512_permutexvar_epi32(index1, r[0]);
+    d[1] = _mm512_permutexvar_epi32(index1, r[1]);
+    d[2] = _mm512_permutexvar_epi32(index1, r[2]);
+    d[3] = _mm512_permutexvar_epi32(index1, r[3]);
+
+    r[0] = _mm512_shuffle_i32x4(d[0], d[1], 0x44);
+    r[1] = _mm512_shuffle_i32x4(d[0], d[1], 0xee);
+    r[2] = _mm512_shuffle_i32x4(d[2], d[3], 0x44);
+    r[3] = _mm512_shuffle_i32x4(d[2], d[3], 0xee);
+
+    d[0] = _mm512_shuffle_i32x4(r[0], r[2], 0x88);
+    d[1] = _mm512_shuffle_i32x4(r[0], r[2], 0xdd);
+    d[2] = _mm512_shuffle_i32x4(r[1], r[3], 0x88);
+    d[3] = _mm512_shuffle_i32x4(r[1], r[3], 0xdd);
+}
+
+inline void transpose_16x16_32bit(__m512i * v) {
+    __m512i v1[16];
+    v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
+    v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
+    v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
+    v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
+    v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
+    v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
+    v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
+    v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
+    v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
+    v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
+    v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
+    v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
+    v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
+    v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
+    v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
+    v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
+
+    v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
+    v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
+    v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
+    v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
+    v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
+    v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
+    v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
+    v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
+    v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
+    v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
+    v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
+    v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
+    v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
+    v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
+    v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
+    v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
+
+    v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
+    v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
+    v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
+    v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
+    v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
+    v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
+    v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
+    v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
+    v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
+    v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
+    v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
+    v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
+    v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
+    v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
+    v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
+    v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
+
+    v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
+    v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
+    v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
+    v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
+    v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
+    v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
+    v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
+    v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
+    v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
+    v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
+    v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
+    v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
+    v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
+    v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
+    v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
+    v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
+}
+
+void quantize_row_q8_K_vnni(const float * RESTRICT x, void * RESTRICT vy, int64_t k) {
+    assert(k % QK_K == 0);
+    const int KB = k / QK_K;
+    constexpr int kVecs = QK_K / 16;
+
+    block_q8_K * y = reinterpret_cast(vy);
+
+    // hold 16 float vecs from x
+    __m512  v[kVecs];
+
+    // hold the quants vecs
+    __m512i vq[kVecs / 4];
+
+    // hold the packed quants vecs
+    __m512i vq_packed[kVecs / 4];
+
+    const __m512 signBit = _mm512_set1_ps(-0.f);
+
+    for (int i = 0; i < KB; ++i) {
+        // Compute max(abs(e)) for the block
+        __m512 vamax = _mm512_set1_ps(0.f);
+        for (int j = 0; j < kVecs; ++j) {
+            v[j] = _mm512_loadu_ps(x); x += 16;
+            vamax = _mm512_max_ps(vamax, _mm512_andnot_ps(signBit, v[j]));
+        }
+        const float amax = _mm512_reduce_max_ps(vamax);
+
+        // Quantize these floats
+        const float iscale = 127.f / amax;
+        y[i].d = GGML_FP32_TO_FP16(1 / iscale);
+        const float id = ( amax != 0.0f ) ? iscale : 0.f;
+        const __m512 vscale = _mm512_set1_ps(id);
+
+        // Apply multiplier and round to nearest integer
+        for (int j = 0; j < kVecs; ++j) {
+            v[j] = _mm512_mul_ps(v[j], vscale);
+            v[j] = _mm512_roundscale_ps(v[j], (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
+        }
+
+        // Pack to epi8 vecs
+        for (int j = 0; j < kVecs / 4; ++j) {
+            __m128i q8_0 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 0]));
+            __m128i q8_1 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 1]));
+            __m128i q8_2 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 2]));
+            __m128i q8_3 = _mm512_cvtepi32_epi8(_mm512_cvtps_epi32(v[j * 4 + 3]));
+
+            __m256i q8_01 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_0), (q8_1), 1);
+            __m256i q8_23 = _mm256_insertf128_si256(_mm256_castsi128_si256(q8_2), (q8_3), 1);
+
+            vq[j] = _mm512_inserti32x8(_mm512_castsi256_si512(q8_01), q8_23, 1);
+            _mm512_storeu_si512((__m512i *)(y[i].qs + j * 64), vq[j]);
+        }
+
+        // Compute the bsums with vnni
+        transpose_16x4_32bit(vq, vq_packed);
+
+        const __m512i one = _mm512_set1_epi8(1);
+        __m512i sum = _mm512_setzero_si512();
+        for (int k = 0; k < 4; ++k) {
+            sum = _mm512_dpbusd_epi32(sum, one, vq_packed[k]);
+        }
+        _mm256_storeu_si256((__m256i *)(y[i].bsums), _mm512_cvtepi32_epi16(sum));
+    }
+}
+
+// quantize A from float to `vec_dot_type`
+template 
+inline void from_float(const float * x, char * vy, int64_t k);
+
+template <>
+inline void from_float(const float * x, char * vy, int64_t k) {
+    quantize_row_q8_0(x, vy, k);
+}
+
+template <>
+inline void from_float(const float * x, char * vy, int64_t k) {
+    quantize_row_q8_1(x, vy, k);
+}
+
+template <>
+inline void from_float(const float * x, char * vy, int64_t k) {
+#if 1
+    // TODO: this is reference impl!
+    quantize_row_q8_K(x, vy, k);
+#else
+    quantize_row_q8_K_vnni(x, vy, k);
+#endif
+}
+
+// load A from memory to array when nrows can not fill in whole tile
+void unpack_A(int8_t * RESTRICT tile, const block_q8_0 * RESTRICT A, int lda, int nr) {
+    assert(nr != TILE_M);
+    for (int m = 0; m < nr; ++m) {
+        const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));
+        _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
+    }
+}
+
+void unpack_A(int8_t * RESTRICT tile, const block_q8_1 * RESTRICT A, int lda, int nr) {
+    assert(nr != TILE_M);
+    for (int m = 0; m < nr; ++m) {
+        const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs));
+        _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
+    }
+}
+
+template 
+void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {
+    assert(nr <= TILE_M);
+    for (int m = 0; m < nr; ++m) {
+        const __m256i v = _mm256_loadu_si256((const __m256i *)(A[m * lda].qs + k * 32));
+        _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), v);
+    }
+}
+
+template <>
+void unpack_A(int8_t * RESTRICT tile, const block_q8_K * RESTRICT A, int lda, int k, int nr) {
+    assert(nr <= TILE_M);
+    // zero padding k from 16 to 32, so that we don't have to re-config amx
+    const __m128i zero = _mm_setzero_si128();
+    for (int m = 0; m < nr; ++m) {
+        const __m128i v = _mm_loadu_si128((const __m128i *)(A[m * lda].qs + k * 16));
+        const __m256i r = _mm256_insertf128_si256(_mm256_castsi128_si256(v), zero, 1);
+        _mm256_storeu_si256((__m256i *)(tile + m * TILE_K), r);
+    }
+}
+
+#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)
+inline __m256i bytes_from_nibbles_32(const uint8_t * rsi) {
+    const __m128i tmp = _mm_loadu_si128((const __m128i *)rsi);
+    const __m256i bytes = MM256_SET_M128I(_mm_srli_epi16(tmp, 4), tmp);
+    const __m256i lowMask = _mm256_set1_epi8(0xF);
+    return _mm256_and_si256(lowMask, bytes);
+}
+
+// used for block_q4_K
+inline __m512i bytes_from_nibbles_64(const uint8_t * rsi) {
+    const __m256i tmp = _mm256_loadu_si256((const __m256i *)rsi);
+    const __m256i lowMask = _mm256_set1_epi8(0xF);
+    const __m256i q4l = _mm256_and_si256(tmp, lowMask);
+    const __m256i q4h = _mm256_and_si256(_mm256_srli_epi16(tmp, 4), lowMask);
+    return _mm512_inserti32x8(_mm512_castsi256_si512(q4l), q4h, 1);
+}
+
+// used for block_q5_K
+inline __m512i bytes_from_nibbles_64(const uint8_t * qs, const uint8_t * qh, int k) {
+    const __m256i lowMask = _mm256_set1_epi8(0xF);
+    __m256i hmask = _mm256_set1_epi8(1);
+    hmask = _mm256_slli_epi16(hmask, k);
+
+    const __m256i q5bits = _mm256_loadu_si256((const __m256i *)qs);
+    const __m256i hbits = _mm256_loadu_si256((const __m256i *)qh);
+
+    const __m256i q5l_0 = _mm256_and_si256(q5bits, lowMask);
+    const __m256i q5h_0 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 0), 4);
+    const __m256i q5_0  = _mm256_add_epi8(q5l_0, q5h_0);
+    hmask = _mm256_slli_epi16(hmask, 1);
+
+    const __m256i q5l_1 = _mm256_and_si256(_mm256_srli_epi16(q5bits, 4), lowMask);
+    const __m256i q5h_1 = _mm256_slli_epi16(_mm256_srli_epi16(_mm256_and_si256(hbits, hmask), k + 1), 4);
+    const __m256i q5_1  = _mm256_add_epi8(q5l_1, q5h_1);
+
+    return _mm512_inserti32x8(_mm512_castsi256_si512(q5_0), q5_1, 1);
+}
+
+// used for block_q6_K
+inline void bytes_from_nibbles_128(__m512i& r0, __m512i& r1, const uint8_t * qs, const uint8_t * qh) {
+    const __m256i m4 = _mm256_set1_epi8(0xF);
+    const __m256i m2 = _mm256_set1_epi8(0x3);
+
+    const __m256i q6bits1 = _mm256_loadu_si256((const __m256i *)qs);
+    const __m256i q6bits2 = _mm256_loadu_si256((const __m256i *)(qs + 32));
+    const __m256i q6bitsH = _mm256_loadu_si256((const __m256i *)qh);
+
+    const __m256i q6h_0 = _mm256_slli_epi16(_mm256_and_si256(                  q6bitsH,     m2), 4);
+    const __m256i q6h_1 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 2), m2), 4);
+    const __m256i q6h_2 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 4), m2), 4);
+    const __m256i q6h_3 = _mm256_slli_epi16(_mm256_and_si256(_mm256_srli_epi16(q6bitsH, 6), m2), 4);
+
+    const __m256i q6_0 = _mm256_or_si256(_mm256_and_si256(q6bits1, m4), q6h_0);
+    const __m256i q6_1 = _mm256_or_si256(_mm256_and_si256(q6bits2, m4), q6h_1);
+    const __m256i q6_2 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits1, 4), m4), q6h_2);
+    const __m256i q6_3 = _mm256_or_si256(_mm256_and_si256(_mm256_srli_epi16(q6bits2, 4), m4), q6h_3);
+
+    r0 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_0), q6_1, 1);
+    r1 = _mm512_inserti32x8(_mm512_castsi256_si512(q6_2), q6_3, 1);
+}
+
+inline __m512i packNibbles(__m512i r0, __m512i r1) {
+    return _mm512_or_si512(r0, _mm512_slli_epi16(r1, 4));
+}
+
+template 
+inline void pack_qs(void * RESTRICT packed_B, const TB * RESTRICT B, int KB) {
+    int8_t tmp[8 * 64];
+    __m256i v[8], v2[8];
+    for (int n = 0; n < 8; ++n) {
+        v[n] = bytes_from_nibbles_32(B[n * KB].qs);
+    }
+    transpose_8x8_32bit(v, v2);
+    for (int n = 0; n < 8; ++n) {
+        _mm256_storeu_si256((__m256i *)(tmp + n * 64), v2[n]);
+    }
+    for (int n = 0; n < 8; ++n) {
+        v[n] = bytes_from_nibbles_32(B[(n + 8) * KB].qs);
+    }
+    transpose_8x8_32bit(v, v2);
+    for (int n = 0; n < 8; ++n) {
+        _mm256_storeu_si256((__m256i *)(tmp + n * 64 + 32), v2[n]);
+    }
+
+    // pack again with 128 to fully utilize vector length
+    for (int n = 0; n < 8; n += 2) {
+        __m512i r0 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64));
+        __m512i r1 = _mm512_loadu_si512((const __m512i *)(tmp + n * 64 + 64));
+        __m512i r1r0 = packNibbles(r0, r1);
+        _mm512_storeu_si512((__m512i *)((char *)packed_B + n * 32), r1r0);
+    }
+}
+
+template <>
+inline void pack_qs(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {
+    __m256i v[8], v2[8];
+    for (int n = 0; n < 8; ++n) {
+        v[n] = _mm256_loadu_si256((const __m256i *)(B[n * KB].qs));
+    }
+    transpose_8x8_32bit(v, v2);
+    for (int n = 0; n < 8; ++n) {
+        _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64), v2[n]);
+    }
+    for (int n = 0; n < 8; ++n) {
+        v[n] = _mm256_loadu_si256((const __m256i *)(B[(n + 8) * KB].qs));
+    }
+    transpose_8x8_32bit(v, v2);
+    for (int n = 0; n < 8; ++n) {
+        _mm256_storeu_si256((__m256i *)((char *)packed_B + n * 64 + 32), v2[n]);
+    }
+}
+
+template <>
+inline void pack_qs(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {
+    __m512i v[16];
+    // QK_K 256 with 8 groups, handle 2 groups at a time
+    char * pb = (char *)packed_B;
+    for (int k = 0; k < QK_K / 64; ++k) {
+        // pack 2 groups { n, g,  k} to {g, k/4, 4n}
+        //          e.g. {16, 2, 32} to {2,   8, 64}
+        for (int n = 0; n < TILE_N; ++n) {
+            v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32);
+        }
+
+        transpose_16x16_32bit(v);
+
+        // pack again with 128 to fully utilize vector length
+        for (int n = 0; n < TILE_N; n += 2) {
+            _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));
+            pb += 64;
+        }
+    }
+}
+
+template <>
+inline void pack_qs(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {
+    __m512i v[16];
+    const __m512i lowMask = _mm512_set1_epi8(0xF);
+    // QK_K 256 with 8 groups, handle 2 groups at a time
+    char * pb = (char *)packed_B;
+    char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;
+    for (int k = 0; k < QK_K / 64; ++k) {
+        // pack 2 groups { n, g,  k} to {g, k/4, 4n}
+        //          e.g. {16, 2, 32} to {2,   8, 64}
+        for (int n = 0; n < TILE_N; ++n) {
+            v[n] = bytes_from_nibbles_64(B[n * KB].qs + k * 32, B[n * KB].qh, /* group */2 * k);
+        }
+
+        transpose_16x16_32bit(v);
+
+        // 1. pack lower 4bits with 2 groups
+        for (int n = 0; n < TILE_N; n += 2) {
+            // get lower 4 bits
+            const __m512i r0 = _mm512_and_si512(v[n], lowMask);
+            const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);
+            _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;
+        }
+
+        // 2. pack higher 1bit with 2 groups
+        const __m512i hmask = _mm512_set1_epi8(0x10);
+        for (int g = 0; g < 2; ++g) {
+            __m512i hbits = _mm512_setzero_si512();
+            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 0], hmask), 4));
+            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 1], hmask), 3));
+            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 2], hmask), 2));
+            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 8 + 3], hmask), 1));
+            hbits = _mm512_add_epi8(hbits,                   _mm512_and_si512(v[g * 8 + 4], hmask)    );
+            hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 5], hmask), 1));
+            hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 6], hmask), 2));
+            hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 8 + 7], hmask), 3));
+            _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;
+        }
+    }
+}
+
+template <>
+inline void pack_qs(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {
+    __m512i v[32];
+    const __m512i lowMask = _mm512_set1_epi8(0xF);
+    // QK_K 256 with 8 groups, handle 4 groups at a time
+    char * pb = (char *)packed_B;
+    char * ph = (char *)packed_B + (QK_K / 2) * TILE_N;
+    for (int k = 0; k < QK_K / 128; ++k) {
+        for (int n = 0; n < TILE_N; ++n) {
+            bytes_from_nibbles_128(v[n], v[n + 16], B[n * KB].ql + k * 64, B[n * KB].qh + k * 32);
+        }
+
+        // top half: group 0,1 or 4,5; bottom half: group 2,3 or 6,7
+        transpose_16x16_32bit(v);
+        transpose_16x16_32bit(v + 16);
+
+        // 1. pack lower 4bits with 4 groups
+        for (int n = 0; n < 32; n += 2) {
+            const __m512i r0 = _mm512_and_si512(v[n], lowMask);
+            const __m512i r1 = _mm512_and_si512(v[n + 1], lowMask);
+            _mm512_storeu_si512((__m512i *)pb, packNibbles(r0, r1)); pb += 64;
+        }
+
+        // 2. pack higher 2bit with 4 groups
+        const __m512i hmask = _mm512_set1_epi8(0x30);
+        for (int g = 0; g < 8; ++g) {
+            __m512i hbits = _mm512_setzero_si512();
+            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 0], hmask), 4));
+            hbits = _mm512_add_epi8(hbits, _mm512_srli_epi16(_mm512_and_si512(v[g * 4 + 1], hmask), 2));
+            hbits = _mm512_add_epi8(hbits,                   _mm512_and_si512(v[g * 4 + 2], hmask)    );
+            hbits = _mm512_add_epi8(hbits, _mm512_slli_epi16(_mm512_and_si512(v[g * 4 + 3], hmask), 2));
+            _mm512_storeu_si512((__m512i *)ph, hbits); ph += 64;
+        }
+    }
+}
+
+template <>
+inline void pack_qs(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {
+    __m512i v[16];
+    char * pb = (char *)packed_B;
+    for (int k = 0; k < QK_K / 64; ++k) {
+        for (int n = 0; n < TILE_N; ++n) {
+            __m256i r0 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 +  0);
+            __m256i r1 = bytes_from_nibbles_32(B[n * KB].qs + k * 32 + 16);
+            v[n] = _mm512_inserti32x8(_mm512_castsi256_si512(r0), r1, 1);
+        }
+
+        transpose_16x16_32bit(v);
+
+        // pack again with 128 to fully utilize vector length
+        for (int n = 0; n < TILE_N; n += 2) {
+            _mm512_storeu_si512((__m512i *)pb, packNibbles(v[n], v[n + 1]));
+            pb += 64;
+        }
+    }
+}
+
+// pack B to vnni formats in 4bits or 8 bits
+void pack_B(void * RESTRICT packed_B, const block_q4_0 * RESTRICT B, int KB) {
+    pack_qs(packed_B, B, KB);
+    ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K / 2);
+    for (int n = 0; n < TILE_N; ++n) {
+        d0[n] = B[n * KB].d;
+    }
+}
+
+void pack_B(void * RESTRICT packed_B, const block_q4_1 * RESTRICT B, int KB) {
+    pack_qs(packed_B, B, KB);
+    ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K / 2);
+    ggml_half * m0 = d0 + TILE_N;
+    for (int n = 0; n < TILE_N; ++n) {
+        d0[n] = B[n * KB].d;
+        m0[n] = B[n * KB].m;
+    }
+}
+
+inline void s8s8_compensation(void * RESTRICT packed_B) {
+    // packed_B layout:
+    //   quants {TILE_N, TILEK}  int8_t
+    //   d0     {TILE_N}      ggml_half
+    //   comp   {TILE_N}        int32_t
+    const int offset = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);
+    __m512i vcomp = _mm512_setzero_si512();
+    const __m512i off = _mm512_set1_epi8(static_cast(0x80));
+    for (int k = 0; k < 8; ++k) {
+        __m512i vb = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + k * 64));
+        vcomp = _mm512_dpbusd_epi32(vcomp, off, vb);
+    }
+    _mm512_storeu_si512((__m512i *)((char *)(packed_B) + offset), vcomp);
+}
+
+void pack_B(void * RESTRICT packed_B, const block_q8_0 * RESTRICT B, int KB) {
+    pack_qs(packed_B, B, KB);
+    ggml_half * d0 = reinterpret_cast((char *)packed_B + TILE_N * TILE_K);
+    for (int n = 0; n < TILE_N; ++n) {
+        d0[n] = B[n * KB].d;
+    }
+    s8s8_compensation(packed_B);
+}
+
+// convert 8 * {min, scale} from int6 to int8
+inline void unpack_mins_and_scales(const uint8_t * scales, uint32_t * utmp) {
+    const uint32_t kmask1 = 0x3f3f3f3f;
+    const uint32_t kmask2 = 0x0f0f0f0f;
+    const uint32_t kmask3 = 0x03030303;
+
+    memcpy(utmp, scales, 12);
+    utmp[3] = ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4);
+    const uint32_t uaux = utmp[1] & kmask1;
+    utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
+    utmp[2] = uaux;
+    utmp[0] &= kmask1;
+}
+
+// packed_B layout:
+//   quants {8, TILE_N, 16}  uint8
+//   scales {8, TILE_N}      uint8
+//   mins   {8, TILE_N}      uint8
+//   d      {TILE_N}     ggml_half
+//   dmin   {TILE_N}     ggml_half
+void pack_B(void * RESTRICT packed_B, const block_q4_K * RESTRICT B, int KB) {
+    pack_qs(packed_B, B, KB);
+
+    uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N);
+    uint8_t * mins = scales + 8 * TILE_N;
+    ggml_half * d = reinterpret_cast(mins + 8 * TILE_N);
+    ggml_half * dmin = d + TILE_N;
+
+    union {
+        uint32_t u32[4];
+        uint8_t  u8[16];
+    } s;
+
+    for (int n = 0; n < TILE_N; ++n) {
+        unpack_mins_and_scales(B[n * KB].scales, s.u32);
+        for (int k = 0; k < 8; ++k) {
+            scales[k * TILE_N + n] = s.u8[k];
+            mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];
+        }
+        d[n] = B[n * KB].d;
+        dmin[n] = B[n * KB].dmin;
+    }
+}
+
+// packed_B layout:
+//   quants {8, TILE_N, 16}  uint8
+//   qh     {8, TILE_N,  4}  uint8
+//   scales {8, TILE_N}      uint8
+//   mins   {8, TILE_N}      uint8
+//   d      {TILE_N}     ggml_half
+//   dmin   {TILE_N}     ggml_half
+void pack_B(void * RESTRICT packed_B, const block_q5_K * RESTRICT B, int KB) {
+    pack_qs(packed_B, B, KB);
+
+    uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);
+    uint8_t * mins = scales + 8 * TILE_N;
+    ggml_half * d = reinterpret_cast(mins + 8 * TILE_N);
+    ggml_half * dmin = d + TILE_N;
+
+    union {
+        uint32_t u32[4];
+        uint8_t  u8[16];
+    } s;
+
+    for (int n = 0; n < TILE_N; ++n) {
+        unpack_mins_and_scales(B[n * KB].scales, s.u32);
+        for (int k = 0; k < 8; ++k) {
+            scales[k * TILE_N + n] = s.u8[k];
+            mins[(k >> 1) * TILE_N * 2 + n * 2 + (k & 0x1)] = s.u8[k + 8];
+        }
+        d[n] = B[n * KB].d;
+        dmin[n] = B[n * KB].dmin;
+    }
+}
+
+// packed_B layout:
+//   quants {16, TILE_N, 8}  uint8
+//   qh     {16, TILE_N, 4}  uint8
+//   scales {16, TILE_N}      uint8
+//   d      {TILE_N}     ggml_half
+void pack_B(void * RESTRICT packed_B, const block_q6_K * RESTRICT B, int KB) {
+    pack_qs(packed_B, B, KB);
+
+    uint8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);
+    ggml_half * d = reinterpret_cast(scales + 16 * TILE_N);
+    for (int n = 0; n < TILE_N; ++n) {
+        const int8_t * ps = B[n * KB].scales;
+        for (int k = 0; k < 16; ++k) {
+            scales[k * TILE_N + n] = ps[k];
+        }
+        d[n] = B[n * KB].d;
+    }
+}
+
+// packed_B layout:
+//   quants {8, TILE_N, 16}  uint8
+//   scales {8, TILE_N}       int8
+//   d      {TILE_N}     ggml_half
+void pack_B(void * RESTRICT packed_B, const block_iq4_xs * RESTRICT B, int KB) {
+    pack_qs(packed_B, B, KB);
+
+    int8_t * scales = reinterpret_cast((char *)packed_B + (QK_K / 2) * TILE_N);
+    ggml_half * d = reinterpret_cast(scales + 8 * TILE_N);
+
+    // pack the scales
+    for (int n = 0; n < TILE_N; ++n) {
+        uint16_t sh = B[n * KB].scales_h;
+        for (int k = 0; k < 8; k += 2) {
+            const int16_t ls1 = ((B[n * KB].scales_l[k / 2] & 0xf) | ((sh << 4) & 0x30)) - 32;
+            const int16_t ls2 = ((B[n * KB].scales_l[k / 2] >>  4) | ((sh << 2) & 0x30)) - 32;
+            scales[(k + 0) * TILE_N + n] = ls1;
+            scales[(k + 1) * TILE_N + n] = ls2;
+            sh >>= 4;
+        }
+        d[n] = B[n * KB].d;
+    }
+}
+
+template>
+void unpack_B(packed_B_t * RESTRICT tile, const void * RESTRICT packed_B) {
+    GGML_UNUSED(tile);
+    GGML_UNUSED(packed_B);
+};
+
+template <>
+void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B) {
+  const __m512i off = _mm512_set1_epi8(8);
+  const __m512i lowMask = _mm512_set1_epi8(0xF);
+  for (int n = 0; n < 8; n += 2) {
+    __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));
+    const __m512i r0 = _mm512_sub_epi8(_mm512_and_si512(bytes, lowMask), off);
+    const __m512i r1 = _mm512_sub_epi8(_mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask), off);
+    _mm512_storeu_si512((__m512i *)(tile + n * 64 +  0), r0);
+    _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
+  }
+}
+
+template <>
+void unpack_B(uint8_t * RESTRICT tile, const void * RESTRICT packed_B) {
+    const __m512i lowMask = _mm512_set1_epi8(0xF);
+    for (int n = 0; n < 8; n += 2) {
+        __m512i bytes = _mm512_loadu_si512((const __m512i *)((const char *)packed_B + n * 32));
+        const __m512i r0 = _mm512_and_si512(bytes, lowMask);
+        const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+        _mm512_storeu_si512((__m512i *)(tile + n * 64 +  0), r0);
+        _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
+    }
+}
+
+// packed_B_t for QKK is int8_t
+template 
+void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
+    const int packed_B_group_size = QK_K / 2 * TILE_N / 8;
+    const char * packed_B_group = (const char *)packed_B + k * packed_B_group_size;
+    const __m512i lowMask = _mm512_set1_epi8(0xF);
+    for (int n = 0; n < 8; n += 2) {
+        __m512i bytes = _mm512_loadu_si512(packed_B_group + n * 32);
+        const __m512i r0 = _mm512_and_si512(bytes, lowMask);
+        const __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+        _mm512_storeu_si512((__m512i *)(tile + n * 64 +  0), r0);
+        _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
+    }
+}
+
+template <>
+void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
+    // lower 4bits, stride 256 bytes
+    const int packed_l4_group_size = QK_K / 2 * TILE_N / 8;
+    const char * pb = (const char *)packed_B + k * packed_l4_group_size;
+
+    // higher 1bit, stride 64 bytes
+    const int packed_h1_group_size = QK_K / 8 * TILE_N / 8;
+    const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h1_group_size;
+    const __m512i hbits = _mm512_loadu_si512(ph);
+
+    const __m512i lowMask = _mm512_set1_epi8(0xF);
+    __m512i hmask0 = _mm512_set1_epi8(0x1);
+    __m512i hmask1 = _mm512_set1_epi8(0x2);
+
+    for (int n = 0; n < 8; n += 2) {
+        __m512i bytes = _mm512_loadu_si512(pb + n * 32);
+        __m512i r0 = _mm512_and_si512(bytes, lowMask);
+        __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+        __m512i h0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), n), 4);
+        __m512i h1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), n + 1), 4);
+
+        hmask0 = _mm512_slli_epi16(hmask0, 2);
+        hmask1 = _mm512_slli_epi16(hmask1, 2);
+        r0 = _mm512_add_epi8(r0, h0);
+        r1 = _mm512_add_epi8(r1, h1);
+        _mm512_storeu_si512((__m512i *)(tile + n * 64 +  0), r0);
+        _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
+    }
+}
+
+template <>
+void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
+    // lower 4bits, stride 128 bytes
+    const int packed_l4_group_size = QK_K / 2 * TILE_N / 16;
+    const char * pb = (const char *)packed_B + k * packed_l4_group_size;
+
+    // higher 2bits, stride 64 bytes
+    const int packed_h2_group_size = QK_K / 4 * TILE_N / 16;
+    const char * ph = (const char *)packed_B + (QK_K / 2) * TILE_N + k * packed_h2_group_size;
+    const __m512i hbits = _mm512_loadu_si512(ph);
+
+    const __m512i off = _mm512_set1_epi8(32);
+    const __m512i lowMask = _mm512_set1_epi8(0xF);
+    __m512i hmask0 = _mm512_set1_epi8(0x3); // 0011
+    __m512i hmask1 = _mm512_set1_epi8(0xC); // 1100
+
+    // notes: skip zero padding from row4 to row7 as we have done so in `unpack_A`
+    __m512i bytes = _mm512_loadu_si512(pb);
+    __m512i r0 = _mm512_and_si512(bytes, lowMask);
+    __m512i r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+    __m512i h0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask0), 4);
+    __m512i h1 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask1), 2);
+    _mm512_storeu_si512((__m512i *)(tile +  0), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));
+    _mm512_storeu_si512((__m512i *)(tile + 64), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));
+
+    hmask0 = _mm512_slli_epi16(hmask0, 4);
+    hmask1 = _mm512_slli_epi16(hmask1, 4);
+
+    bytes = _mm512_loadu_si512(pb + 64);
+    r0 = _mm512_and_si512(bytes, lowMask);
+    r1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+    h0 =                   _mm512_and_si512(hbits, hmask0);
+    h1 = _mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), 2);
+    _mm512_storeu_si512((__m512i *)(tile + 128), _mm512_sub_epi8(_mm512_add_epi8(r0, h0), off));
+    _mm512_storeu_si512((__m512i *)(tile + 192), _mm512_sub_epi8(_mm512_add_epi8(r1, h1), off));
+}
+
+template <>
+void unpack_B(int8_t * RESTRICT tile, const void * RESTRICT packed_B, int k) {
+    static const __m512i values128 = _mm512_set_epi8(
+        113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
+        113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
+        113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
+        113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127
+    );
+
+    const int packed_B_group_size = QK_K / 2 * TILE_N / 8;
+    const char * pb = (const char *)packed_B + k * packed_B_group_size;
+    const __m512i lowMask = _mm512_set1_epi8(0xF);
+
+    for (int n = 0; n < 8; n += 2) {
+        __m512i bytes = _mm512_loadu_si512(pb + n * 32);
+        const __m512i r0 = _mm512_shuffle_epi8(values128, _mm512_and_si512(bytes, lowMask));
+        const __m512i r1 = _mm512_shuffle_epi8(values128, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));
+        _mm512_storeu_si512((__m512i *)(tile + n * 64 +  0), r0);
+        _mm512_storeu_si512((__m512i *)(tile + n * 64 + 64), r1);
+    }
+}
+
+template 
+struct acc_C {};
+
+template 
+struct acc_C {
+    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {
+        const int offset = TILE_N * TILE_K / 2;
+        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
+
+        for (int m = 0; m < nr; ++m) {
+            const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
+            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
+
+            __m512 vsum;
+            if (is_acc) {
+                vsum = _mm512_loadu_ps(C + m * ldc);
+            } else {
+                vsum = _mm512_set1_ps(0.f);
+            }
+            vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
+            _mm512_storeu_ps(C + m * ldc, vsum);
+        }
+    }
+};
+
+template 
+struct acc_C {
+    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_1 * A, int lda, const void * packed_B, int nr) {
+        const int offset = TILE_N * TILE_K / 2;
+        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
+        const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset + TILE_N * sizeof(ggml_half))));
+
+        for (int m = 0; m < nr; ++m) {
+            const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
+            const __m512 vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].s));
+            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
+
+            __m512 vsum;
+            if (is_acc) {
+                vsum = _mm512_loadu_ps(C + m * ldc);
+            } else {
+                vsum = _mm512_set1_ps(0.f);
+            }
+            vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
+            vsum = _mm512_fmadd_ps(vm0, vs1, vsum);
+            _mm512_storeu_ps(C + m * ldc, vsum);
+        }
+    }
+};
+
+template 
+struct acc_C {
+    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_0 * A, int lda, const void * packed_B, int nr) {
+        const int offset = TILE_N * TILE_K;
+        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)((const char *)packed_B + offset)));
+
+        for (int m = 0; m < nr; ++m) {
+            const __m512 vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[m * lda].d));
+            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
+
+            __m512 vsum;
+            if (is_acc) {
+                vsum = _mm512_loadu_ps(C + m * ldc);
+            } else {
+                vsum = _mm512_set1_ps(0.f);
+            }
+            vsum = _mm512_fmadd_ps(vtile, _mm512_mul_ps(vd0, vd1), vsum);
+            _mm512_storeu_ps(C + m * ldc, vsum);
+        }
+    }
+};
+
+template 
+struct acc_C {
+    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
+        const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N);
+        const uint8_t * mins = scales + 8 * TILE_N;
+        const ggml_half * d0 = reinterpret_cast(mins + 8 * TILE_N);
+        const ggml_half * dmin = d0 + TILE_N;
+
+        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
+        const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));
+
+        for (int m = 0; m < nr; ++m) {
+            const float d1 = A[m * lda].d;
+            const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
+            const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);
+            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
+
+            __m512 vsum;
+            if (is_acc) {
+                vsum = _mm512_loadu_ps(C + m * ldc);
+            } else {
+                vsum = _mm512_set1_ps(0.f);
+            }
+
+            const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);
+            const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+
+            __m512i acc_m = _mm512_setzero_si512();
+            for (int k = 0; k < 4; ++k) {
+                __m512i vmask = _mm512_set1_epi32(k);
+                __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));
+                __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));
+                acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
+            }
+
+            vsum = _mm512_fmadd_ps(vtile, vd, vsum);
+            vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);
+            _mm512_storeu_ps(C + m * ldc, vsum);
+        }
+    }
+};
+
+template 
+struct acc_C {
+    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
+        const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N);
+        const uint8_t * mins = scales + 8 * TILE_N;
+        const ggml_half * d0 = reinterpret_cast(mins + 8 * TILE_N);
+        const ggml_half * dmin = d0 + TILE_N;
+
+        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
+        const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)dmin));
+
+        for (int m = 0; m < nr; ++m) {
+            const float d1 = A[m * lda].d;
+            const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
+            const __m512 vdm = _mm512_mul_ps(_mm512_set1_ps(-d1), vdmin);
+            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
+
+            __m512 vsum;
+            if (is_acc) {
+                vsum = _mm512_loadu_ps(C + m * ldc);
+            } else {
+                vsum = _mm512_set1_ps(0.f);
+            }
+
+            const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[m * lda].bsums);
+            const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+
+            __m512i acc_m = _mm512_setzero_si512();
+            for (int k = 0; k < 4; ++k) {
+                __m512i vmask = _mm512_set1_epi32(k);
+                __m512i va = _mm512_permutexvar_epi32(vmask, _mm512_castsi128_si512(q8s));
+                __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(mins + k * 32)));
+                acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
+            }
+
+            vsum = _mm512_fmadd_ps(vtile, vd, vsum);
+            vsum = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc_m), vdm, vsum);
+            _mm512_storeu_ps(C + m * ldc, vsum);
+        }
+    }
+};
+
+template 
+struct acc_C {
+    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
+        const uint8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N);
+        const ggml_half * d0 = reinterpret_cast(scales + 16 * TILE_N);
+
+        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
+
+        for (int m = 0; m < nr; ++m) {
+            const float d1 = A[m * lda].d;
+            const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
+            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
+
+            __m512 vsum;
+            if (is_acc) {
+                vsum = _mm512_loadu_ps(C + m * ldc);
+            } else {
+                vsum = _mm512_set1_ps(0.f);
+            }
+
+            vsum = _mm512_fmadd_ps(vtile, vd, vsum);
+            _mm512_storeu_ps(C + m * ldc, vsum);
+        }
+    }
+};
+
+template 
+struct acc_C {
+    static void apply(float * RESTRICT C, int ldc, const int32_t * RESTRICT tile, const block_q8_K * A, int lda, const void * packed_B, int nr) {
+        const int8_t * scales = reinterpret_cast((const char *)packed_B + (QK_K / 2) * TILE_N);
+        const ggml_half * d0 = reinterpret_cast(scales + 8 * TILE_N);
+
+        const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)d0));
+
+        for (int m = 0; m < nr; ++m) {
+            const float d1 = A[m * lda].d;
+            const __m512 vd = _mm512_mul_ps(_mm512_set1_ps(d1), vd0);
+            const __m512 vtile = _mm512_cvtepi32_ps(_mm512_loadu_si512(tile + m * TILE_N));
+
+            __m512 vsum;
+            if (is_acc) {
+                vsum = _mm512_loadu_ps(C + m * ldc);
+            } else {
+                vsum = _mm512_set1_ps(0.f);
+            }
+
+            vsum = _mm512_fmadd_ps(vtile, vd, vsum);
+            _mm512_storeu_ps(C + m * ldc, vsum);
+        }
+    }
+};
+
+template  constexpr int get_quants_size();
+template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N; }
+template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N; }
+template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N; }
+template <> constexpr int get_quants_size() { return (QK_K / 2) * TILE_N; }
+
+// used for QKK format
+template ::value, int>::type = 0>
+inline void scale_C(const int32_t * RESTRICT tile, int32_t * RESTRICT sumi, const void * packed_B, int k, int nr) {
+    const uint8_t * scales = reinterpret_cast((const char *)packed_B + get_quants_size());
+    const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(scales + k * TILE_N)));
+
+    for (int m = 0; m < nr; ++m) {
+        __m512i vsumi;
+        if (is_acc) {
+            vsumi = _mm512_loadu_si512(sumi + m * TILE_N);
+        } else {
+            vsumi = _mm512_setzero_si512();
+        }
+        __m512i vtile = _mm512_loadu_si512(tile + m * TILE_N);
+        vsumi = _mm512_add_epi32(vsumi, _mm512_mullo_epi32(vtile, vscale));
+        _mm512_storeu_si512((__m512i *)(sumi + m * TILE_N), vsumi);
+    }
+}
+
+template 
+struct tinygemm_kernel_avx {
+    static void apply(int K, const TA * RESTRICT A, const TB * RESTRICT B, TC * RESTRICT C, int ldc) {
+        GGML_UNUSED(K);
+        GGML_UNUSED(A);
+        GGML_UNUSED(B);
+        GGML_UNUSED(C);
+        GGML_UNUSED(ldc);
+    }
+};
+
+template 
+struct tinygemm_kernel_avx {
+    static void apply(int K, const float * RESTRICT A, const ggml_fp16_t * RESTRICT B, float * RESTRICT C, int ldc) {
+        constexpr int ROWS = BLOCK_M;
+        constexpr int COLS = BLOCK_N;
+        assert(BLOCK_K == 16);
+
+        __m512 va;
+        __m512 vb[COLS];
+        __m512 vc[ROWS * COLS];
+
+        auto loadc = [&](int idx) {
+            vc[idx] = _mm512_setzero_ps();
+        };
+        Unroll{}(loadc);
+
+        auto compute = [&](int idx, int k) {
+            // TODO: use `constexpr` here to get rid of interger div
+            // when upgraded to C++17
+            const int row = idx / COLS;
+            const int col = idx % COLS;
+
+            if (col == 0) {
+                va = _mm512_loadu_ps(A + row * K + k);
+            }
+            if (row == 0) {
+                vb[col] =  _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(B + col * K + k)));
+            }
+            vc[idx] = _mm512_fmadd_ps(va, vb[col], vc[idx]);
+        };
+
+        for (int k = 0; k < K; k += 16) {
+            Unroll{}(compute, k);
+        }
+
+        auto storec = [&](int idx) {
+            const int row = idx / COLS;
+            const int col = idx % COLS;
+            C[row * ldc + col] = _mm512_reduce_add_ps(vc[idx]);
+        };
+        Unroll{}(storec);
+    }
+};
+
+#define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE)                                \
+    tinygemm_kernel_avx::apply(    \
+        K, (const float *)src1->data + mb_start * K,                                \
+        (const type *)src0->data + nb_start * K,                                    \
+        (float *)dst->data + mb_start * ldc + nb_start, ldc);
+
+
+// re-organize in the format {NB, KB, TILE_SIZE}:
+#define PACKED_INDEX(n, k, KB, tile_size) (n * KB + k) * tile_size
+
+template
+void convert_B_packed_format(void * RESTRICT packed_B, const TB * RESTRICT B, int N, int K, int n_threads) {
+    const int NB = N / TILE_N;
+    const int KB = K / BLOCK_K;
+    const int TILE_SIZE = get_tile_size();
+
+    // parallel on NB should be enough
+    parallel_for(n_threads, NB, [&](int begin, int end) {
+        for (int n = begin; n < end; ++n) {
+            for (int k = 0; k < KB; ++k) {
+                int n0 = n * TILE_N;
+                pack_B((char *)packed_B + PACKED_INDEX(n, k, KB, TILE_SIZE), &B[n0 * KB + k], KB);
+            }
+        }
+    });
+}
+
+template 
+struct tinygemm_kernel_vnni {};
+
+template 
+struct tinygemm_kernel_vnni {
+    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
+
+        constexpr int COLS = BLOCK_N / 16;
+        const int TILE_SIZE = TILE_N * sizeof(block_q4_0);
+
+        const block_q8_0 * RESTRICT A = static_cast(_A);
+        const char * RESTRICT B = static_cast(_B);
+
+        __m512i va[8];
+        __m512 vc[COLS];
+        __m512 vd1;
+
+        // sum of offsets, shared across COLS
+        //
+        // avx512-vnni does not have `_mm512_dpbssd_epi32`,
+        // need to transfrom ss to us:
+        //   a * (b - 8) is equavilent to b * a - 8 * a
+        //   s    u   u                   u   s   u   s
+        //
+        __m512i vcomp;
+
+        const __m512i off = _mm512_set1_epi8(8);
+        const __m512i lowMask = _mm512_set1_epi8(0xF);
+
+        auto loadc = [&](int col) {
+            vc[col] = _mm512_setzero_ps();
+        };
+        Unroll{}(loadc);
+
+        auto compute = [&](int col, int i) {
+            // load a and compute compensation
+            if (col == 0) {
+                const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs);
+                vcomp = _mm512_setzero_si512();
+                for (int k = 0; k < 8; ++k) {
+                    va[k] = _mm512_set1_epi32(a_ptr[k]);
+                    vcomp = _mm512_dpbusd_epi32(vcomp, off, va[k]);
+                }
+                vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
+            }
+
+            // load b
+            __m512i vsum = _mm512_setzero_si512();
+            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
+            for (int k = 0; k < 8; k += 2) {
+                __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));
+                __m512i vb0 = _mm512_and_si512(bytes, lowMask);
+                vsum = _mm512_dpbusd_epi32(vsum, vb0, va[k + 0]);
+                __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+                vsum = _mm512_dpbusd_epi32(vsum, vb1, va[k + 1]);
+            }
+            const int offset = TILE_N * TILE_K / 2;
+            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
+            vsum = _mm512_sub_epi32(vsum, vcomp);
+
+            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
+        };
+
+        for (int i = 0; i < KB; ++i) {
+            Unroll{}(compute, i);
+        }
+
+        //store to C
+        auto storec = [&](int col) {
+            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
+        };
+        Unroll{}(storec);
+    }
+};
+
+template 
+struct tinygemm_kernel_vnni {
+    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
+
+        constexpr int COLS = BLOCK_N / 16;
+        const int TILE_SIZE = TILE_N * sizeof(block_q4_1);
+
+        const block_q8_1 * RESTRICT A = static_cast(_A);
+        const char * RESTRICT B = static_cast(_B);
+
+        __m512i va[8];
+        __m512i vb[8];
+        __m512 vc[COLS];
+        __m512 vd1, vs1;
+
+        const __m512i lowMask = _mm512_set1_epi8(0xF);
+
+        auto loadc = [&](int col) {
+            vc[col] = _mm512_setzero_ps();
+        };
+        Unroll{}(loadc);
+
+        auto compute = [&](int col, int i) {
+            // load a
+            if (col == 0) {
+                const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs);
+                for (int k = 0; k < 8; ++k) {
+                    va[k] = _mm512_set1_epi32(a_ptr[k]);
+                }
+                vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
+                vs1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].s));
+            }
+
+            // load b
+            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
+            for (int k = 0; k < 8; k += 2) {
+                __m512i bytes = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 32));
+                vb[k + 0] = _mm512_and_si512(bytes, lowMask);
+                vb[k + 1] = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+            }
+            const int offset = TILE_N * TILE_K / 2;
+            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
+            const __m512 vm0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset + TILE_N * sizeof(ggml_half))));
+
+            __m512i vsum = _mm512_setzero_si512();
+            for (int k = 0; k < 8; ++k) {
+                vsum = _mm512_dpbusd_epi32(vsum, vb[k], va[k]);
+            }
+
+            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
+            vc[col] = _mm512_fmadd_ps(vm0, vs1, vc[col]);
+        };
+
+        for (int i = 0; i < KB; ++i) {
+            Unroll{}(compute, i);
+        }
+
+        //store to C
+        auto storec = [&](int col) {
+            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
+        };
+        Unroll{}(storec);
+    }
+};
+
+template 
+struct tinygemm_kernel_vnni {
+    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
+
+        constexpr int COLS = BLOCK_N / 16;
+        const int TILE_SIZE = TILE_N * sizeof(block_q8_0) + TILE_N * sizeof(int32_t);
+
+        const block_q8_0 * RESTRICT A = static_cast(_A);
+        const char * RESTRICT B = static_cast(_B);
+
+        __m512i va[8];
+        __m512i vb[8];
+        __m512 vc[COLS];
+        __m512 vd1;
+
+        // Notes: s8s8 igemm compensation in avx512-vnni
+        // change s8s8 to u8s8 with compensate
+        //   a * b = (a + 128) * b - 128 * b
+        //   s   s       u       s    u    s
+        //
+        // (128 * b is pre-computed when packing B to vnni formats)
+        //
+        const __m512i off = _mm512_set1_epi8(static_cast(0x80));
+
+        auto loadc = [&](int col) {
+            vc[col] = _mm512_setzero_ps();
+        };
+        Unroll{}(loadc);
+
+        auto compute = [&](int col, int i) {
+            // load a and add offset 128
+            if (col == 0) {
+                const int32_t * a_ptr = reinterpret_cast(A[0 * KB + i].qs);
+                for (int k = 0; k < 8; ++k) {
+                    va[k] = _mm512_set1_epi32(a_ptr[k]);
+                    va[k] = _mm512_add_epi8(va[k], off);
+                }
+                vd1 = _mm512_set1_ps(GGML_FP16_TO_FP32(A[0 * KB + i].d));
+            }
+
+            // load b
+            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
+            for (int k = 0; k < 8; ++k) {
+                vb[k] = _mm512_loadu_si512((const __m512i *)(b_ptr + k * 64));
+            }
+            const int offset = TILE_N * TILE_K;
+            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset)));
+            const int offset2 = TILE_N * TILE_K + TILE_N * sizeof(ggml_half);
+            const __m512i vcomp = _mm512_loadu_si512((const __m512i *)(b_ptr + offset2));
+
+            __m512i vsum = _mm512_setzero_si512();
+            for (int k = 0; k < 8; ++k) {
+                vsum = _mm512_dpbusd_epi32(vsum, va[k], vb[k]);
+            }
+            vsum = _mm512_sub_epi32(vsum, vcomp);
+
+            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(vsum), _mm512_mul_ps(vd0, vd1), vc[col]);
+        };
+
+        for (int i = 0; i < KB; ++i) {
+            Unroll{}(compute, i);
+        }
+
+        //store to C
+        auto storec = [&](int col) {
+            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
+        };
+        Unroll{}(storec);
+    }
+};
+
+template 
+struct tinygemm_kernel_vnni {
+    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
+
+        constexpr int COLS = BLOCK_N / 16;
+        const int TILE_SIZE = TILE_N * sizeof(block_q4_K) + TILE_N * 4;
+
+        const block_q8_K * RESTRICT A = static_cast(_A);
+        const char * RESTRICT B = static_cast(_B);
+
+        // a.qs:   8 groups, 32 bytes each group (m256i)
+        __m512i va[8];
+        // a.bsum: 8 groups,  2 bytes each group (m128i)
+        __m512i va_bsum;
+        __m512 vc[COLS];
+        __m512 vd1;
+
+        // packed_B:
+        const int offset_scales = (QK_K / 2) * TILE_N;
+        const int offset_mins   = (QK_K / 2) * TILE_N +  8 * TILE_N;
+        const int offset_d0     = (QK_K / 2) * TILE_N + 16 * TILE_N;
+        const int offset_dmin   = (QK_K / 2) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);
+
+        const __m512i lowMask = _mm512_set1_epi8(0xF);
+
+        auto loadc = [&](int col) {
+            vc[col] = _mm512_setzero_ps();
+        };
+        Unroll{}(loadc);
+
+        // Notes: vnni formats in QK_K
+        //   a) quants vnni format
+        //     int8  {k/4, n, 4}, viewed as 2d {k/4, 4n}, k = 32
+        //     from {16, 32} to {8, 64}
+        //
+        //   b) min vnni format
+        //     int16 {k/2, n, 2}, viewed as 2d {k/2, 2n}, k = 8
+        //     from {16,  8} to {4, 32}
+        //
+        auto compute = [&](int col, int i) {
+            // load a
+            if (col == 0) {
+                for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
+                    va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
+                }
+                const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
+                const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+                va_bsum = _mm512_castsi128_si512(q8s);
+                vd1 = _mm512_set1_ps(A[0 * KB + i].d);
+            }
+
+            // step 1: accumultate the quants
+            __m512i acc = _mm512_setzero_si512();
+            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
+            const char * b_qs  = b_ptr;
+            for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
+                __m512i vsum = _mm512_setzero_si512();
+                for (int k = 0; k < 8; k += 2) {
+                    __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);
+                    __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);
+
+                    __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);
+                    __m512i vb0 = _mm512_and_si512(bytes, lowMask);
+                    vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
+                    __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+                    vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
+
+                    b_qs += 64;
+                }
+                // vacc += scale * (q8 @ q4)
+                const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
+                acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
+            }
+            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
+            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
+
+            // step 2: accumulate the mins
+            __m512i acc_m = _mm512_setzero_si512();
+            for (int k = 0; k < 4; ++k) {
+                __m512i vmask = _mm512_set1_epi32(k);
+                __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);
+                __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));
+                acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
+            }
+            const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));
+            vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);
+        };
+
+        for (int i = 0; i < KB; ++i) {
+            Unroll{}(compute, i);
+        }
+
+        //store to C
+        auto storec = [&](int col) {
+            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
+        };
+        Unroll{}(storec);
+    }
+};
+
+template 
+struct tinygemm_kernel_vnni {
+    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
+
+        constexpr int COLS = BLOCK_N / 16;
+        const int TILE_SIZE = TILE_N * sizeof(block_q5_K) + TILE_N * 4;
+
+        const block_q8_K * RESTRICT A = static_cast(_A);
+        const char * RESTRICT B = static_cast(_B);
+
+        // a.qs:   8 groups, 32 bytes each group (m256i)
+        __m512i va[8];
+        // a.bsum: 8 groups,  2 bytes each group (m128i)
+        __m512i va_bsum;
+        __m512 vc[COLS];
+        __m512 vd1;
+
+        // packed_B:
+        const int offset_qh     = (QK_K / 2) * TILE_N;
+        const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N;
+        const int offset_mins   = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N +  8 * TILE_N;
+        const int offset_d0     = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N;
+        const int offset_dmin   = (QK_K / 2) * TILE_N + (QK_K / 8) * TILE_N + 16 * TILE_N + TILE_N * sizeof(ggml_half);
+
+        const __m512i lowMask = _mm512_set1_epi8(0xF);
+
+        auto loadc = [&](int col) {
+            vc[col] = _mm512_setzero_ps();
+        };
+        Unroll{}(loadc);
+
+        // Q5_K and Q4_K shares the same vnni formats, refer to notes above.
+        auto compute = [&](int col, int i) {
+            // load a
+            if (col == 0) {
+                for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
+                    va[k_group] = _mm512_castsi256_si512(_mm256_loadu_si256((const __m256i *)(A[0 * KB + i].qs + k_group * 32)));
+                }
+                const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
+                const __m128i q8s = _mm_hadd_epi16(_mm256_extracti128_si256(q8sums, 0), _mm256_extracti128_si256(q8sums, 1));
+                va_bsum = _mm512_castsi128_si512(q8s);
+                vd1 = _mm512_set1_ps(A[0 * KB + i].d);
+            }
+
+            // step 1: accumultate the quants
+            __m512i acc = _mm512_setzero_si512();
+            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
+            const char * b_qs  = b_ptr;
+            const char * b_qh  = b_ptr + offset_qh;
+            for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
+                __m512i vsum = _mm512_setzero_si512();
+                __m512i hmask0 = _mm512_set1_epi8(0x1);
+                __m512i hmask1 = _mm512_set1_epi8(0x2);
+                __m512i hbits = _mm512_loadu_si512((const __m512i *)(b_qh + k_group * 64));
+                for (int k = 0; k < 8; k += 2) {
+                    __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 0), va[k_group]);
+                    __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(k + 1), va[k_group]);
+
+                    __m512i bytes = _mm512_loadu_si512((const __m512i *)b_qs);
+                    __m512i vb0 = _mm512_and_si512(bytes, lowMask);
+                    __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+
+                    __m512i vh0 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask0), k), 4);
+                    __m512i vh1 = _mm512_slli_epi16(_mm512_srli_epi16(_mm512_and_si512(hbits, hmask1), k + 1), 4);
+
+                    hmask0 = _mm512_slli_epi16(hmask0, 2);
+                    hmask1 = _mm512_slli_epi16(hmask1, 2);
+                    vb0 = _mm512_add_epi8(vb0, vh0);
+                    vb1 = _mm512_add_epi8(vb1, vh1);
+
+                    vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
+                    vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
+
+                    b_qs += 64;
+                }
+                // vacc += scale * (q8 @ q5)
+                const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
+                acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
+            }
+            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
+            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
+
+            // step 2: accumulate the mins
+            __m512i acc_m = _mm512_setzero_si512();
+            for (int k = 0; k < 4; ++k) {
+                __m512i vmask = _mm512_set1_epi32(k);
+                __m512i va = _mm512_permutexvar_epi32(vmask, va_bsum);
+                __m512i vb = _mm512_cvtepi8_epi16(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_mins + k * 32)));
+                acc_m = _mm512_dpwssds_epi32(acc_m, va, vb);
+            }
+            const __m512 vdmin = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_dmin)));
+            vc[col] = _mm512_fnmadd_ps(_mm512_cvtepi32_ps(acc_m), _mm512_mul_ps(vdmin, vd1), vc[col]);
+        };
+
+        for (int i = 0; i < KB; ++i) {
+            Unroll{}(compute, i);
+        }
+
+        //store to C
+        auto storec = [&](int col) {
+            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
+        };
+        Unroll{}(storec);
+    }
+};
+
+template 
+struct tinygemm_kernel_vnni {
+    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
+
+        constexpr int COLS = BLOCK_N / 16;
+        const int TILE_SIZE = TILE_N * sizeof(block_q6_K);
+
+        const block_q8_K * RESTRICT A = static_cast(_A);
+        const char * RESTRICT B = static_cast(_B);
+
+        // load the 256 bytes from A to 4 avx512 vectors
+        __m512i va[4];
+        __m512 vc[COLS];
+        __m512 vd1;
+
+        // packed_B:
+        const int offset_qh     = (QK_K / 2) * TILE_N;
+        const int offset_scales = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N;
+        const int offset_d0     = (QK_K / 2) * TILE_N + (QK_K / 4) * TILE_N + 16 * TILE_N;
+
+        // compensation
+        __m512i vcomp;
+
+        const __m512i m32s = _mm512_set1_epi32(32);
+        const __m512i lowMask = _mm512_set1_epi8(0xF);
+
+        auto loadc = [&](int col) {
+            vc[col] = _mm512_setzero_ps();
+        };
+        Unroll{}(loadc);
+
+        auto compute = [&](int col, int i) {
+            if (col == 0) {
+                // load a
+                va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs +   0));
+                va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs +  64));
+                va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));
+                va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));
+
+                const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
+                vcomp = _mm512_mullo_epi32(_mm512_cvtepi16_epi32(q8sums), m32s);
+                vd1 = _mm512_set1_ps(A[0 * KB + i].d);
+            }
+
+            // accmulate the quants
+            __m512i acc = _mm512_setzero_si512();
+            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
+            const char * b_qs = b_ptr;
+            const char * b_qh = b_ptr + offset_qh;
+            int mask = 0;
+            for (int k_group = 0; k_group < QK_K / 16; ++k_group) {
+                int r = k_group >> 2;
+                __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
+                __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
+
+                __m512i vsum = _mm512_setzero_si512();
+                __m512i hmask = _mm512_set1_epi8(0x3);
+
+                __m512i bytes = _mm512_loadu_si512(b_qs);
+                __m512i hbits = _mm512_loadu_si512(b_qh);
+                __m512i vb0 = _mm512_and_si512(bytes, lowMask);
+                __m512i vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+                __m512i vh0 = _mm512_slli_epi16(_mm512_and_si512(hbits, hmask), 4);
+                __m512i vh1 = _mm512_slli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 2)), 2);
+
+                vb0 = _mm512_add_epi8(vb0, vh0);
+                vb1 = _mm512_add_epi8(vb1, vh1);
+                vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
+                vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
+                b_qs += 64;
+
+                va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
+                va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
+
+                bytes = _mm512_loadu_si512(b_qs);
+                vb0 = _mm512_and_si512(bytes, lowMask);
+                vb1 = _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask);
+                vh0 =                   _mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 4));
+                vh1 = _mm512_srli_epi16(_mm512_and_si512(hbits, _mm512_slli_epi16(hmask, 6)), 2);
+                vb0 = _mm512_add_epi8(vb0, vh0);
+                vb1 = _mm512_add_epi8(vb1, vh1);
+                vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
+                vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
+                b_qs += 64;
+                b_qh += 64;
+
+                // B * A - 32 * A
+                __m512i vmask = _mm512_set1_epi32(k_group);
+                vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));
+
+                // vacc += scale * (q8 @ q6)
+                const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
+                acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
+            }
+            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
+            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
+        };
+
+        for (int i = 0; i < KB; ++i) {
+            Unroll{}(compute, i);
+        }
+
+        //store to C
+        auto storec = [&](int col) {
+            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
+        };
+        Unroll{}(storec);
+    }
+};
+
+template 
+struct tinygemm_kernel_vnni {
+    static void apply(int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
+
+        constexpr int COLS = BLOCK_N / 16;
+        const int TILE_SIZE = TILE_N * sizeof(block_iq4_xs) + TILE_N * 2;
+
+        const block_q8_K * RESTRICT A = static_cast(_A);
+        const char * RESTRICT B = static_cast(_B);
+
+        // load the 256 bytes from A to 4 avx512 vectors
+        __m512i va[4];
+        __m512 vc[COLS];
+        __m512 vd1;
+
+        // packed_B:
+        const int offset_scales = (QK_K / 2) * TILE_N ;
+        const int offset_d0     = (QK_K / 2) * TILE_N + 8 * TILE_N;
+
+        // compensation
+        __m512i vcomp;
+
+        const __m256i m128s = _mm256_set1_epi16(128);
+        const __m512i lowMask = _mm512_set1_epi8(0xF);
+
+        const __m512i values128 = _mm512_set_epi8(
+            113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
+            113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
+            113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127,
+            113, 89, 69, 53, 38, 25, 13, 1, -10, -22, -35, -49, -65, -83, -104, -127
+        );
+        const __m512i off = _mm512_set1_epi8(static_cast(0x80));
+        const __m512i values256 = _mm512_add_epi8(values128, off);
+
+        auto loadc = [&](int col) {
+            vc[col] = _mm512_setzero_ps();
+        };
+        Unroll{}(loadc);
+
+        auto compute = [&](int col, int i) {
+            if (col == 0) {
+                // load a
+                va[0] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs +   0));
+                va[1] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs +  64));
+                va[2] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 128));
+                va[3] = _mm512_loadu_si512((const __m512i *)(A[0 * KB + i].qs + 192));
+
+                // compensation: 128 * A
+                const __m256i q8sums = _mm256_loadu_si256((const __m256i *)A[0 * KB + i].bsums);
+                vcomp = _mm512_castsi256_si512(_mm256_madd_epi16(q8sums, m128s));
+                vd1 = _mm512_set1_ps(A[0 * KB + i].d);
+            }
+
+            // accmulate the quants
+            __m512i acc = _mm512_setzero_si512();
+            const char * b_ptr = B + PACKED_INDEX(col, i, KB, TILE_SIZE);
+            const char * b_qs = b_ptr;
+            int mask = 0;
+            for (int k_group = 0; k_group < QK_K / 32; ++k_group) {
+                int r = k_group >> 1;
+                __m512i vmask = _mm512_set1_epi32(k_group);
+                __m512i vsum = _mm512_setzero_si512();
+                for (int k = 0; k < 8; k += 2) {
+                    __m512i va0 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
+                    __m512i va1 = _mm512_permutexvar_epi32(_mm512_set1_epi32(mask++), va[r]);
+
+                    __m512i bytes = _mm512_loadu_si512(b_qs);
+                    __m512i vb0 = _mm512_shuffle_epi8(values256, _mm512_and_si512(bytes, lowMask));
+                    __m512i vb1 = _mm512_shuffle_epi8(values256, _mm512_and_si512(_mm512_srli_epi16(bytes, 4), lowMask));
+
+                    vsum = _mm512_dpbusd_epi32(vsum, vb0, va0);
+                    vsum = _mm512_dpbusd_epi32(vsum, vb1, va1);
+                    b_qs += 64;
+                }
+                // (B + 128) * A - 128 * A
+                vsum = _mm512_sub_epi32(vsum, _mm512_permutexvar_epi32(vmask, vcomp));
+
+                // vacc += scale * (q8 @ q4)
+                const __m512i vscale = _mm512_cvtepi8_epi32(_mm_loadu_si128((const __m128i *)(b_ptr + offset_scales + k_group * TILE_N)));
+                acc = _mm512_add_epi32(acc, _mm512_mullo_epi32(vsum, vscale));
+            }
+            const __m512 vd0 = _mm512_cvtph_ps(_mm256_loadu_si256((const __m256i *)(b_ptr + offset_d0)));
+            vc[col] = _mm512_fmadd_ps(_mm512_cvtepi32_ps(acc), _mm512_mul_ps(vd0, vd1), vc[col]);
+        };
+
+        for (int i = 0; i < KB; ++i) {
+            Unroll{}(compute, i);
+        }
+
+        //store to C
+        auto storec = [&](int col) {
+            _mm512_storeu_ps((__m512i*)(C + 0 * ldc + col * 16), vc[col]);
+        };
+        Unroll{}(storec);
+    }
+};
+
+#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE)                                         \
+    tinygemm_kernel_vnni::apply(   \
+        KB, (const char *)wdata + 0 * row_size_A,                                    \
+        (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE),     \
+        (float *) dst->data + 0 * N + nb_start, ldc)
+
+template ::value, int>::type = 0>
+void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, TC * RESTRICT C, int ldc) {
+    using packed_B_t = packed_B_type;
+    const int TILE_SIZE = get_tile_size();
+    const bool need_unpack = do_unpack::value;
+
+    GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);
+    const TA * RESTRICT A = static_cast(_A);
+    const char * RESTRICT B = static_cast(_B);
+
+    const int m0 = std::min(M, TILE_M);
+    const int m1 = std::max(M - TILE_M, 0);
+    const int lda = KB * sizeof(TA);
+    //const int ldb = KB * sizeof(TB);
+
+    static thread_local packed_B_t Tile0[TILE_N * TILE_K];
+    static thread_local packed_B_t Tile1[TILE_N * TILE_K];
+    static thread_local int8_t Tile23[TILE_M * TILE_K];
+
+    static thread_local int32_t TileC0[TILE_M * TILE_N * 4];
+    static thread_local int32_t TileC1[TILE_M * TILE_N * 4];
+
+    // double buffering C to interleave avx512 and amx
+    int32_t * C_cur = TileC0;
+    int32_t * C_pre = TileC1;
+
+    auto Tile4 = [&](int32_t * base) { return base; };
+    auto Tile5 = [&](int32_t * base) { return base + TILE_M * TILE_N; };
+    auto Tile6 = [&](int32_t * base) { return base + 2 * TILE_M * TILE_N; };
+    auto Tile7 = [&](int32_t * base) { return base + 3 * TILE_M * TILE_N; };
+
+    if (M == 2 * TILE_M) {
+        // i = 0
+        const char * B_blk0 = B + PACKED_INDEX(0, 0, KB, TILE_SIZE);
+        const char * B_blk1 = B + PACKED_INDEX(1, 0, KB, TILE_SIZE);
+        if (need_unpack) {
+            unpack_B(Tile0, B_blk0);
+            _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
+        } else {
+            _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
+        }
+
+        _tile_zero(TMM4);
+        _tile_loadd(TMM2, A[0].qs, lda);
+        _tile_dpbssd(TMM4, TMM2, TMM0);
+        _tile_stored(TMM4, Tile4(C_pre), TILE_N * sizeof(int32_t));
+
+        _tile_zero(TMM5);
+        _tile_loadd(TMM3, A[TILE_M * KB + 0].qs, lda);
+        _tile_dpbssd(TMM5, TMM3, TMM0);
+        _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t));
+
+        if (need_unpack) {
+            unpack_B(Tile1, B_blk0);
+            _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
+        } else {
+            _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
+        }
+
+        _tile_zero(TMM6);
+        _tile_dpbssd(TMM6, TMM2, TMM1);
+        _tile_stored(TMM6, Tile6(C_pre), TILE_N * sizeof(int32_t));
+
+        _tile_zero(TMM7);
+        _tile_dpbssd(TMM7, TMM3, TMM1);
+        _tile_stored(TMM7, Tile7(C_pre), TILE_N * sizeof(int32_t));
+
+        for (int i = 1; i < KB; ++i) {
+            // index of previous iter
+            const int ii = i - 1;
+            const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);
+            const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);
+            GGML_DISPATCH_BOOL(ii > 0, is_acc, [&] {
+                if (need_unpack) {
+                    unpack_B(Tile0, B_blk0);
+                    _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
+                } else {
+                    _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
+                }
+                _tile_zero(TMM4);
+                _tile_loadd(TMM2, A[i].qs, lda);
+                acc_C::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
+
+                _tile_dpbssd(TMM4, TMM2, TMM0);
+                _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));
+
+                _tile_zero(TMM5);
+                _tile_loadd(TMM3, A[TILE_M * KB + i].qs, lda);
+                acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
+
+                _tile_dpbssd(TMM5, TMM3, TMM0);
+                _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));
+
+                if (need_unpack) {
+                    unpack_B(Tile1, B_blk1);
+                    _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
+                } else {
+                    _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
+                }
+                _tile_zero(TMM6);
+                acc_C::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
+
+                _tile_dpbssd(TMM6, TMM2, TMM1);
+                _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));
+
+                _tile_zero(TMM7);
+                acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
+
+                _tile_dpbssd(TMM7, TMM3, TMM1);
+                _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));
+
+                std::swap(C_cur, C_pre);
+            });
+        }
+        // final accumulation
+        {
+            int ii = KB - 1;
+            acc_C::apply(C, ldc, Tile4(C_pre), &A[ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
+            acc_C::apply(C + TILE_M * ldc, ldc, Tile5(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(0, ii, KB, TILE_SIZE), TILE_M);
+            acc_C::apply(C + TILE_N, ldc, Tile6(C_pre), &A[ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
+            acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_pre), &A[TILE_M * KB + ii], KB, B + PACKED_INDEX(1, ii, KB, TILE_SIZE), TILE_M);
+        }
+    } else {
+        for (int i = 0; i < KB; ++i) {
+            _tile_zero(TMM4);
+            _tile_zero(TMM6);
+            if (m1 != 0) {
+                _tile_zero(TMM5);
+                _tile_zero(TMM7);
+            }
+
+            const char * B_blk0 = B + PACKED_INDEX(0, i, KB, TILE_SIZE);
+            const char * B_blk1 = B + PACKED_INDEX(1, i, KB, TILE_SIZE);
+            if (need_unpack) {
+                unpack_B(Tile0, B_blk0);
+                _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
+            } else {
+                _tile_loadd(TMM0, B_blk0, TILE_N * VNNI_BLK);
+            }
+
+            if (need_unpack) {
+                unpack_B(Tile1, B_blk1);
+                _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
+            } else {
+                _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK);
+            }
+
+            if (m0 == TILE_M) {
+                _tile_loadd(TMM2, A[i].qs, lda);
+            } else {
+                unpack_A(Tile23, &A[i], KB, m0);
+                _tile_loadd(TMM2, Tile23, TILE_K);
+            }
+
+            _tile_dpbssd(TMM4, TMM2, TMM0);
+            _tile_dpbssd(TMM6, TMM2, TMM1);
+
+            _tile_stored(TMM4, Tile4(C_cur), TILE_N * sizeof(int32_t));
+            _tile_stored(TMM6, Tile6(C_cur), TILE_N * sizeof(int32_t));
+
+            GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
+                acc_C::apply(C,          ldc, Tile4(C_cur), &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);
+                acc_C::apply(C + TILE_N, ldc, Tile6(C_cur), &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);
+            });
+
+            if (m1 != 0) {
+                unpack_A(Tile23, &A[TILE_M * KB + i], KB, m1);
+                _tile_loadd(TMM3, Tile23, TILE_K);
+
+                _tile_dpbssd(TMM5, TMM3, TMM0);
+                _tile_dpbssd(TMM7, TMM3, TMM1);
+                _tile_stored(TMM5, Tile5(C_cur), TILE_N * sizeof(int32_t));
+                _tile_stored(TMM7, Tile7(C_cur), TILE_N * sizeof(int32_t));
+                GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
+                    acc_C::apply(C + TILE_M * ldc,          ldc, Tile5(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);
+                    acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Tile7(C_cur), &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);
+                });
+            }
+        }
+    }
+    return;
+}
+
+template ::value, int>::type = 0>
+void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const void * RESTRICT _B, float * RESTRICT C, int ldc) {
+    static_assert(std::is_same::value);
+    const int TILE_SIZE = get_tile_size();
+
+    GGML_ASSERT(M <= 2 * TILE_M && N == 2 * TILE_N);
+    const TA * RESTRICT A = static_cast(_A);
+    const char * RESTRICT B = static_cast(_B);
+
+    const int m0 = std::min(M, TILE_M);
+    const int m1 = std::max(M - TILE_M, 0);
+    //const int lda = KB * sizeof(TA);
+
+    static thread_local int8_t Tile0[TILE_N * TILE_K];
+    static thread_local int8_t Tile1[TILE_N * TILE_K];
+    static thread_local int8_t Tile23[TILE_M * TILE_K];
+
+    // mat mul result for each group
+    static thread_local int32_t Tile4[TILE_M * TILE_N];
+    static thread_local int32_t Tile5[TILE_M * TILE_N];
+    static thread_local int32_t Tile6[TILE_M * TILE_N];
+    static thread_local int32_t Tile7[TILE_M * TILE_N];
+
+    // sum of each QK_K block, contains 8 groups, int32
+    static thread_local int32_t Sumi4[TILE_M * TILE_N];
+    static thread_local int32_t Sumi5[TILE_M * TILE_N];
+    static thread_local int32_t Sumi6[TILE_M * TILE_N];
+    static thread_local int32_t Sumi7[TILE_M * TILE_N];
+
+    const int k_group_size = std::is_same::value ? 16 : 32;
+    for (int i = 0; i < KB; ++i) {
+        // step 1: accumulate the quants across 8 groups, each group with 32
+        for (int k = 0; k < QK_K / k_group_size; ++k) {
+            GGML_DISPATCH_BOOL(k > 0, is_acc, [&] {
+                _tile_zero(TMM4);
+                _tile_zero(TMM6);
+
+                unpack_B(Tile0, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k);
+                _tile_loadd(TMM0, Tile0, TILE_N * VNNI_BLK);
+
+                unpack_B(Tile1, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k);
+                _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK);
+
+                unpack_A(Tile23, &A[i], KB, k, m0);
+                _tile_loadd(TMM2, Tile23, TILE_K);
+
+                _tile_dpbssd(TMM4, TMM2, TMM0);
+                _tile_dpbssd(TMM6, TMM2, TMM1);
+
+                _tile_stored(TMM4, Tile4, TILE_N * sizeof(int32_t));
+                _tile_stored(TMM6, Tile6, TILE_N * sizeof(int32_t));
+
+                scale_C(Tile4, Sumi4, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m0);
+                scale_C(Tile6, Sumi6, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m0);
+
+                if (m1 != 0) {
+                    _tile_zero(TMM5);
+                    _tile_zero(TMM7);
+
+                    unpack_A(Tile23, &A[TILE_M * KB + i], KB, k, m1);
+                    _tile_loadd(TMM3, Tile23, TILE_K);
+
+                    _tile_dpbssd(TMM5, TMM3, TMM0);
+                    _tile_dpbssd(TMM7, TMM3, TMM1);
+
+                    _tile_stored(TMM5, Tile5, TILE_N * sizeof(int32_t));
+                    _tile_stored(TMM7, Tile7, TILE_N * sizeof(int32_t));
+
+                    scale_C(Tile5, Sumi5, B + PACKED_INDEX(0, i, KB, TILE_SIZE), k, m1);
+                    scale_C(Tile7, Sumi7, B + PACKED_INDEX(1, i, KB, TILE_SIZE), k, m1);
+                }
+            });
+        }
+
+        // step 2: accmulate the mins
+        GGML_DISPATCH_BOOL(i > 0, is_acc, [&] {
+            acc_C::apply(C,          ldc, Sumi4, &A[i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m0);
+            acc_C::apply(C + TILE_N, ldc, Sumi6, &A[i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m0);
+            if (m1 != 0) {
+                acc_C::apply(C + TILE_M * ldc,          ldc, Sumi5, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(0, i, KB, TILE_SIZE), m1);
+                acc_C::apply(C + TILE_M * ldc + TILE_N, ldc, Sumi7, &A[TILE_M * KB + i], KB, B + PACKED_INDEX(1, i, KB, TILE_SIZE), m1);
+            }
+        });
+    }
+    return;
+}
+
+} // anonymous namespace
+
+// get the packed tensor size for quantized weights
+size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor) {
+    const enum ggml_type TYPE = tensor->type;
+
+    const int K = tensor->ne[0]; // ne0: in_features
+    const int N = tensor->ne[1]; // ne1: out_features
+
+    auto get_tensor_size = [&] {
+        size_t row_size_B{0};
+        GGML_DISPATCH_QTYPES(TYPE, [&] {
+            row_size_B = get_row_size(K);
+        });
+        return N * row_size_B;
+    };
+
+    if (qtype_has_amx_kernels(TYPE)) {
+        return get_tensor_size();
+    } else {
+        // for f16, bf16 we don't do packing
+        return ggml_nbytes(tensor);
+    }
+}
+
+// pack weight to vnni format
+void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
+
+    size_t alloc_size = ggml_backend_amx_get_alloc_size(tensor);
+    GGML_ASSERT(alloc_size == size);
+
+    const enum ggml_type TYPE = tensor->type;
+
+    const int K = tensor->ne[0]; // ne0: in_features
+    const int N = tensor->ne[1]; // ne1: out_features
+
+#if defined(_OPENMP)
+    // the buffer ctx is not initialized when .set_tensor is called
+    int n_threads = omp_get_num_threads();
+#else
+    int n_threads = 1;
+#endif
+
+    GGML_DISPATCH_QTYPES(TYPE, [&] {
+        convert_B_packed_format((void *)((char *)tensor->data + offset), (const type *)data, N, K, n_threads);
+    });
+}
+
+// NB: mixed dtype gemm with Advanced Matrix Extensions (Intel AMX)
+//
+// src0: weight in shape of {N, K}, quantized
+// src1: input  in shape of {M, K}, float32
+// dst:  output in shape of {M, N}, float32
+//
+// the function performs: dst = src1 @ src0.T
+//
+void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) {
+    struct ggml_tensor * src0 = dst->src[0];
+    struct ggml_tensor * src1 = dst->src[1];
+
+    const enum ggml_type TYPE = src0->type;
+
+    const int n_threads = ctx->n_threads;
+
+    // f16 only has avx512 kernels for now,
+    // amx kernels will be added once 6th gen xeon is released.
+    const bool is_floating_type = TYPE == GGML_TYPE_F16;
+
+    const int M = dst->ne[1];
+    const int N = dst->ne[0];
+    const int K = src0->ne[0];
+    const int ldc = dst->nb[1] / dst->nb[0];
+
+    if (is_floating_type) {
+        constexpr int BLOCK_M = 4;
+        constexpr int BLOCK_N = 6;
+        const int MB = div_up(M, BLOCK_M);
+        const int NB = div_up(N, BLOCK_N);
+
+        parallel_for(n_threads, MB * NB, [&](int begin, int end) {
+            GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] {
+                for (int i = begin; i < end; ++i) {
+                    int mb = i / NB;
+                    int nb = i % NB;
+
+                    int mb_start = mb * BLOCK_M;
+                    int mb_size = std::min(BLOCK_M, M - mb_start);
+                    int nb_start = nb * BLOCK_N;
+                    int nb_size = std::min(BLOCK_N, N - nb_start);
+
+                    switch (mb_size << 4 | nb_size) {
+                        case 0x12: LAUNCH_TINYGEMM_KERNEL_AVX(1, 2); break;
+                        case 0x14: LAUNCH_TINYGEMM_KERNEL_AVX(1, 4); break;
+                        case 0x16: LAUNCH_TINYGEMM_KERNEL_AVX(1, 6); break;
+                        case 0x22: LAUNCH_TINYGEMM_KERNEL_AVX(2, 2); break;
+                        case 0x24: LAUNCH_TINYGEMM_KERNEL_AVX(2, 4); break;
+                        case 0x26: LAUNCH_TINYGEMM_KERNEL_AVX(2, 6); break;
+                        case 0x32: LAUNCH_TINYGEMM_KERNEL_AVX(3, 2); break;
+                        case 0x34: LAUNCH_TINYGEMM_KERNEL_AVX(3, 4); break;
+                        case 0x36: LAUNCH_TINYGEMM_KERNEL_AVX(3, 6); break;
+                        case 0x42: LAUNCH_TINYGEMM_KERNEL_AVX(4, 2); break;
+                        case 0x44: LAUNCH_TINYGEMM_KERNEL_AVX(4, 4); break;
+                        case 0x46: LAUNCH_TINYGEMM_KERNEL_AVX(4, 6); break;
+                        default: fprintf(stderr, "Unexpected block size!\n");
+                    }
+                }
+            });
+        });
+        return;
+    }
+
+    // pointer to work space, used convert A from float to quantized type
+    void * wdata = nullptr;
+
+    //TODO: performance improvement: merge quant A
+    GGML_DISPATCH_QTYPES(TYPE, [&] {
+        const size_t row_size_A = K / blck_size * sizeof(vec_dot_type);
+        const size_t desired_wsize = M * row_size_A;
+        if (ctx->work_size < desired_wsize) {
+            ctx->work_data.reset(new char[desired_wsize]);
+            ctx->work_size = desired_wsize;
+        }
+        wdata = ctx->work_data.get();
+
+        // Q4_0, Q4_1, Q8_0 handles 1 TILE_K per blck_size
+        // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size
+        GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size);
+
+        const float * A_data = static_cast(src1->data);
+        for (int m = 0; m < M; ++m) {
+            from_float(A_data + m * K, (char *)wdata + m * row_size_A, K);
+        }
+    });
+
+    if (M == 1) {
+        // MB = 1 and handle 8 tiles in each block
+        constexpr int kTilesN = 4;
+        constexpr int BLOCK_N = TILE_N * kTilesN;
+        const int NB = div_up(N, BLOCK_N);
+
+        parallel_for(n_threads, NB, [&](int begin, int end) {
+            GGML_DISPATCH_QTYPES(TYPE, [&] {
+                const int KB = K / blck_size;
+                const int TILE_SIZE = get_tile_size();
+                const int row_size_A = KB * sizeof(vec_dot_type);
+                for (int i = begin; i < end; ++i) {
+                    int nb = i;
+                    int nb_start = nb * BLOCK_N;
+                    int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96
+
+                    switch (nb_size) {
+                        //case 160: LAUNCH_TINYGEMM_KERNEL_VNNI(160); break;
+                        case 128: LAUNCH_TINYGEMM_KERNEL_VNNI(128); break;
+                        case 96: LAUNCH_TINYGEMM_KERNEL_VNNI(96); break;
+                        case 64: LAUNCH_TINYGEMM_KERNEL_VNNI(64); break;
+                        case 32: LAUNCH_TINYGEMM_KERNEL_VNNI(32); break;
+                        default: fprintf(stderr, "Unexpected n block size!\n");
+                    }
+                }
+            });
+        });
+        return;
+    }
+
+    // handle 4 tiles at a tile
+    constexpr int BLOCK_M = TILE_M * 2;
+    constexpr int BLOCK_N = TILE_N * 2;
+    const int MB = div_up(M, BLOCK_M);
+    const int NB = div_up(N, BLOCK_N);
+
+    parallel_for(n_threads, MB * NB, [&](int begin, int end) {
+        // init tile config for each thread
+        ggml_tile_config_init();
+
+        GGML_DISPATCH_QTYPES(TYPE, [&] {
+            const int KB = K / blck_size;
+            const int TILE_SIZE = get_tile_size();
+            const int row_size_A = KB * sizeof(vec_dot_type);
+
+            for (int i = begin; i < end; ++i) {
+                int mb = i / NB;
+                int nb = i % NB;
+
+                int mb_start = mb * BLOCK_M;
+                int mb_size = std::min(BLOCK_M, M - mb_start);
+                int nb_start = nb * BLOCK_N;
+                int nb_size = BLOCK_N;
+
+                tinygemm_kernel_amx(
+                    mb_size, nb_size, KB,
+                    (const char *)wdata + mb_start * row_size_A,
+                    (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE),
+                    (float *) dst->data + mb_start * N + nb_start, ldc);
+            }
+        });
+    });
+}
+
+#else // if defined(__AMX_INT8__)
+
+void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst) {
+    fprintf(stderr, "GGML is not compiled with AMX support!\n");
+
+    GGML_UNUSED(ctx);
+    GGML_UNUSED(dst);
+}
+
+#endif // if defined(__AMX_INT8__)
diff --git a/ggml/src/ggml-amx/mmq.h b/ggml/src/ggml-amx/mmq.h
new file mode 100644
index 00000000000..cf092062063
--- /dev/null
+++ b/ggml/src/ggml-amx/mmq.h
@@ -0,0 +1,17 @@
+#pragma once
+#include "common.h"
+#include 
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+size_t ggml_backend_amx_get_alloc_size(const struct ggml_tensor * tensor);
+
+void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
+
+void ggml_backend_amx_mul_mat(ggml_backend_amx_context * ctx, struct ggml_tensor * dst);
+
+#ifdef __cplusplus
+}
+#endif
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
index ba2e26999df..fd3deae0097 100644
--- a/ggml/src/ggml-backend-impl.h
+++ b/ggml/src/ggml-backend-impl.h
@@ -88,6 +88,7 @@ extern "C" {
 
         void (*free)(ggml_backend_t backend);
 
+        // Will be moved to the device interface
         // buffer allocation
         ggml_backend_buffer_type_t (*get_default_buffer_type)(ggml_backend_t backend);
 
@@ -112,17 +113,9 @@ extern "C" {
 
         // IMPORTANT: these functions have been moved to the device interface and will be removed from the backend interface
         //            new backends should implement the device interface instead
-
         // These functions are being moved to the device interface
-        // check if the backend can compute an operation
         bool (*supports_op)  (ggml_backend_t backend, const struct ggml_tensor * op);
-
-        // check if the backend can use tensors allocated in a buffer type
         bool (*supports_buft)(ggml_backend_t backend, ggml_backend_buffer_type_t buft);
-
-        // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
-        // these should be expensive operations with large batch sizes that may benefit from running on this backend
-        // even if the weight has to be copied from the CPU temporarily
         bool (*offload_op)   (ggml_backend_t backend, const struct ggml_tensor * op);
 
         // (optional) event synchronization
@@ -184,9 +177,8 @@ extern "C" {
         // check if the backend can use tensors allocated in a buffer type
         bool (*supports_buft)(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft);
 
-        // check if the backend wants to run an operation, even if the weights are allocated in a CPU buffer
-        // these should be expensive operations with large batch sizes that may benefit from running on this backend
-        // even if the weight has to be copied from the CPU temporarily
+        // (optional) check if the backend wants to run an operation, even if the weights are allocated in an incompatible buffer
+        // these should be expensive operations that may benefit from running on this backend instead of the CPU backend
         bool (*offload_op)(ggml_backend_dev_t dev, const struct ggml_tensor * op);
 
         // (optional) event synchronization
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 0551764fe3f..7d7b63a15a1 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -329,7 +329,6 @@ bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type
     if (backend->device) {
         return ggml_backend_dev_supports_buft(backend->device, buft);
     }
-
     return backend->iface.supports_buft(backend, buft);
 }
 
@@ -379,7 +378,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
         ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
     } else if (!ggml_backend_buffer_copy_tensor(src, dst)) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer));
+        GGML_LOG_DEBUG("%s: warning: slow copy from %s to %s\n", __func__, ggml_backend_buffer_name(src->buffer), ggml_backend_buffer_name(dst->buffer));
 #endif
         size_t nbytes = ggml_nbytes(src);
         void * data = malloc(nbytes);
@@ -463,6 +462,7 @@ enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
 }
 
 void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props) {
+    memset(props, 0, sizeof(*props));
     device->iface.get_props(device, props);
 }
 
@@ -479,6 +479,10 @@ ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t devic
 }
 
 ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
+    if (device->iface.get_host_buffer_type == NULL) {
+        return NULL;
+    }
+
     return device->iface.get_host_buffer_type(device);
 }
 
@@ -495,7 +499,11 @@ bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buff
 }
 
 bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
-    return device->iface.offload_op(device, op);
+    if (device->iface.offload_op != NULL) {
+        return device->iface.offload_op(device, op);
+    }
+
+    return false;
 }
 
 // Backend (reg)
@@ -525,6 +533,38 @@ void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * na
 #include "ggml-cuda.h"
 #endif
 
+#ifdef GGML_USE_METAL
+#include "ggml-metal.h"
+#endif
+
+#ifdef GGML_USE_SYCL
+#include "ggml-sycl.h"
+#endif
+
+#ifdef GGML_USE_VULKAN
+#include "ggml-vulkan.h"
+#endif
+
+#ifdef GGML_USE_BLAS
+#include "ggml-blas.h"
+#endif
+
+#ifdef GGML_USE_RPC
+#include "ggml-rpc.h"
+#endif
+
+#ifndef __AMX_INT8__
+#undef GGML_USE_AMX
+#endif
+
+#ifdef GGML_USE_AMX
+#  include "ggml-amx.h"
+#endif
+
+#ifdef GGML_USE_CANN
+#include "ggml-cann.h"
+#endif
+
 struct ggml_backend_registry {
     std::vector backends;
     std::vector devices;
@@ -533,15 +573,36 @@ struct ggml_backend_registry {
 #ifdef GGML_USE_CUDA
         register_backend(ggml_backend_cuda_reg());
 #endif
+#ifdef GGML_USE_METAL
+        register_backend(ggml_backend_metal_reg());
+#endif
+#ifdef GGML_USE_SYCL
+        register_backend(ggml_backend_sycl_reg());
+#endif
+#ifdef GGML_USE_VULKAN
+        register_backend(ggml_backend_vk_reg());
+#endif
+#ifdef GGML_USE_BLAS
+        register_backend(ggml_backend_blas_reg());
+#endif
+#ifdef GGML_USE_RPC
+        register_backend(ggml_backend_rpc_reg());
+#endif
+#ifdef GGML_USE_AMX
+        register_backend(ggml_backend_amx_reg());
+#endif
+#ifdef GGML_USE_CANN
+        register_backend(ggml_backend_cann_reg());
+#endif
 
-        register_backend(ggml_backend_cpu_reg());
+        // TODO: kompute
 
-        // TODO: sycl, metal, vulkan, kompute, cann
+        register_backend(ggml_backend_cpu_reg());
     }
 
     void register_backend(ggml_backend_reg_t reg) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: registered backend %s (%zu devices)\n",
+        GGML_LOG_DEBUG("%s: registered backend %s (%zu devices)\n",
             __func__, ggml_backend_reg_name(reg), ggml_backend_reg_dev_count(reg));
 #endif
         backends.push_back(reg);
@@ -552,7 +613,7 @@ struct ggml_backend_registry {
 
     void register_device(ggml_backend_dev_t device) {
 #ifndef NDEBUG
-        fprintf(stderr, "%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
+        GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device));
 #endif
         devices.push_back(device);
     }
@@ -652,8 +713,6 @@ ggml_backend_t ggml_backend_init_best(void) {
 
 // backend CPU
 
-static const size_t TENSOR_ALIGNMENT = 32; // required for mmap as gguf only guarantees 32-byte alignment
-
 static const char * ggml_backend_cpu_buffer_get_name(ggml_backend_buffer_t buffer) {
     return "CPU";
 
@@ -672,7 +731,7 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
 }
 
 static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    free(buffer->context);
+    ggml_aligned_free(buffer->context, buffer->size);
 }
 
 static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
@@ -740,14 +799,19 @@ static const char * ggml_backend_cpu_buffer_type_get_name(ggml_backend_buffer_ty
 }
 
 static ggml_backend_buffer_t ggml_backend_cpu_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    size += TENSOR_ALIGNMENT;   // malloc may return an address that is not aligned
-    void * data = malloc(size); // TODO: use GGML_ALIGNED_MALLOC (move to ggml-impl.h)
+    auto alloc_size = size;
+    if (alloc_size == 0) {
+        alloc_size = 1;
+    }
+
+    void * data = ggml_aligned_malloc(alloc_size);
+
     if (data == NULL) {
-        fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
+        GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, alloc_size);
         return NULL;
     }
 
-    return ggml_backend_buffer_init(buft, ggml_backend_cpu_buffer_i, data, size);
+    return ggml_backend_buffer_init(buft, ggml_backend_cpu_buffer_i, data, alloc_size);
 }
 
 static size_t ggml_backend_cpu_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
@@ -806,7 +870,7 @@ static ggml_backend_buffer_t ggml_backend_cpu_hbm_buffer_type_alloc_buffer(ggml_
     void * ptr;
     int result = hbw_posix_memalign(&ptr, ggml_backend_cpu_buffer_type_get_alignment(buft), size);
     if (result != 0) {
-        fprintf(stderr, "failed to allocate HBM buffer of size %zu\n", size);
+        GGML_LOG_ERROR("failed to allocate HBM buffer of size %zu\n", size);
         return NULL;
     }
 
@@ -1118,9 +1182,10 @@ static void ggml_backend_cpu_device_get_props(ggml_backend_dev_t dev, struct ggm
     props->type        = ggml_backend_cpu_device_get_type(dev);
     ggml_backend_cpu_device_get_memory(dev, &props->memory_free, &props->memory_total);
     props->caps = {
-        /* async       */ false,
-        /* host_buffer */ false,
-        /* events      */ false,
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ true,
+        /* .events                = */ false,
     };
 }
 
@@ -1153,7 +1218,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
                 op->type != GGML_TYPE_IQ1_S   &&
                 op->type != GGML_TYPE_IQ1_M; // missing type_traits.from_float
         case GGML_OP_MUL_MAT:
-            return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_internal_get_type_traits(op->src[0]->type).vec_dot_type;
+            return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_get_type_traits(op->src[0]->type)->vec_dot_type;
         case GGML_OP_ROPE_BACK:
             return op->src[2] == NULL && (op->op_params[2] & 4) == 0;
         case GGML_OP_IM2COL_BACK:
@@ -1216,16 +1281,22 @@ static ggml_backend_dev_t ggml_backend_cpu_reg_get_device(ggml_backend_reg_t reg
     };
 
     return &ggml_backend_cpu_device;
+}
+
+static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (strcmp(name, "ggml_backend_set_n_threads") == 0) {
+        return (void *)ggml_backend_cpu_set_n_threads;
+    }
+    return NULL;
 
     GGML_UNUSED(reg);
-    GGML_UNUSED(index);
 }
 
 static const struct ggml_backend_reg_i ggml_backend_cpu_reg_i = {
     /* .get_name         = */ ggml_backend_cpu_reg_get_name,
     /* .get_device_count = */ ggml_backend_cpu_reg_get_device_count,
     /* .get_device       = */ ggml_backend_cpu_reg_get_device,
-    /* .get_proc_address = */ NULL,
+    /* .get_proc_address = */ ggml_backend_cpu_get_proc_address,
 };
 
 ggml_backend_reg_t ggml_backend_cpu_reg(void) {
@@ -1422,7 +1493,7 @@ static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, co
     }
 
 #ifndef NDEBUG
-    fprintf(stderr, "%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n",
+    GGML_LOG_DEBUG("%s: warning: no backend supports op %s with a weight with buffer type %s used in tensor %s, the weight will need to be copied\n",
         __func__, ggml_op_desc(tensor), ggml_backend_buffer_name(buffer), tensor->name);
 #endif
 
@@ -1511,13 +1582,13 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
     for (int i = 0; i < graph->n_nodes; i++) {
         if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
             ggml_backend_t split_backend = sched->backends[sched->splits[cur_split].backend_id];
-            fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
+            GGML_LOG_DEBUG("\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend),
                 sched->splits[cur_split].n_inputs);
             for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
-                fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
+                GGML_LOG_DEBUG("[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name,
                     fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
             }
-            fprintf(stderr, "\n");
+            GGML_LOG_DEBUG("\n");
             cur_split++;
         }
         struct ggml_tensor * node = graph->nodes[i];
@@ -1525,7 +1596,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
             continue;
         }
         ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
-        fprintf(stderr, "node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
+        GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
             fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
@@ -1533,10 +1604,10 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
                 continue;
             }
             ggml_backend_t src_backend = ggml_backend_sched_get_tensor_backend(sched, src);
-            fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
+            GGML_LOG_DEBUG(" %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
                 fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", GET_CAUSE(src));
         }
-        fprintf(stderr, "\n");
+        GGML_LOG_DEBUG("\n");
     }
 }
 
@@ -2050,11 +2121,11 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
         // the re-allocation may cause the split inputs to be moved to a different address
         ggml_backend_sched_synchronize(sched);
 #ifndef NDEBUG
-        fprintf(stderr, "%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
+        GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
 #endif
         ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
         if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
-            fprintf(stderr, "%s: failed to allocate graph\n", __func__);
+            GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__);
             return false;
         }
     }
@@ -2197,6 +2268,7 @@ ggml_backend_sched_t ggml_backend_sched_new(
         sched->backends[b] = backends[b];
         sched->bufts[b] = bufts ? bufts[b] : ggml_backend_get_default_buffer_type(backends[b]);
         GGML_ASSERT(ggml_backend_supports_buft(backends[b], sched->bufts[b]));
+
         if (sched->n_copies > 1) {
             for (int c = 0; c < sched->n_copies; c++) {
                 sched->events[b][c] = ggml_backend_event_new(backends[b]->device);
@@ -2448,7 +2520,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
     struct ggml_context * ctx_unallocated = ggml_init(params);
 
     if (ctx_allocated == NULL || ctx_unallocated == NULL) {
-        fprintf(stderr, "failed to allocate context for graph copy\n");
+        GGML_LOG_ERROR("%s: failed to allocate context for graph copy\n", __func__);
         ggml_hash_set_free(&hash_set);
         free(node_copies);
         free(node_init);
@@ -2471,7 +2543,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s
     // allocate nodes
     ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx_allocated, backend);
     if (buffer == NULL) {
-        fprintf(stderr, "failed to allocate buffer for graph copy\n");
+        GGML_LOG_ERROR("%s: failed to allocate buffer for graph copy\n", __func__);
         ggml_hash_set_free(&hash_set);
         free(node_copies);
         free(node_init);
diff --git a/ggml/src/ggml-blas.cpp b/ggml/src/ggml-blas.cpp
index b850e6a8ded..7875ec86d08 100644
--- a/ggml/src/ggml-blas.cpp
+++ b/ggml/src/ggml-blas.cpp
@@ -4,6 +4,7 @@
 
 #include 
 #include 
+#include 
 
 #if defined(GGML_USE_ACCELERATE)
 #   include 
@@ -26,30 +27,6 @@ struct ggml_backend_blas_context {
 #endif
 };
 
-// helper function to determine if it is better to use BLAS or not
-// for large matrices, BLAS is faster
-static bool ggml_backend_blas_use_blas(const struct ggml_tensor * dst) {
-    const struct ggml_tensor * src0 = dst->src[0];
-    const struct ggml_tensor * src1 = dst->src[1];
-
-    const int64_t ne10 = src1->ne[0];
-
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
-
-    // TODO: find the optimal values for these
-    if (ggml_is_contiguous(src0) &&
-        ggml_is_contiguous(src1) &&
-        src1->type == GGML_TYPE_F32 &&
-        (ne0 >= 32 && ne1 >= 32 && ne10 >= 32)) {
-
-        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
-        return true;
-    }
-
-    return false;
-}
-
 static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct ggml_tensor * dst) {
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
@@ -88,8 +65,8 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
 
     // convert src0 to float
     if (type != GGML_TYPE_F32) {
-        ggml_type_traits_t type_traits = ggml_internal_get_type_traits(type);
-        ggml_to_float_t const to_float = type_traits.to_float;
+        const auto * type_traits = ggml_get_type_traits(type);
+        ggml_to_float_t const to_float = type_traits->to_float;
 
         for (int64_t i03 = 0; i03 < ne03; i03++) {
             for (int64_t i02 = 0; i02 < ne02; i02++) {
@@ -235,7 +212,7 @@ static void ggml_backend_blas_out_prod(ggml_backend_blas_context * ctx, struct g
 
 // backend interface
 
-static const char * ggml_backend_blas_name(ggml_backend_t backend) {
+static const char * ggml_backend_blas_get_name(ggml_backend_t backend) {
     return "BLAS";
 
     GGML_UNUSED(backend);
@@ -285,29 +262,8 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend,
     GGML_UNUSED(backend);
 }
 
-static bool ggml_backend_blas_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    const struct ggml_tensor * src0 = op->src[0];
-    const struct ggml_tensor * src1 = op->src[1];
-
-    return (op->op == GGML_OP_MUL_MAT  && ggml_backend_blas_use_blas(op)) ||
-           (op->op == GGML_OP_OUT_PROD && op->src[0]->type == GGML_TYPE_F32 &&
-                                          op->src[1]->type == GGML_TYPE_F32 &&
-                                          ggml_is_matrix(src0) &&
-                                          ggml_is_matrix(src1) &&
-                                          ggml_is_contiguous(src0) &&
-                                          (ggml_is_contiguous(src1) || ggml_is_transposed(src1)));
-
-    GGML_UNUSED(backend);
-}
-
-static bool ggml_backend_blas_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
-    return ggml_backend_buft_is_host(buft);
-
-    GGML_UNUSED(backend);
-}
-
 static struct ggml_backend_i blas_backend_i = {
-    /* .get_name                = */ ggml_backend_blas_name,
+    /* .get_name                = */ ggml_backend_blas_get_name,
     /* .free                    = */ ggml_backend_blas_free,
     /* .get_default_buffer_type = */ ggml_backend_blas_get_default_buffer_type,
     /* .set_tensor_async        = */ NULL,
@@ -319,8 +275,8 @@ static struct ggml_backend_i blas_backend_i = {
     /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_blas_graph_compute,
-    /* .supports_op             = */ ggml_backend_blas_supports_op,
-    /* .supports_buft           = */ ggml_backend_blas_supports_buft,
+    /* .supports_op             = */ NULL,
+    /* .supports_buft           = */ NULL,
     /* .offload_op              = */ NULL,
     /* .event_record            = */ NULL,
     /* .event_wait              = */ NULL,
@@ -337,18 +293,18 @@ ggml_backend_t ggml_backend_blas_init(void) {
     ggml_backend_t backend = new ggml_backend {
         /* .guid      = */ ggml_backend_blas_guid(),
         /* .interface = */ blas_backend_i,
-        /* .device    = */ nullptr,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_blas_reg(), 0),
         /* .context   = */ ctx,
     };
 
-#if !defined(NDEBUG) && defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
+#if defined(OPENBLAS_VERSION) && defined(GGML_USE_OPENMP)
     if (openblas_get_parallel() != OPENBLAS_OPENMP) {
-        fprintf(stderr, "%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
+        GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but OpenBLAS was compiled without OpenMP support\n", __func__);
     }
 #endif
 
-#if !defined(NDEBUG) && defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
-    fprintf(stderr, "%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
+#if defined(BLIS_ENABLE_CBLAS) && defined(GGML_USE_OPENMP) && !defined(BLIS_ENABLE_OPENMP)
+    GGML_LOG_DEBUG("%s: warning: ggml is using OpenMP, but BLIS was compiled without OpenMP support\n", __func__);
 #endif
 
     return backend;
@@ -364,3 +320,205 @@ void ggml_backend_blas_set_n_threads(ggml_backend_t backend_blas, int n_threads)
     ggml_backend_blas_context * ctx = (ggml_backend_blas_context *)backend_blas->context;
     ctx->n_threads = n_threads;
 }
+
+// device interface
+
+static const char * ggml_backend_blas_device_get_name(ggml_backend_dev_t dev) {
+    return "BLAS";
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_blas_device_get_description(ggml_backend_dev_t dev) {
+    #if defined(GGML_USE_ACCELERATE)
+        return "Accelerate";
+    #elif defined(GGML_BLAS_USE_MKL)
+        return "MKL";
+    #elif defined(GGML_BLAS_USE_BLIS)
+        return "BLIS";
+    #elif defined(GGML_BLAS_USE_NVPL)
+        return "NVPL";
+    #elif defined(OPENBLAS_VERSION)
+        return "OpenBLAS";
+    #else
+        return "BLAS";
+    #endif
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_blas_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    // TODO
+    *free = 0;
+    *total = 0;
+
+    GGML_UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_blas_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_CPU;
+
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_blas_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_blas_device_get_name(dev);
+    props->description = ggml_backend_blas_device_get_description(dev);
+    props->type        = ggml_backend_blas_device_get_type(dev);
+    ggml_backend_blas_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ true,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_blas_device_init(ggml_backend_dev_t dev, const char * params) {
+    return ggml_backend_blas_init();
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_blas_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_cpu_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_blas_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    return ggml_backend_cpu_buffer_from_ptr(ptr, size);
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(max_tensor_size);
+}
+
+static bool ggml_backend_blas_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    const struct ggml_tensor * src0 = op->src[0];
+    const struct ggml_tensor * src1 = op->src[1];
+
+    switch (op->op) {
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            return true;
+
+        case GGML_OP_MUL_MAT:
+        {
+            // BLAS usually is only faster for large matrices
+            const struct ggml_tensor * src0 = op->src[0];
+            const struct ggml_tensor * src1 = op->src[1];
+
+            const int64_t ne10 = src1->ne[0];
+
+            const int64_t ne0 = op->ne[0];
+            const int64_t ne1 = op->ne[1];
+
+            // TODO: find the optimal value
+            const int64_t min_batch = 32;
+
+            return ggml_is_contiguous(src0) &&
+                   ggml_is_contiguous(src1) &&
+                   src1->type == GGML_TYPE_F32 &&
+                   (ne0 >= min_batch && ne1 >= min_batch && ne10 >= min_batch) &&
+                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
+        }
+
+        case GGML_OP_OUT_PROD:
+            return op->src[0]->type == GGML_TYPE_F32 &&
+                   op->src[1]->type == GGML_TYPE_F32 &&
+                   ggml_is_matrix(src0) &&
+                   ggml_is_matrix(src1) &&
+                   ggml_is_contiguous(src0) &&
+                   (ggml_is_contiguous(src1) || ggml_is_transposed(src1)) &&
+                   (src0->type == GGML_TYPE_F32 || ggml_get_type_traits(src0->type)->to_float != NULL);
+
+        default:
+            return false;
+
+    }
+
+    GGML_UNUSED(dev);
+}
+
+static bool ggml_backend_blas_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return ggml_backend_buft_is_host(buft);
+
+    GGML_UNUSED(dev);
+}
+
+static const struct ggml_backend_device_i ggml_backend_blas_device_i = {
+    /* .get_name             = */ ggml_backend_blas_device_get_name,
+    /* .get_description      = */ ggml_backend_blas_device_get_description,
+    /* .get_memory           = */ ggml_backend_blas_device_get_memory,
+    /* .get_type             = */ ggml_backend_blas_device_get_type,
+    /* .get_props            = */ ggml_backend_blas_device_get_props,
+    /* .init_backend         = */ ggml_backend_blas_device_init,
+    /* .get_buffer_type      = */ ggml_backend_blas_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_blas_device_buffer_from_ptr,
+    /* .supports_op          = */ ggml_backend_blas_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_blas_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// backend reg interface
+
+static const char * ggml_backend_blas_reg_get_name(ggml_backend_reg_t reg) {
+    return "BLAS";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_blas_reg_get_device_count(ggml_backend_reg_t reg) {
+    return 1;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_blas_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    static ggml_backend_device ggml_backend_blas_device = {
+        /* .iface   = */ ggml_backend_blas_device_i,
+        /* .reg     = */ reg,
+        /* .context = */ nullptr,
+    };
+
+    return &ggml_backend_blas_device;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(index);
+}
+
+static void * ggml_backend_blas_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (std::strcmp(name, "ggml_backend_set_n_threads") == 0) {
+        return (void *)ggml_backend_blas_set_n_threads;
+    }
+    return NULL;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(name);
+}
+
+static const struct ggml_backend_reg_i ggml_backend_blas_reg_i = {
+    /* .get_name         = */ ggml_backend_blas_reg_get_name,
+    /* .get_device_count = */ ggml_backend_blas_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_blas_reg_get_device,
+    /* .get_proc_address = */ ggml_backend_blas_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_blas_reg(void) {
+    static struct ggml_backend_reg ggml_backend_blas_reg = {
+        /* .iface   = */ ggml_backend_blas_reg_i,
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_blas_reg;
+}
diff --git a/ggml/src/ggml-cann.cpp b/ggml/src/ggml-cann.cpp
index db5f8f1865d..af0fb603a7c 100644
--- a/ggml/src/ggml-cann.cpp
+++ b/ggml/src/ggml-cann.cpp
@@ -39,6 +39,8 @@
 
 #include "ggml-common.h"
 
+#define GGML_CANN_NAME "CANN"
+
 /**
  * @brief Handles CANN errors by printing an error message and aborting.
  *
@@ -851,13 +853,6 @@ static void ggml_backend_cann_buffer_set_tensor(
         void *transform_buffer = malloc(size);
         ggml_backend_cann_transform(tensor, data, transform_buffer);
 
-#ifndef NDEBUG
-        void *check_buffer = malloc(size);
-        ggml_backend_cann_transform_back(tensor, transform_buffer,
-                                         check_buffer);
-        GGML_ASSERT(memcmp(data, check_buffer, size) == 0);
-        free(check_buffer);
-#endif
         ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size,
                               transform_buffer, size,
                               ACL_MEMCPY_HOST_TO_DEVICE));
@@ -969,7 +964,7 @@ static void ggml_backend_cann_buffer_clear(
  * This structure defines function pointers to operations that can be performed
  * on a CANN buffer within the backend.
  */
-static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
+static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
     /* .get_name        = */ ggml_backend_cann_buffer_get_name,
     /* .free_buffer     = */ ggml_backend_cann_buffer_free_buffer,
     /* .get_base        = */ ggml_backend_cann_buffer_get_base,
@@ -1105,19 +1100,25 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
     GGML_UNUSED(buft);
 }
 
+static bool ggml_backend_cann_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return false;
+
+    GGML_UNUSED(buft);
+}
+
 /**
  * @brief Interface for managing CANN buffer types in the GGML backend.
  *
  * Provides function pointers for allocating, querying properties, and managing
  * memory for CANN buffer types in the GGML backend.
  */
-static ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
+static const ggml_backend_buffer_type_i ggml_backend_cann_buffer_type_interface = {
     /* .get_name         = */ ggml_backend_cann_buffer_type_name,
     /* .alloc_buffer     = */ ggml_backend_cann_buffer_type_alloc_buffer,
     /* .get_alignment    = */ ggml_backend_cann_buffer_type_get_alignment,
     /* .get_max_size     = */ NULL,  // defaults to SIZE_MAX
     /* .get_alloc_size   = */ ggml_backend_cann_buffer_type_get_alloc_size,
-    /* .is_host          = */ NULL,
+    /* .is_host          = */ ggml_backend_cann_buffer_type_is_host,
 };
 
 /**
@@ -1148,6 +1149,7 @@ ggml_backend_cann_buffer_type(int32_t device) {
         for (int32_t i = 0; i < GGML_CANN_MAX_DEVICES; i++) {
             ggml_backend_cann_buffer_types[i] = {
                 /* .iface    = */ ggml_backend_cann_buffer_type_interface,
+                /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
                 /* .context  = */
                  new ggml_backend_cann_buffer_type_context{
                     i, "CANN" + std::to_string(i)},
@@ -1263,7 +1265,7 @@ ggml_backend_buffer_type_t ggml_backend_cann_host_buffer_type() {
             /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
             /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
         },
-        /* .device   = */ nullptr,
+        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), 0),
         /* .context  = */ nullptr,
     };
 
@@ -1510,13 +1512,6 @@ static void ggml_backend_cann_set_tensor_async(ggml_backend_t backend,
         void *transform_buffer = malloc(size);
         ggml_backend_cann_transform(tensor, data, transform_buffer);
 
-#ifndef NDEBUG
-        void *check_buffer = malloc(size);
-        ggml_backend_cann_transform_back(tensor, transform_buffer,
-                                         check_buffer);
-        GGML_ASSERT(memcmp(data, check_buffer, size));
-        free(check_buffer);
-#endif
         ACL_CHECK(aclrtMemcpyAsync(
             (char *)tensor->data + offset, size, transform_buffer, size,
             ACL_MEMCPY_HOST_TO_DEVICE, cann_ctx->stream()));
@@ -1691,7 +1686,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(
  * @return bool Returns true if the operation is supported by the backend,
  *              otherwise false.
  */
-static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
+static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
                                                     const ggml_tensor* op) {
     switch (op->op) {
         case GGML_OP_UNARY:
@@ -1782,7 +1777,7 @@ static bool ggml_backend_cann_supports_op(ggml_backend_t backend,
             return false;
     }
 
-    GGML_UNUSED(backend);
+    GGML_UNUSED(dev);
 }
 
 /**
@@ -1800,31 +1795,6 @@ static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
     return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
 }
 
-/**
- * @brief Checks if the CANN backend supports a specific backend buffer type.
- *
- * This function determines whether the CANN backend supports the given backend
- * buffer type by comparing the device context of the backend and buffer type.
- * It returns true if the devices are same between the backend context and
- * buffer type context.
- *
- * @param backend Pointer to the CANN backend.
- * @param buft Pointer to the backend buffer type to check.
- * @return bool Returns true if the CANN backend supports the buffer type,
- *              otherwise false.
- */
-static bool ggml_backend_cann_supports_buft(
-    ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
-    if (ggml_backend_buft_is_cann(buft)) {
-        ggml_backend_cann_context * cann_ctx =
-                        (ggml_backend_cann_context *)backend->context;
-        ggml_backend_cann_buffer_type_context * buft_ctx =
-                        (ggml_backend_cann_buffer_type_context *)buft->context;
-        return buft_ctx->device == cann_ctx->device;
-    }
-    return false;
-}
-
 /**
  * @brief Determines if a tensor operation should be offloaded to the CANN
  * backend.
@@ -1839,54 +1809,14 @@ static bool ggml_backend_cann_supports_buft(
  * @return bool Returns true if the operation should be offloaded, otherwise
  * false.
  */
-static bool ggml_backend_cann_offload_op(ggml_backend_t backend,
+static bool ggml_backend_cann_offload_op(ggml_backend_dev_t dev,
                                                    const ggml_tensor* op) {
     const int min_batch_size = 32;
-    GGML_UNUSED(backend);
+    GGML_UNUSED(dev);
 
     return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
 }
 
-/**
- * @brief Creates a new event for the CANN backend.
- *
- * This function initializes a new event for the CANN backend by setting the
- * device and creating an ACL runtime event. The created event is then wrapped
- * in a ggml_backend_event structure and returned.
- *
- * @param backend Pointer to the CANN backend.
- * @return ggml_backend_event_t Returns a pointer to the new event structure.
- */
-static ggml_backend_event_t ggml_backend_cann_event_new(
-    ggml_backend_t backend) {
-    ggml_backend_cann_context* cann_ctx =
-        (ggml_backend_cann_context*)backend->context;
-
-    ggml_cann_set_device(cann_ctx->device);
-
-    aclrtEvent event;
-    ACL_CHECK(aclrtCreateEvent(&event));
-
-    return new ggml_backend_event{
-        /* .backend = */ backend,
-        /* .context = */ event,
-    };
-}
-
-/**
- * @brief Frees a CANN backend event.
- *
- * This function destroys the ACL runtime event associated with the given CANN
- * backend event and then deletes the event structure itself.
- *
- * @param event Pointer to the event structure to be freed.
- */
-static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
-    ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
-
-    delete event;
-}
-
 /**
  * @brief Records an event on the CANN backend stream.
  *
@@ -1895,10 +1825,9 @@ static void ggml_backend_cann_event_free(ggml_backend_event_t event) {
  *
  * @param event Pointer to the event structure to be recorded.
  */
-static void ggml_backend_cann_event_record(ggml_backend_event_t event) {
+static void ggml_backend_cann_event_record(ggml_backend_t backend, ggml_backend_event_t event) {
     ggml_backend_cann_context* cann_ctx =
-        (ggml_backend_cann_context*)event->backend->context;
-
+        (ggml_backend_cann_context*)backend->context;
     ACL_CHECK(aclrtRecordEvent((aclrtEvent)event->context, cann_ctx->stream()));
 }
 
@@ -1916,8 +1845,7 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
                                          ggml_backend_event_t event) {
     ggml_backend_cann_context* cann_ctx =
         (ggml_backend_cann_context*)backend->context;
-
-    if (ggml_backend_is_cann(event->backend)) {
+    if (ggml_backend_is_cann(backend)) {
         ACL_CHECK(aclrtStreamWaitEvent(cann_ctx->stream(),
                                        (aclrtEvent)event->context));
     } else {
@@ -1925,17 +1853,6 @@ static void ggml_backend_cann_event_wait(ggml_backend_t backend,
     }
 }
 
-/**
- * @brief Synchronizes the given event on the CANN backend.
- *
- * This function waits for the specified event to complete on the ACL runtime.
- *
- * @param event Pointer to the event structure to be synchronized.
- */
-static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
-    ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
-}
-
 /**
  * @brief Structure defining the interface for the CANN backend.
  *
@@ -1943,7 +1860,7 @@ static void ggml_backend_cann_event_synchronize(ggml_backend_event_t event) {
  * supported by the CANN backend, including name retrieval, memory
  * management, tensor operations, synchronization, and event handling.
  */
-static ggml_backend_i ggml_backend_cann_interface = {
+static const ggml_backend_i ggml_backend_cann_interface = {
     /* .get_name                = */ ggml_backend_cann_name,
     /* .free                    = */ ggml_backend_cann_free,
     /* .get_default_buffer_type = */ ggml_backend_cann_get_default_buffer_type,
@@ -1956,9 +1873,9 @@ static ggml_backend_i ggml_backend_cann_interface = {
     /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_cann_graph_compute,
-    /* .supports_op             = */ ggml_backend_cann_supports_op,
-    /* .supports_buft           = */ ggml_backend_cann_supports_buft,
-    /* .offload_op              = */ ggml_backend_cann_offload_op,
+    /* .supports_op             = */ NULL, // moved to device
+    /* .supports_buft           = */ NULL, // moved to device
+    /* .offload_op              = */ NULL, // moved to device
     /* .event_record            = */ ggml_backend_cann_event_record,
     /* .event_wait              = */ ggml_backend_cann_event_wait,
 };
@@ -1977,6 +1894,234 @@ static ggml_guid_t ggml_backend_cann_guid() {
     return &guid;
 }
 
+// backend device
+struct ggml_backend_cann_device_context {
+    int device;
+    std::string name;
+    std::string description;
+};
+
+static const char * ggml_backend_cann_device_get_name(ggml_backend_dev_t dev) {
+    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
+    return ctx->name.c_str();
+}
+
+static const char* ggml_backend_cann_device_get_description(ggml_backend_dev_t dev) {
+    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
+    return ctx->description.c_str();
+}
+
+static void ggml_backend_cann_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
+    ggml_backend_cann_get_device_memory(ctx->device, free, total);
+}
+
+static enum ggml_backend_dev_type ggml_backend_cann_device_get_type(ggml_backend_dev_t dev) {
+    GGML_UNUSED(dev);
+    return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
+}
+
+static void ggml_backend_cann_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_cann_device_get_name(dev);
+    props->description = ggml_backend_cann_device_get_description(dev);
+    props->type        = ggml_backend_cann_device_get_type(dev);
+    ggml_backend_cann_device_get_memory(dev, &props->memory_free, &props->memory_total);
+
+    bool host_buffer = getenv("GGML_CANN_NO_PINNED") == nullptr;
+
+    props->caps = {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ host_buffer,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ true,
+    };
+}
+
+static ggml_backend_t ggml_backend_cann_device_init(ggml_backend_dev_t dev, const char * params) {
+    GGML_UNUSED(params);
+    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
+    return ggml_backend_cann_init(ctx->device);
+}
+
+/**
+ * @brief Checks if the CANN backend supports a specific backend buffer type.
+ *
+ * This function determines whether the CANN backend supports the given backend
+ * buffer type by comparing the device context of the backend and buffer type.
+ * It returns true if the devices are same between the backend context and
+ * buffer type context.
+ *
+ * @param backend Pointer to the CANN backend.
+ * @param buft Pointer to the backend buffer type to check.
+ * @return bool Returns true if the CANN backend supports the buffer type,
+ *              otherwise false.
+ */
+static bool ggml_backend_cann_supports_buft(
+    ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    if (ggml_backend_buft_is_cann(buft)) {
+        ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
+        ggml_backend_cann_buffer_type_context * buft_ctx =
+                        (ggml_backend_cann_buffer_type_context *)buft->context;
+        return buft_ctx->device == dev_ctx->device;
+    }
+    return false;
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cann_device_get_buffer_type(ggml_backend_dev_t dev) {
+    ggml_backend_cann_device_context * ctx = (ggml_backend_cann_device_context *)dev->context;
+    return ggml_backend_cann_buffer_type(ctx->device);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_cann_device_get_host_buffer_type(ggml_backend_dev_t dev) {
+    GGML_UNUSED(dev);
+    return ggml_backend_cann_host_buffer_type();
+}
+
+/**
+ * @brief Creates a new event for the CANN backend device.
+ *
+ * This function initializes a new event for the CANN backend by setting the
+ * device and creating an ACL runtime event. The created event is then wrapped
+ * in a ggml_backend_event structure and returned.
+ *
+ * @param backend Pointer to the CANN backend.
+ * @return ggml_backend_event_t Returns a pointer to the new event structure.
+ */
+static ggml_backend_event_t ggml_backend_cann_device_event_new(
+    ggml_backend_dev_t dev) {
+    ggml_backend_cann_device_context * dev_ctx = (ggml_backend_cann_device_context *)dev->context;
+
+    ggml_cann_set_device(dev_ctx->device);
+
+    aclrtEvent event;
+    ACL_CHECK(aclrtCreateEvent(&event));
+
+    return new ggml_backend_event{
+        /* .device = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), dev_ctx->device),
+        /* .context = */ event,
+    };
+}
+
+/**
+ * @brief Frees a CANN backend event.
+ *
+ * This function destroys the ACL runtime event associated with the given CANN
+ * backend event and then deletes the event structure itself.
+ *
+ * @param event Pointer to the event structure to be freed.
+ */
+static void ggml_backend_cann_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+    ACL_CHECK(aclrtDestroyEvent((aclrtEvent)event->context));
+
+    delete event;
+    GGML_UNUSED(dev);
+}
+
+/**
+ * @brief Synchronizes the given event on the CANN backend.
+ *
+ * This function waits for the specified event to complete on the ACL runtime.
+ *
+ * @param event Pointer to the event structure to be synchronized.
+ */
+static void ggml_backend_cann_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) {
+    ACL_CHECK(aclrtSynchronizeEvent((aclrtEvent)event->context));
+
+    GGML_UNUSED(dev);
+}
+
+static const ggml_backend_device_i ggml_backend_cann_device_interface = {
+    /* .get_name                = */ ggml_backend_cann_device_get_name,
+    /* .get_description         = */ ggml_backend_cann_device_get_description,
+    /* .get_memory              = */ ggml_backend_cann_device_get_memory,
+    /* .get_type                = */ ggml_backend_cann_device_get_type,
+    /* .get_props               = */ ggml_backend_cann_device_get_props,
+    /* .init_backend            = */ ggml_backend_cann_device_init,    // called for every card
+    /* .get_buffer_type         = */ ggml_backend_cann_device_get_buffer_type,
+    /* .get_host_buffer_type    = */ ggml_backend_cann_device_get_host_buffer_type,
+    /* .buffer_from_host_ptr    = */ NULL, // not supported for CANN
+    /* .supports_op             = */ ggml_backend_cann_supports_op,
+    /* .supports_buft           = */ ggml_backend_cann_supports_buft,
+    /* .offload_op              = */ ggml_backend_cann_offload_op,
+    /* .event_new               = */ ggml_backend_cann_device_event_new,
+    /* .event_free              = */ ggml_backend_cann_device_event_free,
+    /* .event_synchronize       = */ ggml_backend_cann_device_event_synchronize,
+};
+
+
+// backend reg
+struct ggml_backend_cann_reg_context {
+    std::vector devices;
+};
+
+static const char * ggml_backend_cann_reg_get_name(ggml_backend_reg_t reg) {
+    GGML_UNUSED(reg);
+    return GGML_CANN_NAME;
+}
+
+static size_t ggml_backend_cann_reg_get_device_count(ggml_backend_reg_t reg) {
+    ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
+    return ctx->devices.size();
+}
+
+static ggml_backend_dev_t ggml_backend_cann_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    ggml_backend_cann_reg_context * ctx = (ggml_backend_cann_reg_context *)reg->context;
+    GGML_ASSERT(index < ctx->devices.size());
+    return ctx->devices[index];
+}
+
+static void * ggml_backend_cann_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    GGML_UNUSED(reg);
+    GGML_UNUSED(name);
+    // reserved for future use
+    return nullptr;
+}
+
+static const ggml_backend_reg_i ggml_backend_cann_reg_interface = {
+    /* .get_name          = */ ggml_backend_cann_reg_get_name,
+    /* .get_device_count  = */ ggml_backend_cann_reg_get_device_count,
+    /* .get_device_get    = */ ggml_backend_cann_reg_get_device,
+    /* .get_proc_address  = */ ggml_backend_cann_reg_get_proc_address,
+};
+
+// backend registry, called only once for cann backend
+ggml_backend_reg_t ggml_backend_cann_reg() {
+    static ggml_backend_reg reg;
+    static bool initialized = false;
+
+    {
+        static std::mutex mutex;
+        std::lock_guard lock(mutex);
+        if (!initialized) {
+            aclInit(nullptr);
+            ggml_backend_cann_reg_context * ctx = new ggml_backend_cann_reg_context;
+
+            for (int i = 0; i < ggml_cann_info().device_count; i++) {
+                ggml_backend_cann_device_context* dev_ctx = new ggml_backend_cann_device_context();
+                dev_ctx->description = aclrtGetSocName();
+                dev_ctx->device = i;
+                dev_ctx->name = GGML_CANN_NAME + std::to_string(i);
+                ggml_cann_set_device(i);
+                ggml_backend_dev_t dev = new ggml_backend_device {
+                    /* .interface = */ ggml_backend_cann_device_interface,
+                    /* .reg       = */ ®,
+                    /* .context   = */ dev_ctx
+                };
+                ctx->devices.push_back(dev);
+            }
+
+            reg = ggml_backend_reg {
+                /* .interface = */ ggml_backend_cann_reg_interface,
+                /* .context   = */ ctx
+            };
+        }
+
+        initialized = true;
+    }
+
+    return ®
+}
+
 ggml_backend_t ggml_backend_cann_init(int32_t device) {
     aclInit(nullptr);
     if (device < 0 || device >= ggml_backend_cann_get_device_count()) {
@@ -1993,7 +2138,7 @@ ggml_backend_t ggml_backend_cann_init(int32_t device) {
     ggml_backend_t cann_backend =
         new ggml_backend{/* .guid      = */ ggml_backend_cann_guid(),
                          /* .interface = */ ggml_backend_cann_interface,
-                         /* .device    = */ nullptr,
+                         /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_cann_reg(), device),
                          /* .context   = */ ctx};
 
     return cann_backend;
diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu
index 5b6f605b008..21c9f5e3829 100644
--- a/ggml/src/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda.cu
@@ -291,7 +291,7 @@ struct ggml_cuda_pool_leg : public ggml_cuda_pool {
                 return;
             }
         }
-        GGML_LOG_WARN(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
+        GGML_LOG_DEBUG(GGML_CUDA_NAME " buffer pool full, increase MAX_CUDA_BUFFERS\n");
         ggml_cuda_set_device(device);
         CUDA_CHECK(cudaFree(ptr));
         pool_size -= size;
@@ -980,7 +980,7 @@ static void * ggml_cuda_host_malloc(size_t size) {
     if (err != cudaSuccess) {
         // clear the error
         cudaGetLastError();
-        GGML_LOG_WARN("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
+        GGML_LOG_DEBUG("%s: failed to allocate %.2f MiB of pinned memory: %s\n", __func__,
                            size / 1024.0 / 1024.0, cudaGetErrorString(err));
         return nullptr;
     }
@@ -1151,8 +1151,8 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
     void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
 
     GGML_ASSERT(ggml_backend_buffer_is_cuda(src->buffer));
-    char * src_ptr = (char *) src->data;
-    char * dst_ptr = (char *) dst;
+    const char * src_ptr = (const char *) src->data;
+    char       * dst_ptr = (char       *) dst;
 
     const int64_t ne0 = src->ne[0];
     const int64_t nb0 = src->nb[0];
@@ -1162,7 +1162,7 @@ static cudaError_t ggml_cuda_cpy_tensor_2d(
     const enum ggml_type type = src->type;
     const int64_t ts = ggml_type_size(type);
     const int64_t bs = ggml_blck_size(type);
-    int64_t i1_diff = i1_high - i1_low;
+    const int64_t i1_diff = i1_high - i1_low;
 
     const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
     if (nb0 == ts && nb1 == ts*ne0/bs) {
@@ -1479,13 +1479,18 @@ static void ggml_cuda_op_mul_mat(
         if (src0_is_contiguous) {
             dev[id].src0_dd = split ? (char *) src0_extra->data_device[id] : (char *) src0->data;
         } else {
-            dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), ggml_nbytes(src0));
+            // If src0 is not contiguous it will be copied to a temporary buffer.
+            // This buffer needs to be cleared entirely because multiple regions will function as padding.
+            const size_t nbytes_data    = ggml_nbytes(src0);
+            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+            dev[id].src0_dd = dev[id].src0_dd_alloc.alloc(ctx.pool(id), nbytes_data + nbytes_padding);
+            CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd, 0, nbytes_data + nbytes_padding, stream));
         }
 
-        // If src0 is on a temporary compute buffers (partial offloading) there may be some padding that needs to be cleared:
+        // If src0 is on a temporary compute buffer (partial offloading) there may be some padding that needs to be cleared:
         if (ne00 % MATRIX_ROW_PADDING != 0 && ggml_is_quantized(src0->type) && ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE && src0->view_src == nullptr) {
-            const int64_t nbytes_data    = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
-            const int64_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
+            const size_t nbytes_data    = ggml_row_size(src0->type, (dev[id].row_high - dev[id].row_low)*ne00);
+            const size_t nbytes_padding = ggml_row_size(src0->type, MATRIX_ROW_PADDING - ne00 % MATRIX_ROW_PADDING);
             CUDA_CHECK(cudaMemsetAsync(dev[id].src0_dd + nbytes_data , 0, nbytes_padding, stream));
         }
 
@@ -2406,7 +2411,7 @@ static bool ggml_backend_cuda_cpy_tensor_async(ggml_backend_t backend_src, ggml_
 
     if (cuda_ctx_src->device != buf_ctx_src->device || cuda_ctx_dst->device != buf_ctx_dst->device) {
 #ifndef NDEBUG
-        GGML_LOG_WARN("%s: backend and buffer devices do not match\n", __func__);
+        GGML_LOG_DEBUG("%s: backend and buffer devices do not match\n", __func__);
 #endif
         return false;
     }
@@ -2524,7 +2529,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
             cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
 #ifndef NDEBUG
-            GGML_LOG_WARN("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
 #endif
         }
     }
@@ -2575,14 +2580,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
             if (node->src[0] && node->src[0]->buffer && ggml_backend_buffer_is_cuda_split(node->src[0]->buffer)) {
                 use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
 #ifndef NDEBUG
-                GGML_LOG_WARN("%s: disabling CUDA graphs due to split buffer\n", __func__);
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
 #endif
             }
 
             if (node->op == GGML_OP_MUL_MAT_ID) {
                 use_cuda_graph = false; // This node type is not supported by CUDA graph capture
 #ifndef NDEBUG
-                GGML_LOG_WARN("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to mul_mat_id\n", __func__);
 #endif
             }
 
@@ -2591,7 +2596,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
                 // Changes in batch size or context size can cause changes to the grid size of some kernels.
                 use_cuda_graph = false;
 #ifndef NDEBUG
-                GGML_LOG_WARN("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
 #endif
             }
 
@@ -2603,7 +2608,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
                 if (!ptr) {
                     use_cuda_graph = false;
 #ifndef NDEBUG
-                    GGML_LOG_WARN("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
+                    GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
 #endif
                 } else {
                     if (std::find(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), ptr) == ggml_cuda_cpy_fn_ptrs.end()) {
@@ -2627,7 +2632,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
         if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
             cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
 #ifndef NDEBUG
-            GGML_LOG_WARN("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
+            GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
 #endif
         }
     }
@@ -2685,7 +2690,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
                 use_cuda_graph = false;
                 cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture = true;
 #ifndef NDEBUG
-                GGML_LOG_WARN("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
+                GGML_LOG_DEBUG("%s: disabling CUDA graphs due to failed graph capture\n", __func__);
 #endif
             } else {
                 graph_evaluated_or_captured = true; // CUDA graph has been captured
@@ -2854,7 +2859,7 @@ bool ggml_backend_cuda_register_host_buffer(void * buffer, size_t size) {
         // clear the error
         cudaGetLastError();
 
-        GGML_LOG_WARN("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
+        GGML_LOG_DEBUG("%s: failed to register %.2f MiB of pinned memory: %s\n", __func__,
                            size / 1024.0 / 1024.0, cudaGetErrorString(err));
         return false;
     }
@@ -2920,9 +2925,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
 #endif
 
     props->caps = {
-        /* async       */ true,
-        /* host_buffer */ host_buffer,
-        /* events      */ events,
+        /* .async                 = */ true,
+        /* .host_buffer           = */ host_buffer,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ events,
     };
 }
 
@@ -3140,7 +3146,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
         case GGML_OP_ROPE:
             return ggml_is_contiguous(op->src[0]);
         case GGML_OP_IM2COL:
-            return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM:
         case GGML_OP_SUM_ROWS:
diff --git a/ggml/src/ggml-cuda/cpy.cuh b/ggml/src/ggml-cuda/cpy.cuh
index 7961674266e..28b06cddaa8 100644
--- a/ggml/src/ggml-cuda/cpy.cuh
+++ b/ggml/src/ggml-cuda/cpy.cuh
@@ -1,6 +1,6 @@
 #include "common.cuh"
 
-#define CUDA_CPY_BLOCK_SIZE 32
+#define CUDA_CPY_BLOCK_SIZE 64
 
 void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
 
diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu
index 96a5adef5b2..00e21b5d77e 100644
--- a/ggml/src/ggml-cuda/dmmv.cu
+++ b/ggml/src/ggml-cuda/dmmv.cu
@@ -416,10 +416,11 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx,
 
 static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
     const half * x = (const half *) vx;
-
+    // load 2 halfs into register in a single instruction
+    const half2 x_reg = *((half2 *) &(x[ib + iqs]));
     // automatic half -> float type cast if dfloat == float
-    v.x = x[ib + iqs + 0];
-    v.y = x[ib + iqs + 1];
+    v.x = __low2float(x_reg);
+    v.y = __high2float(x_reg);
 }
 
 static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) {
@@ -476,13 +477,28 @@ static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, cons
             // matrix multiplication
             // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
 #ifdef GGML_CUDA_F16
-            tmp += __hmul2(v, {
-                y[iybs + iqs + j/qr + 0],
-                y[iybs + iqs + j/qr + y_offset]
-            });
+            if ( y_offset == 1 ) {
+                // load 2 dfloats into register in a single instruction
+                const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
+                tmp += __hmul2(v, y_reg);
+            }
+            else {
+                tmp += __hmul2(v, {
+                        y[iybs + iqs + j/qr + 0],
+                        y[iybs + iqs + j/qr + y_offset]
+                    });
+            }
 #else
-            tmp += v.x * y[iybs + iqs + j/qr + 0];
-            tmp += v.y * y[iybs + iqs + j/qr + y_offset];
+            if ( y_offset == 1 ) {
+                // load 2 dfloats into register in a single instruction
+                const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr]));
+                tmp += v.x * y_reg.x;
+                tmp += v.y * y_reg.y;
+            }
+            else {
+                tmp += v.x * y[iybs + iqs + j/qr + 0];
+                tmp += v.y * y[iybs + iqs + j/qr + y_offset];
+            }
 #endif // GGML_CUDA_F16
         }
     }
diff --git a/ggml/src/ggml-cuda/im2col.cu b/ggml/src/ggml-cuda/im2col.cu
index 16463ab0fb6..86a54e42bb7 100644
--- a/ggml/src/ggml-cuda/im2col.cu
+++ b/ggml/src/ggml-cuda/im2col.cu
@@ -91,9 +91,9 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const int64_t OH = is_2D ? dst->ne[2] : 1;
     const int64_t OW =         dst->ne[1];
 
-    const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
-    const int64_t batch = src1->ne[3];
-    const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
+    const size_t  delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+    const int64_t batch        = src1->ne[is_2D ? 3 : 2];
+    const size_t  batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
 
     if(dst->type == GGML_TYPE_F16) {
         im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 4935f881867..ae5c68ab351 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -8,8 +8,6 @@ void ggml_cuda_op_mul_mat_q(
 
     const int64_t ne00 = src0->ne[0];
 
-    const int64_t nb01 = src0->nb[1];
-
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
     GGML_ASSERT(ne10 % QK8_1 == 0);
@@ -17,7 +15,7 @@ void ggml_cuda_op_mul_mat_q(
     const int64_t ne0 = dst->ne[0];
 
     const int64_t row_diff = row_high - row_low;
-    const int64_t stride00 = nb01 / ggml_type_size(src0->type);
+    const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
 
     int id = ggml_cuda_get_device();
     const int compute_capability = ggml_cuda_info().devices[id].cc;
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index d3f4bad8c0a..65c4f81195b 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -19,6 +19,9 @@ extern "C" {
 #define MIN(a, b) ((a) < (b) ? (a) : (b))
 #define MAX(a, b) ((a) > (b) ? (a) : (b))
 
+// required for mmap as gguf only guarantees 32-byte alignment
+#define TENSOR_ALIGNMENT 32
+
 // static_assert should be a #define, but if it's not,
 // fall back to the _Static_assert C11 keyword.
 // if C99 - static_assert is noop
@@ -196,6 +199,11 @@ struct ggml_cgraph {
 
 struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph, int i0, int i1);
 
+// Memory allocation
+
+void * ggml_aligned_malloc(size_t size);
+void ggml_aligned_free(void * ptr, size_t size);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m
index 7e0b866a99b..80c08f15b29 100644
--- a/ggml/src/ggml-metal.m
+++ b/ggml/src/ggml-metal.m
@@ -20,6 +20,69 @@
 
 #define UNUSED(x) (void)(x)
 
+// globals
+
+// overload of MTLGPUFamilyMetal3 (not available in some environments)
+static const NSInteger MTLGPUFamilyMetal3_GGML = 5001;
+
+// initialized in ggml_backend_metal_reg
+static struct ggml_backend_reg    g_ggml_backend_metal_reg;
+static struct ggml_backend_device g_ggml_backend_metal_device;
+
+// information about a Metal device
+// note: assumes single GPU device - the default one
+// TODO: support multiple GPU devices
+static struct ggml_backend_metal_device_context {
+    id mtl_device;
+    int           mtl_device_ref_count;
+
+    bool support_simdgroup_reduction;
+    bool support_simdgroup_mm;
+
+    char name[128];
+} g_ggml_ctx_dev_main = {
+    /*.mtl_device                  =*/ nil,
+    /*.mtl_device_ref_count        =*/ 0,
+    /*.support_simdgroup_reduction =*/ false,
+    /*.support_simdgroup_mm        =*/ false,
+    /*.name                        =*/ "",
+};
+
+// acquire
+static id ggml_backend_metal_device_acq(struct ggml_backend_metal_device_context * ctx) {
+    assert(ctx != NULL);
+
+    if (ctx->mtl_device == nil) {
+        ctx->mtl_device = MTLCreateSystemDefaultDevice();
+
+        ctx->support_simdgroup_reduction  = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+        ctx->support_simdgroup_reduction |= [ctx->mtl_device supportsFamily:MTLGPUFamilyMetal3_GGML];
+
+        ctx->support_simdgroup_mm = [ctx->mtl_device supportsFamily:MTLGPUFamilyApple7];
+
+        strncpy(ctx->name, [[ctx->mtl_device name] UTF8String], sizeof(ctx->name) - 1);
+    }
+
+    ctx->mtl_device_ref_count++;
+
+    return ctx->mtl_device;
+}
+
+// release
+static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_context * ctx) {
+    assert(ctx != NULL);
+    assert(ctx->mtl_device_ref_count > 0);
+
+    ctx->mtl_device_ref_count--;
+
+    if (ctx->mtl_device_ref_count == 0) {
+        [ctx->mtl_device release];
+        ctx->mtl_device = nil;
+    }
+}
+
+// kernels
+
 struct ggml_metal_kernel {
     id pipeline;
 };
@@ -178,6 +241,8 @@
     GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
+    GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
+    GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
     GGML_METAL_KERNEL_TYPE_PAD_F32,
     GGML_METAL_KERNEL_TYPE_ARANGE_F32,
@@ -209,21 +274,19 @@
     GGML_METAL_KERNEL_TYPE_SIN,
     GGML_METAL_KERNEL_TYPE_COS,
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
+    GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
+    GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
 
     GGML_METAL_KERNEL_TYPE_COUNT
 };
 
 struct ggml_backend_metal_context {
-    id       device;
     id queue;
 
     dispatch_queue_t d_queue;
 
     struct ggml_metal_kernel kernels[GGML_METAL_KERNEL_TYPE_COUNT];
 
-    bool support_simdgroup_reduction;
-    bool support_simdgroup_mm;
-
     // capture state
     bool capture_next_compute;
     bool capture_started;
@@ -239,8 +302,6 @@
     struct ggml_cgraph * gf;
 
     // the callback given to the thread pool
-    // TODO: ideally, this should be created once, utilizing the command buffer state above
-    //       for some reason, doing it like this leads to a crash
     void (^encode_async)(size_t ith);
 
     // n_cb command buffers + 1 used by the main thread
@@ -282,7 +343,7 @@ @implementation GGMLMetalClass
     return data;
 }
 
-static struct ggml_backend_metal_context * ggml_metal_init(void) {
+static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t dev) {
     GGML_LOG_INFO("%s: allocating\n", __func__);
 
 #if TARGET_OS_OSX && !GGML_METAL_NDEBUG
@@ -294,14 +355,14 @@ @implementation GGMLMetalClass
     [devices release]; // since it was created by a *Copy* C method
 #endif
 
-    // Pick and show default Metal device
-    id device = MTLCreateSystemDefaultDevice();
+    // init context
+    struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
+    struct ggml_backend_metal_device_context * ctx_dev = dev->context;
+
+    id device = ggml_backend_metal_device_acq(ctx_dev);
     GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
 
-    // Configure context
-    struct ggml_backend_metal_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_context));
-    ctx->device = device;
-    ctx->queue  = [ctx->device newCommandQueue];
+    ctx->queue  = [device newCommandQueue];
     ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
     id metal_library;
@@ -334,7 +395,7 @@ @implementation GGMLMetalClass
             NSURL * libURL = [NSURL fileURLWithPath:path_lib];
             GGML_LOG_INFO("%s: loading '%s'\n", __func__, [path_lib UTF8String]);
 
-            metal_library = [ctx->device newLibraryWithURL:libURL error:&error];
+            metal_library = [device newLibraryWithURL:libURL error:&error];
             if (error) {
                 GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
                 return NULL;
@@ -384,7 +445,7 @@ @implementation GGMLMetalClass
 
                 //[options setFastMathEnabled:false];
 
-                metal_library = [ctx->device newLibraryWithSource:src options:options error:&error];
+                metal_library = [device newLibraryWithSource:src options:options error:&error];
                 if (error) {
                     GGML_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
                     return NULL;
@@ -394,44 +455,37 @@ @implementation GGMLMetalClass
     }
 
     // print MTL GPU family:
-    GGML_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
-
-    const NSInteger MTLGPUFamilyMetal3 = 5001;
+    GGML_LOG_INFO("%s: GPU name:   %s\n", __func__, [[device name] UTF8String]);
 
     // determine max supported GPU family
     // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
     // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
     {
         for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
-            if ([ctx->device supportsFamily:i]) {
+            if ([device supportsFamily:i]) {
                 GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d  (%d)\n", __func__, i - (int) MTLGPUFamilyApple1 + 1, i);
                 break;
             }
         }
 
         for (int i = MTLGPUFamilyCommon1 + 5; i >= MTLGPUFamilyCommon1; --i) {
-            if ([ctx->device supportsFamily:i]) {
+            if ([device supportsFamily:i]) {
                 GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyCommon%d (%d)\n", __func__, i - (int) MTLGPUFamilyCommon1 + 1, i);
                 break;
             }
         }
 
-        for (int i = MTLGPUFamilyMetal3 + 5; i >= MTLGPUFamilyMetal3; --i) {
-            if ([ctx->device supportsFamily:i]) {
-                GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d  (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3 + 3, i);
+        for (int i = MTLGPUFamilyMetal3_GGML + 5; i >= MTLGPUFamilyMetal3_GGML; --i) {
+            if ([device supportsFamily:i]) {
+                GGML_LOG_INFO("%s: GPU family: MTLGPUFamilyMetal%d  (%d)\n", __func__, i - (int) MTLGPUFamilyMetal3_GGML + 3, i);
                 break;
             }
         }
     }
 
-    ctx->support_simdgroup_reduction  = [ctx->device supportsFamily:MTLGPUFamilyApple7];
-    ctx->support_simdgroup_reduction |= [ctx->device supportsFamily:MTLGPUFamilyMetal3];
-
-    ctx->support_simdgroup_mm = [ctx->device supportsFamily:MTLGPUFamilyApple7];
-
-    GGML_LOG_INFO("%s: simdgroup reduction support   = %s\n",       __func__, ctx->support_simdgroup_reduction ? "true" : "false");
-    GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n",       __func__, ctx->support_simdgroup_mm ? "true" : "false");
-    GGML_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
+    GGML_LOG_INFO("%s: simdgroup reduction support   = %s\n", __func__, ctx_dev->support_simdgroup_reduction ? "true" : "false");
+    GGML_LOG_INFO("%s: simdgroup matrix mul. support = %s\n", __func__, ctx_dev->support_simdgroup_mm ? "true" : "false");
+    GGML_LOG_INFO("%s: hasUnifiedMemory              = %s\n", __func__, ctx_dev->mtl_device.hasUnifiedMemory ? "true" : "false");
 
     ctx->capture_next_compute = false;
     ctx->capture_started = false;
@@ -445,13 +499,7 @@ @implementation GGMLMetalClass
 
 #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
     if (@available(macOS 10.12, iOS 16.0, *)) {
-        GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1e6);
-    }
-#elif TARGET_OS_OSX
-    if (ctx->device.maxTransferRate != 0) {
-        GGML_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1e6);
-    } else {
-        GGML_LOG_INFO("%s: maxTransferRate               = built-in GPU\n", __func__);
+        GGML_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, device.recommendedMaxWorkingSetSize / 1e6);
     }
 #endif
 
@@ -472,7 +520,7 @@ @implementation GGMLMetalClass
         if (supported) { \
             struct ggml_metal_kernel * kernel = &ctx->kernels[e]; \
             id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
-            kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
+            kernel->pipeline = [device newComputePipelineStateWithFunction:metal_function error:&error]; \
             [metal_function release]; \
             if (error) { \
                 GGML_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
@@ -483,6 +531,9 @@ @implementation GGMLMetalClass
             GGML_LOG_WARN("%s: skipping %-40s (not supported)\n", __func__, "kernel_"#name); \
         }
 
+        const bool support_simdgroup_mm        = ctx_dev->support_simdgroup_mm;
+        const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
+
         // simd_sum and simd_max requires MTLGPUFamilyApple7
 
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD,                           add,                            true);
@@ -509,10 +560,10 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK_4,                  gelu_quick_4,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU,                          silu,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU_4,                        silu_4,                         true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16,                  soft_max_f16,                   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F16_4,                soft_max_f16_4,                 support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32,                  soft_max_f32,                   support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_F32_4,                soft_max_f32_4,                 support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF,                 diag_mask_inf,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8,               diag_mask_inf_8,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32,                  get_rows_f32,                   true);
@@ -537,107 +588,109 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,               get_rows_iq4_nl,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,               get_rows_iq4_xs,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                  get_rows_i32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                      rms_norm,                       support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                    group_norm,                     support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM,                          norm,                           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,                  ssm_conv_f32,                   true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32,                  ssm_scan_f32,                   true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           ctx->support_simdgroup_reduction);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           ctx->support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32,                mul_mv_f32_f32,                 support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16,                mul_mv_f16_f16,                 support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32,                mul_mv_f16_f32,                 support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW,           mul_mv_f16_f32_1row,            support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4,             mul_mv_f16_f32_l4,              support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32,               mul_mv_q4_0_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32,               mul_mv_q4_1_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,               mul_mv_q5_0_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,               mul_mv_q5_1_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,               mul_mv_q8_0_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,               mul_mv_q2_K_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,               mul_mv_q3_K_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,               mul_mv_q4_K_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32,               mul_mv_q5_K_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32,               mul_mv_q6_K_f32,                support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32,            mul_mv_iq2_xxs_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32,             mul_mv_iq2_xs_f32,              support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32,            mul_mv_iq3_xxs_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_S_F32,              mul_mv_iq3_s_f32,               support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_S_F32,              mul_mv_iq2_s_f32,               support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_S_F32,              mul_mv_iq1_s_f32,               support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ1_M_F32,              mul_mv_iq1_m_f32,               support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_NL_F32,             mul_mv_iq4_nl_f32,              support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ4_XS_F32,             mul_mv_iq4_xs_f32,              support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32,             mul_mv_id_f32_f32,              support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16,             mul_mv_id_f16_f16,              support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32,             mul_mv_id_f16_f32,              support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW,        mul_mv_id_f16_f32_1row,         support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4,          mul_mv_id_f16_f32_l4,           support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32,            mul_mv_id_q4_0_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32,            mul_mv_id_q4_1_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32,            mul_mv_id_q5_0_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32,            mul_mv_id_q5_1_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32,            mul_mv_id_q8_0_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32,            mul_mv_id_q2_K_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32,            mul_mv_id_q3_K_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32,            mul_mv_id_q4_K_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32,            mul_mv_id_q5_K_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32,            mul_mv_id_q6_K_f32,             support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32,         mul_mv_id_iq2_xxs_f32,          support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32,          mul_mv_id_iq2_xs_f32,           support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32,         mul_mv_id_iq3_xxs_f32,          support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_S_F32,           mul_mv_id_iq3_s_f32,            support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_S_F32,           mul_mv_id_iq2_s_f32,            support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_S_F32,           mul_mv_id_iq1_s_f32,            support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ1_M_F32,           mul_mv_id_iq1_m_f32,            support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_NL_F32,          mul_mv_id_iq4_nl_f32,           support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ4_XS_F32,          mul_mv_id_iq4_xs_f32,           support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32,                mul_mm_f32_f32,                 support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32,                mul_mm_f16_f32,                 support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32,               mul_mm_q4_0_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32,               mul_mm_q4_1_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32,               mul_mm_q5_0_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32,               mul_mm_q5_1_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32,               mul_mm_q8_0_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32,               mul_mm_q2_K_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32,               mul_mm_q3_K_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32,               mul_mm_q4_K_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32,               mul_mm_q5_K_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32,               mul_mm_q6_K_f32,                support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32,            mul_mm_iq2_xxs_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32,             mul_mm_iq2_xs_f32,              support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32,            mul_mm_iq3_xxs_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_S_F32,              mul_mm_iq3_s_f32,               support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_S_F32,              mul_mm_iq2_s_f32,               support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_S_F32,              mul_mm_iq1_s_f32,               support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,              mul_mm_iq1_m_f32,               support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,             mul_mm_iq4_nl_f32,              support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,             mul_mm_iq4_xs_f32,              support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32,             mul_mm_id_f32_f32,              support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32,             mul_mm_id_f16_f32,              support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32,            mul_mm_id_q4_0_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32,            mul_mm_id_q4_1_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32,            mul_mm_id_q5_0_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32,            mul_mm_id_q5_1_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32,            mul_mm_id_q8_0_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32,            mul_mm_id_q2_K_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32,            mul_mm_id_q3_K_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32,            mul_mm_id_q4_K_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32,            mul_mm_id_q5_K_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32,            mul_mm_id_q6_K_f32,             support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32,         mul_mm_id_iq2_xxs_f32,          support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32,          mul_mm_id_iq2_xs_f32,           support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32,         mul_mm_id_iq3_xxs_f32,          support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_S_F32,           mul_mm_id_iq3_s_f32,            support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_S_F32,           mul_mm_id_iq2_s_f32,            support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_S_F32,           mul_mm_id_iq1_s_f32,            support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ1_M_F32,           mul_mm_id_iq1_m_f32,            support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_NL_F32,          mul_mm_id_iq4_nl_f32,           support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32,                 rope_norm_f32,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,                 rope_norm_f16,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,                 rope_neox_f32,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,                 rope_neox_f16,                  true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,                im2col_ext_f16,                 true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32,                im2col_ext_f32,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32,                       pad_f32,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32,        timestep_embedding_f32,         true);
@@ -645,14 +698,14 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,           argsort_f32_i32_asc,            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,          argsort_f32_i32_desc,           true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,                leaky_relu_f32,                 true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        ctx->support_simdgroup_mm);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        ctx->support_simdgroup_mm);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    ctx->support_simdgroup_reduction);
-      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    ctx->support_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,        flash_attn_ext_f16_h64,         support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,        flash_attn_ext_f16_h80,         support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,        flash_attn_ext_f16_h96,         support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,       flash_attn_ext_f16_h112,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,       flash_attn_ext_f16_h128,        support_simdgroup_mm);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,       flash_attn_ext_f16_h256,        support_simdgroup_mm);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,   flash_attn_ext_vec_f16_h128,    support_simdgroup_reduction);
+      //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,   flash_attn_ext_vec_f16_h256,    support_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16,                   cpy_f16_f16,                    true);
@@ -669,6 +722,8 @@ @implementation GGMLMetalClass
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,               pool_2d_avg_f32,                true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,               pool_2d_max_f32,                true);
     }
 
     [metal_library release];
@@ -683,8 +738,9 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
         [ctx->kernels[i].pipeline release];
     }
 
+    Block_release(ctx->encode_async);
+
     [ctx->queue release];
-    [ctx->device release];
 
     dispatch_release(ctx->d_queue);
 
@@ -742,13 +798,16 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
     return nil;
 }
 
-static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx, const struct ggml_tensor * op) {
+static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_context * ctx_dev, const struct ggml_tensor * op) {
     for (size_t i = 0, n = 3; i < n; ++i) {
         if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
             return false;
         }
     }
 
+    const bool support_simdgroup_mm        = ctx_dev->support_simdgroup_mm;
+    const bool support_simdgroup_reduction = ctx_dev->support_simdgroup_reduction;
+
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
@@ -786,15 +845,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
         case GGML_OP_SOFT_MAX:
         case GGML_OP_RMS_NORM:
         case GGML_OP_GROUP_NORM:
-            return ctx->support_simdgroup_reduction;
+            return support_simdgroup_reduction;
         case GGML_OP_NORM:
         case GGML_OP_ROPE:
             return true;
         case GGML_OP_IM2COL:
             return op->src[0]->type == GGML_TYPE_F16;
         case GGML_OP_POOL_1D:
-        case GGML_OP_POOL_2D:
             return false;
+        case GGML_OP_POOL_2D:
         case GGML_OP_UPSCALE:
         case GGML_OP_PAD:
         case GGML_OP_ARANGE:
@@ -812,13 +871,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
             if (op->src[0]->ne[0] == 256) {
                 return false;
             }
-            return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
+            return support_simdgroup_mm; // TODO: over-restricted for vec-kernels
         case GGML_OP_SSM_CONV:
         case GGML_OP_SSM_SCAN:
             return true;
         case GGML_OP_MUL_MAT:
         case GGML_OP_MUL_MAT_ID:
-            return ctx->support_simdgroup_reduction &&
+            return support_simdgroup_reduction &&
                 (op->src[0]->type != GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F32);
         case GGML_OP_CPY:
         case GGML_OP_DUP:
@@ -862,9 +921,12 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
 }
 
 static void ggml_metal_encode_node(
-     struct ggml_backend_metal_context * ctx,
+                        ggml_backend_t   backend,
                                    int   idx,
           id   encoder) {
+    struct ggml_backend_metal_context        * ctx     = backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
     struct ggml_cgraph * gf = ctx->gf;
 
     struct ggml_tensor * node = ggml_graph_node(gf, idx);
@@ -894,7 +956,7 @@ static void ggml_metal_encode_node(
             } break;
     }
 
-    if (!ggml_metal_supports_op(ctx, dst)) {
+    if (!ggml_metal_supports_op(ctx_dev, dst)) {
         GGML_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
         GGML_ABORT("unsupported op");
     }
@@ -953,19 +1015,23 @@ static void ggml_metal_encode_node(
     id id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
     id id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
 
-    //GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
-    //if (src0) {
-    //    GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
-    //            ggml_is_contiguous(src0), src0->name);
-    //}
-    //if (src1) {
-    //    GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
-    //            ggml_is_contiguous(src1), src1->name);
-    //}
-    //if (dst) {
-    //    GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld], 1, %s\n",  __func__, ggml_type_name(dstt),  ne0,  ne1,  ne2,
-    //            dst->name);
-    //}
+#if 0
+    GGML_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
+    if (src0) {
+        GGML_LOG_INFO("%s: src0 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, ne03, nb00, nb01, nb02, nb03,
+                ggml_is_contiguous(src0), src0->name);
+    }
+    if (src1) {
+        GGML_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13,
+                ggml_is_contiguous(src1), src1->name);
+    }
+    if (dst) {
+        GGML_LOG_INFO("%s: dst  - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3,
+                dst->name);
+    }
+#endif
+
+    id device = ctx_dev->mtl_device;
 
     switch (dst->op) {
         case GGML_OP_CONCAT:
@@ -1675,7 +1741,7 @@ static void ggml_metal_encode_node(
                 // the numbers below are measured on M2 Ultra for 7B and 13B models
                 // these numbers do not translate to other devices or model sizes
                 // TODO: need to find a better approach
-                        if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
+                        if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
                             switch (src0t) {
                                 case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
                                 case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
@@ -1695,7 +1761,7 @@ static void ggml_metal_encode_node(
 
                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                        if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                        if ([device supportsFamily:MTLGPUFamilyApple7] &&
                                 !ggml_is_transposed(src0) &&
                                 !ggml_is_transposed(src1) &&
                                 src1t == GGML_TYPE_F32 &&
@@ -1746,14 +1812,16 @@ static void ggml_metal_encode_node(
                             [encoder setBytes:&ne02    length:sizeof(ne02) atIndex:4];
                             [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
                             [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
-                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
-                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
-                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
-                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
-                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
-                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
-                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:13];
-                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:14];
+                            [encoder setBytes:&nb03    length:sizeof(nb03) atIndex:7];
+                            [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:8];
+                            [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:9];
+                            [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:10];
+                            [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:11];
+                            [encoder setBytes:&nb13    length:sizeof(nb13) atIndex:12];
+                            [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:13];
+                            [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:14];
+                            [encoder setBytes:&r2      length:sizeof(r2)   atIndex:15];
+                            [encoder setBytes:&r3      length:sizeof(r3)   atIndex:16];
                             [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                             [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {
@@ -1922,20 +1990,22 @@ static void ggml_metal_encode_node(
                             [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
                             [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
                             [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
-                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:9];
-                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:10];
-                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:11];
-                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:12];
-                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:13];
-                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
-                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:15];
-                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
-                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:17];
-                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:18];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:17];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:18];
+                            [encoder setBytes:&r2   length:sizeof(r2)   atIndex:19];
+                            [encoder setBytes:&r3   length:sizeof(r3)   atIndex:20];
 
                             if (src0t == GGML_TYPE_Q4_0  || src0t == GGML_TYPE_Q4_1  || src0t == GGML_TYPE_Q5_0 ||
-                                    src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
-                                    src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+                                src0t == GGML_TYPE_Q5_1  || src0t == GGML_TYPE_Q8_0  || src0t == GGML_TYPE_Q2_K ||
+                                src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
                                 [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                             }
                             else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
@@ -1984,13 +2054,16 @@ static void ggml_metal_encode_node(
 
                 GGML_ASSERT(src1t == GGML_TYPE_F32);
 
+                GGML_ASSERT(ne03 == 1);
+                GGML_ASSERT(ne13 == 1);
+
                 // find the break-even point where the matrix-matrix kernel becomes more efficient compared
                 // to the matrix-vector kernel
                 // ne20 = n_used_experts
                 // ne21 = n_rows
                 const int dst_rows = ne20*ne21;
                 const int dst_rows_min = n_as;
-                const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
+                const int dst_rows_max = (device.maxThreadgroupMemoryLength - 32 - 8192)/4;
 
                 // max size of the rowids array in the kernel shared buffer
                 GGML_ASSERT(dst_rows <= dst_rows_max);
@@ -2001,7 +2074,7 @@ static void ggml_metal_encode_node(
                 // TODO: for now, always use mat-vec kernels until we figure out how to improve the
                 //       indirect matrix multiplication
                 // !!!
-                if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                if ([device supportsFamily:MTLGPUFamilyApple7] &&
                         ne00 % 32 == 0 && ne00 >= 64 &&
                         dst_rows > dst_rows_min) {
 
@@ -2489,6 +2562,8 @@ static void ggml_metal_encode_node(
             } break;
         case GGML_OP_IM2COL:
             {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(ggml_is_contiguous(src1));
                 GGML_ASSERT(src0->type == GGML_TYPE_F16);
                 GGML_ASSERT(src1->type == GGML_TYPE_F32);
                 GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
@@ -2518,30 +2593,54 @@ static void ggml_metal_encode_node(
                 const int32_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4;
                 const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
 
-                id pipeline = nil;
+                id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
+
+                const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
 
                 switch (dst->type) {
-                    case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline; break;
-                    case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline; break;
+                    case GGML_TYPE_F32: {
+                        pipeline = (is_gt_mttpt ?
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F32].pipeline
+                                    :
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline);
+                    } break;
+                    case GGML_TYPE_F16: {
+                        pipeline = (is_gt_mttpt ?
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16].pipeline
+                                    :
+                                    ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F16].pipeline);
+                    } break;
                     default: GGML_ABORT("fatal error");
                 };
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src1 offset:offs_src1        atIndex:0];
-                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
-                [encoder setBytes:&ofs0    length:sizeof( int32_t) atIndex:2];
-                [encoder setBytes:&ofs1    length:sizeof( int32_t) atIndex:3];
-                [encoder setBytes:&IW      length:sizeof( int32_t) atIndex:4];
-                [encoder setBytes:&IH      length:sizeof( int32_t) atIndex:5];
-                [encoder setBytes:&CHW     length:sizeof( int32_t) atIndex:6];
-                [encoder setBytes:&s0      length:sizeof( int32_t) atIndex:7];
-                [encoder setBytes:&s1      length:sizeof( int32_t) atIndex:8];
-                [encoder setBytes:&p0      length:sizeof( int32_t) atIndex:9];
-                [encoder setBytes:&p1      length:sizeof( int32_t) atIndex:10];
-                [encoder setBytes:&d0      length:sizeof( int32_t) atIndex:11];
-                [encoder setBytes:&d1      length:sizeof( int32_t) atIndex:12];
-
-                [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+                [encoder setBuffer:id_src1 offset:offs_src1       atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
+                [encoder setBytes:&ofs0    length:sizeof(int32_t) atIndex:2];
+                [encoder setBytes:&ofs1    length:sizeof(int32_t) atIndex:3];
+                [encoder setBytes:&IW      length:sizeof(int32_t) atIndex:4];
+                [encoder setBytes:&IH      length:sizeof(int32_t) atIndex:5];
+                [encoder setBytes:&CHW     length:sizeof(int32_t) atIndex:6];
+                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:7];
+                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:8];
+                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:9];
+                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:10];
+                [encoder setBytes:&d0      length:sizeof(int32_t) atIndex:11];
+                [encoder setBytes:&d1      length:sizeof(int32_t) atIndex:12];
+
+                if (is_gt_mttpt) {
+                    [encoder setBytes:&N   length:sizeof(int32_t) atIndex:13];
+                    [encoder setBytes:&KH  length:sizeof(int32_t) atIndex:14];
+                    [encoder setBytes:&KW  length:sizeof(int32_t) atIndex:15];
+
+                    const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
+
+                    const int64_t  quotient  = N / n_threads + (N % n_threads > 0 ? 1 : 0);
+
+                    [encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
+                } else {
+                    [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
+                }
             } break;
         case GGML_OP_UPSCALE:
             {
@@ -2840,7 +2939,7 @@ static void ggml_metal_encode_node(
 
                     while (true) {
                         const size_t smem = nqptg*(ne00 + 2*nsgmax*(ncpsg + nqptg))*(sizeof(float)/2);
-                        if (smem > ctx->device.maxThreadgroupMemoryLength) {
+                        if (smem > device.maxThreadgroupMemoryLength) {
                             break;
                         }
                         nsgmax *= 2;
@@ -2852,8 +2951,8 @@ static void ggml_metal_encode_node(
 
                     const size_t smem = nqptg*(ne00 + 2*nsg*(ncpsg + nqptg))*(sizeof(float)/2);
 
-                    //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
-                    GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+                    //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
+                    GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
 
                     [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
 
@@ -2878,8 +2977,8 @@ static void ggml_metal_encode_node(
 
                     const size_t smem = (nqptg*(ne00 + 2*nsg*(ncpsg + nqptg)) + nsg*ne00)*(sizeof(float)/2);
 
-                    //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
-                    GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
+                    //printf("smem: %zu, max: %zu\n", smem, device.maxThreadgroupMemoryLength);
+                    GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
                     [encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
 
                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
@@ -2945,6 +3044,64 @@ static void ggml_metal_encode_node(
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
+        case GGML_OP_POOL_2D:
+            {
+                GGML_ASSERT(ggml_is_contiguous(src0));
+                GGML_ASSERT(src0t == GGML_TYPE_F32 && src0t == dstt);
+
+                const int32_t * opts = dst->op_params;
+                enum ggml_op_pool op = opts[0];
+
+                id pipeline = nil;
+                switch (src0t) {
+                    case GGML_TYPE_F32: {
+                        switch(op) {
+                            case GGML_OP_POOL_AVG:
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32].pipeline; break;
+                            case GGML_OP_POOL_MAX:
+                                pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32].pipeline; break;
+                            default: GGML_ASSERT(false && "not implemented");
+                        }
+                    } break;
+                    default: GGML_ASSERT(false && "not implemented");
+                }
+
+                const int32_t k0 = opts[1];
+                const int32_t k1 = opts[2];
+                const int32_t s0 = opts[3];
+                const int32_t s1 = opts[4];
+                const int32_t p0 = opts[5];
+                const int32_t p1 = opts[6];
+
+                const int64_t IH = src0->ne[1];
+                const int64_t IW = src0->ne[0];
+
+                const int64_t N  = dst->ne[3];
+                const int64_t OC = dst->ne[2];
+                const int64_t OH = dst->ne[1];
+                const int64_t OW = dst->ne[0];
+
+                const int64_t parallel_elements = N * OC * OH * OW;
+                const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements);
+                const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0       atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst        atIndex:1];
+                [encoder setBytes:&k0      length:sizeof(int32_t) atIndex:2];
+                [encoder setBytes:&k1      length:sizeof(int32_t) atIndex:3];
+                [encoder setBytes:&s0      length:sizeof(int32_t) atIndex:4];
+                [encoder setBytes:&s1      length:sizeof(int32_t) atIndex:5];
+                [encoder setBytes:&p0      length:sizeof(int32_t) atIndex:6];
+                [encoder setBytes:&p1      length:sizeof(int32_t) atIndex:7];
+                [encoder setBytes:&IH      length:sizeof(int64_t) atIndex:8];
+                [encoder setBytes:&IW      length:sizeof(int64_t) atIndex:9];
+                [encoder setBytes:&OH      length:sizeof(int64_t) atIndex:10];
+                [encoder setBytes:&OW      length:sizeof(int64_t) atIndex:11];
+                [encoder setBytes:¶llel_elements length:sizeof(int64_t) atIndex:12];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
+            } break;
        default:
             {
                 GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
@@ -2954,8 +3111,11 @@ static void ggml_metal_encode_node(
 }
 
 static enum ggml_status ggml_metal_graph_compute(
-        struct ggml_backend_metal_context * ctx,
-                       struct ggml_cgraph * gf) {
+            ggml_backend_t   backend,
+        struct ggml_cgraph * gf) {
+    struct ggml_backend_metal_context        * ctx     = backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
     // number of nodes encoded by the main thread (empirically determined)
     const int n_main = 128;
 
@@ -2983,7 +3143,7 @@ static enum ggml_status ggml_metal_graph_compute(
 
             if (!ctx->capture_started) {
                 // create capture scope
-                ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx->device];
+                ctx->capture_scope = [[MTLCaptureManager sharedCaptureManager] newCaptureScopeWithDevice:ctx_dev->mtl_device];
 
                 MTLCaptureDescriptor * descriptor = [MTLCaptureDescriptor new];
                 descriptor.captureObject = ctx->capture_scope;
@@ -3000,46 +3160,6 @@ static enum ggml_status ggml_metal_graph_compute(
             }
         }
 
-        // TODO: how to avoid this allocation? I tried initializing it in ggml_backend_metal_set_n_cb but it crashes.
-        ctx->encode_async = ^(size_t iter) {
-            const int cb_idx = iter;
-            const int n_cb_l = ctx->n_cb;
-
-            const int n_nodes_0 = ctx->n_nodes_0;
-            const int n_nodes_1 = ctx->n_nodes_1;
-
-            const int n_nodes_per_cb = ctx->n_nodes_per_cb;
-
-            id command_buffer  = ctx->command_buffers[cb_idx];
-            id encoder = [command_buffer computeCommandEncoder];
-
-            int node_start = 0;
-            int node_end   = n_nodes_0;
-
-            if (cb_idx < n_cb_l) {
-                node_start = n_nodes_0 + (                                         (cb_idx + 0) * n_nodes_per_cb);
-                node_end   = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
-            }
-
-            for (int idx = node_start; idx < node_end; ++idx) {
-                if (should_capture) {
-                    [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(gf, idx)) encoding:NSUTF8StringEncoding]];
-                }
-
-                ggml_metal_encode_node(ctx, idx, encoder);
-
-                if (should_capture) {
-                    [encoder popDebugGroup];
-                }
-            }
-
-            [encoder endEncoding];
-
-            if (cb_idx < 2 || ctx->abort_callback == NULL) {
-                [command_buffer commit];
-            }
-        };
-
         // the main thread commits the first few commands immediately
         // command_buffer[n_cb]
         {
@@ -3127,31 +3247,6 @@ static enum ggml_status ggml_metal_graph_compute(
 
 // backend interface
 
-// default buffer
-static id g_backend_device = nil;
-static int g_backend_device_ref_count = 0; // TODO: make thread-safe
-
-static id ggml_backend_metal_get_device(void) {
-    if (g_backend_device == nil) {
-        g_backend_device = MTLCreateSystemDefaultDevice();
-    }
-
-    g_backend_device_ref_count++;
-
-    return g_backend_device;
-}
-
-static void ggml_backend_metal_free_device(void) {
-    assert(g_backend_device_ref_count > 0);
-
-    g_backend_device_ref_count--;
-
-    if (g_backend_device_ref_count == 0) {
-        [g_backend_device release];
-        g_backend_device = nil;
-    }
-}
-
 static const char * ggml_backend_metal_buffer_get_name(ggml_backend_buffer_t buffer) {
     return "Metal";
 
@@ -3164,7 +3259,7 @@ static void ggml_backend_metal_buffer_free_buffer(ggml_backend_buffer_t buffer)
     for (int i = 0; i < ctx->n_buffers; i++) {
         [ctx->buffers[i].metal release];
     }
-    ggml_backend_metal_free_device();
+    ggml_backend_metal_device_rel(buffer->buft->device->context);
 
     if (ctx->owned) {
 #if TARGET_OS_OSX
@@ -3267,7 +3362,7 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
         size_aligned += (size_page - (size_aligned % size_page));
     }
 
-    id device = ggml_backend_metal_get_device();
+    id device = ggml_backend_metal_device_acq(buft->device->context);
 
     ctx->all_data = ggml_metal_host_malloc(size_aligned);
     ctx->all_size = size_aligned;
@@ -3281,16 +3376,16 @@ static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_ba
 
         if (size_aligned > 0) {
             ctx->buffers[0].metal = [device newBufferWithBytesNoCopy:ctx->all_data
-                            length:size_aligned
-                            options:MTLResourceStorageModeShared
-                            deallocator:nil];
+                                            length:size_aligned
+                                            options:MTLResourceStorageModeShared
+                                            deallocator:nil];
         }
     }
 
     if (size_aligned > 0 && (ctx->all_data == NULL || ctx->buffers[0].metal == nil)) {
         GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
         free(ctx);
-        ggml_backend_metal_free_device();
+        ggml_backend_metal_device_rel(buft->device->context);
         return NULL;
     }
 
@@ -3305,9 +3400,9 @@ static size_t ggml_backend_metal_buffer_type_get_alignment(ggml_backend_buffer_t
 }
 
 static size_t ggml_backend_metal_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
-    id device = ggml_backend_metal_get_device();
-    size_t max_size = device.maxBufferLength;
-    ggml_backend_metal_free_device();
+    id device = ggml_backend_metal_device_acq(buft->device->context);
+    const size_t max_size = device.maxBufferLength;
+    ggml_backend_metal_device_rel(buft->device->context);
 
     return max_size;
 
@@ -3330,15 +3425,14 @@ ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void) {
             /* .get_alloc_size   = */ NULL, // defaults to ggml_nbytes
             /* .is_host          = */ ggml_backend_metal_buffer_type_is_host,
         },
-        /* .device  = */ NULL,
+        /* .device  = */ &g_ggml_backend_metal_device,
         /* .context = */ NULL,
     };
 
     return &ggml_backend_buffer_type_metal;
 }
 
-// buffer from ptr
-
+// TODO: obsoleted by ggml_backend_metal_device_buffer_from_ptr
 ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size) {
     struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
 
@@ -3361,7 +3455,7 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
         size_aligned += (size_page - (size_aligned % size_page));
     }
 
-    id device = ggml_backend_metal_get_device();
+    id device = ggml_backend_metal_device_acq(&g_ggml_ctx_dev_main);
 
     // the buffer fits into the max buffer size allowed by the device
     if (size_aligned <= device.maxBufferLength) {
@@ -3426,8 +3520,12 @@ ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t siz
 }
 
 static void ggml_backend_metal_free(ggml_backend_t backend) {
-    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+    struct ggml_backend_metal_context        * ctx     = backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
+
+    ggml_backend_metal_device_rel(ctx_dev);
     ggml_metal_free(ctx);
+
     free(backend);
 }
 
@@ -3438,21 +3536,7 @@ static ggml_backend_buffer_type_t ggml_backend_metal_get_default_buffer_type(ggm
 }
 
 static enum ggml_status ggml_backend_metal_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
-    struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
-
-    return ggml_metal_graph_compute(metal_ctx, cgraph);
-}
-
-static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
-    struct ggml_backend_metal_context * metal_ctx = (struct ggml_backend_metal_context *)backend->context;
-
-    return ggml_metal_supports_op(metal_ctx, op);
-}
-
-static bool ggml_backend_metal_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
-    return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
-
-    UNUSED(backend);
+    return ggml_metal_graph_compute(backend, cgraph);
 }
 
 static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
@@ -3468,10 +3552,50 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
         }
     }
 
-    // TODO: setting encode_async here causes crash during the next ggml_metal_graph_compute call. why?
-    //ctx->encode_async = ^(size_t iter) {
-    //    ...
-    //};
+    if (ctx->encode_async) {
+        Block_release(ctx->encode_async);
+    }
+
+    ctx->encode_async = Block_copy(^(size_t iter) {
+        const int cb_idx = iter;
+        const int n_cb_l = ctx->n_cb;
+
+        const int n_nodes_0 = ctx->n_nodes_0;
+        const int n_nodes_1 = ctx->n_nodes_1;
+
+        const int n_nodes_per_cb = ctx->n_nodes_per_cb;
+
+        id command_buffer  = ctx->command_buffers[cb_idx];
+        id encoder = [command_buffer computeCommandEncoder];
+
+        int node_start = 0;
+        int node_end   = n_nodes_0;
+
+        if (cb_idx < n_cb_l) {
+            node_start = n_nodes_0 + (                                         (cb_idx + 0) * n_nodes_per_cb);
+            node_end   = n_nodes_0 + (MIN((cb_idx == n_cb_l - 1) ? n_nodes_1 : (cb_idx + 1) * n_nodes_per_cb, n_nodes_1));
+        }
+
+        const bool should_capture = ctx->capture_next_compute;
+
+        for (int idx = node_start; idx < node_end; ++idx) {
+            if (should_capture) {
+                [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
+            }
+
+            ggml_metal_encode_node(backend, idx, encoder);
+
+            if (should_capture) {
+                [encoder popDebugGroup];
+            }
+        }
+
+        [encoder endEncoding];
+
+        if (cb_idx < 2 || ctx->abort_callback == NULL) {
+            [command_buffer commit];
+        }
+    });
 }
 
 static struct ggml_backend_i ggml_backend_metal_i = {
@@ -3487,8 +3611,8 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
     /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_metal_graph_compute,
-    /* .supports_op             = */ ggml_backend_metal_supports_op,
-    /* .supports_buft           = */ ggml_backend_metal_supports_buft,
+    /* .supports_op             = */ NULL,
+    /* .supports_buft           = */ NULL,
     /* .offload_op              = */ NULL,
     /* .event_record            = */ NULL,
     /* .event_wait              = */ NULL,
@@ -3499,8 +3623,11 @@ static ggml_guid_t ggml_backend_metal_guid(void) {
     return &guid;
 }
 
+// TODO: remove in the future
 ggml_backend_t ggml_backend_metal_init(void) {
-    struct ggml_backend_metal_context * ctx = ggml_metal_init();
+    ggml_backend_dev_t dev = ggml_backend_reg_dev_get(ggml_backend_metal_reg(), 0);
+
+    struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
     if (ctx == NULL) {
         GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
         return NULL;
@@ -3511,7 +3638,7 @@ ggml_backend_t ggml_backend_metal_init(void) {
     *backend = (struct ggml_backend) {
         /* .guid      = */ ggml_backend_metal_guid(),
         /* .interface = */ ggml_backend_metal_i,
-        /* .device    = */ NULL,
+        /* .device    = */ dev,
         /* .context   = */ ctx,
     };
 
@@ -3536,9 +3663,9 @@ void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_ca
 bool ggml_backend_metal_supports_family(ggml_backend_t backend, int family) {
     GGML_ASSERT(ggml_backend_is_metal(backend));
 
-    struct ggml_backend_metal_context * ctx = (struct ggml_backend_metal_context *)backend->context;
+    struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
 
-    return [ctx->device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
+    return [ctx_dev->mtl_device supportsFamily:(MTLGPUFamilyApple1 + family - 1)];
 }
 
 void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
@@ -3548,11 +3675,246 @@ void ggml_backend_metal_capture_next_compute(ggml_backend_t backend) {
     ctx->capture_next_compute = true;
 }
 
-ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); // silence warning
+// backend device
+
+static const char * ggml_backend_metal_device_get_name(ggml_backend_dev_t dev) {
+    return "Metal";
+
+    GGML_UNUSED(dev);
+}
+
+static const char * ggml_backend_metal_device_get_description(ggml_backend_dev_t dev) {
+    // acq/rel just to populate ctx->name in case it hasn't been done yet
+    struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+    ggml_backend_metal_device_acq(ctx_dev);
+    ggml_backend_metal_device_rel(ctx_dev);
+
+    return ctx_dev->name;
+}
+
+static void ggml_backend_metal_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    if (@available(macOS 10.12, iOS 16.0, *)) {
+        struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+        id device = ggml_backend_metal_device_acq(ctx_dev);
+
+        *total = device.recommendedMaxWorkingSetSize;
+        *free  = *total - device.currentAllocatedSize;
+
+        ggml_backend_metal_device_rel(ctx_dev);
+    } else {
+        *free = 1;
+        *total = 1;
+    }
+}
+
+static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backend_dev_t dev) {
+    return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
 
-ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data) {
-    return ggml_backend_metal_init();
+    GGML_UNUSED(dev);
+}
+
+static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_metal_device_get_name(dev);
+    props->description = ggml_backend_metal_device_get_description(dev);
+    props->type        = ggml_backend_metal_device_get_type(dev);
+    ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = (struct ggml_backend_dev_caps) {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ true,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_metal_device_init(ggml_backend_dev_t dev, const char * params) {
+    struct ggml_backend_metal_context * ctx = ggml_metal_init(dev);
+    if (ctx == NULL) {
+        GGML_LOG_ERROR("%s: error: failed to allocate context\n", __func__);
+        return NULL;
+    }
+
+    ggml_backend_t backend = malloc(sizeof(struct ggml_backend));
+
+    *backend = (struct ggml_backend) {
+        /* .guid      = */ ggml_backend_metal_guid(),
+        /* .interface = */ ggml_backend_metal_i,
+        /* .device    = */ dev,
+        /* .context   = */ ctx,
+    };
+
+    ggml_backend_metal_set_n_cb(backend, 1);
+
+    return backend;
 
     GGML_UNUSED(params);
-    GGML_UNUSED(user_data);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_metal_device_get_buffer_type(ggml_backend_dev_t dev) {
+    return ggml_backend_metal_buffer_type();
+
+    GGML_UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_metal_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    struct ggml_backend_metal_buffer_context * ctx = calloc(1, sizeof(struct ggml_backend_metal_buffer_context));
+
+    ctx->all_data = ptr;
+    ctx->all_size = size;
+    ctx->owned = false;
+    ctx->n_buffers = 0;
+
+    const size_t size_page = sysconf(_SC_PAGESIZE);
+
+    // page-align the data ptr
+    {
+        const uintptr_t offs = (uintptr_t) ptr % size_page;
+        ptr  = (void *) ((char *) ptr - offs);
+        size += offs;
+    }
+
+    size_t size_aligned = size;
+    if ((size_aligned % size_page) != 0) {
+        size_aligned += (size_page - (size_aligned % size_page));
+    }
+
+    struct ggml_backend_metal_device_context * ctx_dev = (struct ggml_backend_metal_device_context *)dev->context;
+    id device = ggml_backend_metal_device_acq(ctx_dev);
+
+    // the buffer fits into the max buffer size allowed by the device
+    if (size_aligned <= device.maxBufferLength) {
+        ctx->buffers[ctx->n_buffers].data  = ptr;
+        ctx->buffers[ctx->n_buffers].size  = size;
+        ctx->buffers[ctx->n_buffers].metal = nil;
+
+        if (size_aligned > 0) {
+            ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:ptr length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+            if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_aligned / 1024.0 / 1024.0);
+                return false;
+            }
+        }
+
+        ggml_backend_metal_log_allocated_size(device, size_aligned);
+
+        ++ctx->n_buffers;
+    } else {
+        // this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
+        // one of the views
+        const size_t size_ovlp = ((max_tensor_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
+        const size_t size_step = device.maxBufferLength - size_ovlp;
+        const size_t size_view = device.maxBufferLength;
+
+        for (size_t i = 0; i < size; i += size_step) {
+            const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
+
+            ctx->buffers[ctx->n_buffers].data  = (void *) ((uint8_t *) ptr + i);
+            ctx->buffers[ctx->n_buffers].size  = size_step_aligned;
+            ctx->buffers[ctx->n_buffers].metal = nil;
+
+            if (size_step_aligned > 0) {
+                ctx->buffers[ctx->n_buffers].metal = [device newBufferWithBytesNoCopy:(void *) ((uint8_t *) ptr + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
+
+                if (ctx->buffers[ctx->n_buffers].metal == nil) {
+                    GGML_LOG_ERROR("%s: error: failed to allocate buffer, size = %8.2f MiB\n", __func__, size_step_aligned / 1024.0 / 1024.0);
+                    return false;
+                }
+            }
+
+            ggml_backend_metal_log_allocated_size(device, size_step_aligned);
+
+            if (i + size_step < size) {
+                GGML_LOG_INFO("\n");
+            }
+
+            ++ctx->n_buffers;
+        }
+    }
+
+    return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
+}
+
+static bool ggml_backend_metal_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    struct ggml_backend_metal_device_context * ctx_dev = dev->context;
+
+    return ggml_metal_supports_op(ctx_dev, op);
+}
+
+static bool ggml_backend_metal_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    return buft->iface.get_name == ggml_backend_metal_buffer_type_get_name;
+
+    UNUSED(dev);
+}
+
+static bool ggml_backend_metal_device_offload_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    return false;
+
+    GGML_UNUSED(dev);
+    GGML_UNUSED(op);
+}
+
+static struct ggml_backend_device_i ggml_backend_metal_device_i = {
+    /* .get_name             = */ ggml_backend_metal_device_get_name,
+    /* .get_description      = */ ggml_backend_metal_device_get_description,
+    /* .get_memory           = */ ggml_backend_metal_device_get_memory,
+    /* .get_type             = */ ggml_backend_metal_device_get_type,
+    /* .get_props            = */ ggml_backend_metal_device_get_props,
+    /* .init_backend         = */ ggml_backend_metal_device_init,
+    /* .get_buffer_type      = */ ggml_backend_metal_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_metal_device_buffer_from_ptr,
+    /* .supports_op          = */ ggml_backend_metal_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_metal_device_supports_buft,
+    /* .offload_op           = */ ggml_backend_metal_device_offload_op,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// backend registry
+
+static const char * ggml_backend_metal_reg_get_name(ggml_backend_reg_t reg) {
+    return "Metal";
+
+    GGML_UNUSED(reg);
+}
+
+static size_t ggml_backend_metal_reg_device_count(ggml_backend_reg_t reg) {
+    return 1;
+
+    GGML_UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_metal_reg_device_get(ggml_backend_reg_t reg, size_t index) {
+    GGML_ASSERT(index == 0);
+
+    return &g_ggml_backend_metal_device;
+
+    GGML_UNUSED(reg);
+    GGML_UNUSED(index);
+}
+
+static struct ggml_backend_reg_i ggml_backend_metal_reg_i = {
+    /* .get_name         = */ ggml_backend_metal_reg_get_name,
+    /* .device_count     = */ ggml_backend_metal_reg_device_count,
+    /* .device_get       = */ ggml_backend_metal_reg_device_get,
+    /* .get_proc_address = */ NULL,
+};
+
+ggml_backend_reg_t ggml_backend_metal_reg(void) {
+    // TODO: make this thread-safe somehow?
+    {
+        g_ggml_backend_metal_reg = (struct ggml_backend_reg) {
+            /* .iface   = */ ggml_backend_metal_reg_i,
+            /* .context = */ NULL,
+        };
+
+        g_ggml_backend_metal_device = (struct ggml_backend_device) {
+            /* .iface   = */ ggml_backend_metal_device_i,
+            /* .reg     = */ &g_ggml_backend_metal_reg,
+            /* .context = */ &g_ggml_ctx_dev_main,
+        };
+    }
+
+    return &g_ggml_backend_metal_reg;
 }
diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal
index 2b200032394..defde6246f1 100644
--- a/ggml/src/ggml-metal.metal
+++ b/ggml/src/ggml-metal.metal
@@ -777,10 +777,10 @@ kernel void kernel_ssm_conv_f32(
     const int64_t i3 = tgpig.z;
 
     const int64_t nc  = ne10;
-    const int64_t ncs = ne00;
-    const int64_t nr  = ne01;
-    const int64_t n_t = ne1;
-    const int64_t n_s = ne2;
+  //const int64_t ncs = ne00;
+  //const int64_t nr  = ne01;
+  //const int64_t n_t = ne1;
+  //const int64_t n_s = ne2;
 
     device const float * s = (device const float *) ((device const char *) src0 + ir*nb01 + i2*nb00 + i3*nb02);
     device const float * c = (device const float *) ((device const char *) src1 + ir*nb11);
@@ -834,9 +834,9 @@ kernel void kernel_ssm_scan_f32(
     const int64_t i3 = tgpig.y;
 
     const int64_t nc  = d_state;
-    const int64_t nr  = d_inner;
+  //const int64_t nr  = d_inner;
     const int64_t n_t = n_seq_tokens;
-    const int64_t n_s = n_seqs;
+  //const int64_t n_s = n_seqs;
 
     for (int64_t i2 = 0; i2 < n_t; ++i2) {
         device const float * s0 = (device const float *) ((device const char *) src0 + ir*nb01 + i3*nb02);
@@ -1064,17 +1064,18 @@ kernel void kernel_group_norm(
 inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
 
-    float2 acc = 0.f;
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
 
-    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
 
-    for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
-                + yl[i + 1] * (qs[i / 2] & 0x0F00);
-        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
-                + yl[i + 9] * (qs[i / 2] & 0xF000);
+    for (int i = 0; i < 8; i += 2) {
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
     }
-    return d * (sumy * -8.f + acc[0] + acc[1]);
+
+    return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
 }
 
 // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1085,17 +1086,18 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
     float d = qb_curr->d;
     float m = qb_curr->m;
 
-    float2 acc = 0.f;
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
 
-    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+    device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
 
     for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
-                + yl[i + 1] * (qs[i / 2] & 0x0F00);
-        acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
-                + yl[i + 9] * (qs[i / 2] & 0xF000);
+        acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
+        acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
+        acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
+        acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
     }
-    return d * (acc[0] + acc[1]) + sumy * m;
+
+    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
 }
 
 // function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1105,18 +1107,19 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
 inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
 
-    float2 acc = 0.f;
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
 
     device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 3 + il/2);
            const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
 
     for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010))
-                + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
-        acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
-                + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
+        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
     }
-    return d * (sumy * -16.f + acc[0] + acc[1]);
+
+    return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
 }
 
 // function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -1127,18 +1130,19 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
     float d = qb_curr->d;
     float m = qb_curr->m;
 
-    float2 acc = 0.f;
+    float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
 
     device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 4 + il/2);
            const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
 
     for (int i = 0; i < 8; i+=2) {
-        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010))
-                + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
-        acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
-                + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010));
+        acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
+        acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
     }
-    return d * (acc[0] + acc[1]) + sumy * m;
+
+    return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
 }
 
 // putting them in the kernel cause a significant performance penalty
@@ -1156,14 +1160,22 @@ void mul_vec_q_n_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
                    uint      r3,
         threadgroup int8_t * shared_values,
-                   uint3 tgpig, uint tiisg, uint sgitg) {
+                     uint3   tgpig,
+                     uint    tiisg,
+                     uint    sgitg) {
     const int nb = ne00/QK4_0;
 
     const int r0 = tgpig.x;
@@ -1175,10 +1187,19 @@ void mul_vec_q_n_f32_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+  //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+  //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
+    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
 
-    device const block_q_type * x = (device const block_q_type *) src0 + offset0;
-    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+    // pointers to src0 rows
+    device const block_q_type * ax[nr];
+    for (int row = 0; row < nr; ++row) {
+        const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+        ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
+    }
 
     float yl[16]; // src1 vector cache
     float sumf[nr] = {0.f};
@@ -1190,19 +1211,22 @@ void mul_vec_q_n_f32_impl(
 
     // each thread in a SIMD group deals with half a block.
     for (int ib = ix; ib < nb; ib += nw/2) {
-        float sumy = 0;
+        float sumy[2] = { 0.f, 0.f };
+
+#pragma unroll
         for (int i = 0; i < 8; i += 2) {
-            sumy += yb[i] + yb[i+1];
-            yl[i+0] = yb[i+ 0];
-            yl[i+1] = yb[i+ 1]/256.f;
+            sumy[0]  += yb[i +  0] + yb[i +  1];
+            yl[i + 0] = yb[i +  0];
+            yl[i + 1] = yb[i +  1]/256.f;
 
-            sumy += yb[i+16] + yb[i+17];
-            yl[i+8] = yb[i+16]/16.f;
-            yl[i+9] = yb[i+17]/4096.f;
+            sumy[1]  += yb[i + 16] + yb[i + 17];
+            yl[i + 8] = yb[i + 16]/16.f;
+            yl[i + 9] = yb[i + 17]/4096.f;
         }
 
+#pragma unroll
         for (int row = 0; row < nr; row++) {
-            sumf[row] += block_q_n_dot_y(x+ib+row*nb, sumy, yl, il);
+            sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il);
         }
 
         yb += QK4_0 * 16;
@@ -1226,12 +1250,14 @@ kernel void kernel_mul_mv_q4_0_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1239,7 +1265,7 @@ kernel void kernel_mul_mv_q4_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q4_1_f32(
@@ -1252,12 +1278,14 @@ kernel void kernel_mul_mv_q4_1_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1265,7 +1293,7 @@ kernel void kernel_mul_mv_q4_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
-     mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+     mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_0_f32(
@@ -1278,12 +1306,14 @@ kernel void kernel_mul_mv_q5_0_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1291,7 +1321,7 @@ kernel void kernel_mul_mv_q5_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 kernel void kernel_mul_mv_q5_1_f32(
@@ -1304,12 +1334,14 @@ kernel void kernel_mul_mv_q5_1_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1317,7 +1349,7 @@ kernel void kernel_mul_mv_q5_1_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+    mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 
@@ -1330,8 +1362,14 @@ void kernel_mul_mv_q8_0_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -1354,10 +1392,19 @@ void kernel_mul_mv_q8_0_f32_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+  //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+  //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
+    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
+
+    // pointers to src0 rows
+    device const block_q8_0 * ax[nr];
+    for (int row = 0; row < nr; ++row) {
+        const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+
+        ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
+    }
 
     float yl[NB_Q8_0];
     float sumf[nr]={0.f};
@@ -1374,12 +1421,12 @@ void kernel_mul_mv_q8_0_f32_impl(
         }
 
         for (int row = 0; row < nr; row++) {
-            device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
+            device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
             float sumq = 0.f;
             for (int iq = 0; iq < NB_Q8_0; ++iq) {
                 sumq += qs[iq] * yl[iq];
             }
-            sumf[row] += sumq*x[ib+row*nb].d;
+            sumf[row] += sumq*ax[row][ib].d;
         }
 
         yb += NB_Q8_0 * nw;
@@ -1404,12 +1451,14 @@ kernel void kernel_mul_mv_q8_0_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1417,7 +1466,7 @@ kernel void kernel_mul_mv_q8_0_f32(
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
-    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+    kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
 }
 
 #define N_MV_T_T 4
@@ -1433,12 +1482,14 @@ void kernel_mul_mv_impl(
                   uint64_t   nb00,
                   uint64_t   nb01,
                   uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne11,
                    int64_t   ne12,
                   uint64_t   nb10,
                   uint64_t   nb11,
                   uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -1452,7 +1503,7 @@ void kernel_mul_mv_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
 
     device const T0 * x = (device const T0 *) (src0 + offset0);
 
@@ -1463,7 +1514,9 @@ void kernel_mul_mv_impl(
                 break;
             }
 
-            device const T1 * y = (device const T1 *) (src1 + r1*nb11 + im*nb12);
+            const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            device const T1 * y = (device const T1 *) (src1 + offset1);
 
             float sumf = 0;
             for (int i = tiisg; i < ne00; i += 32) {
@@ -1483,7 +1536,9 @@ void kernel_mul_mv_impl(
                 break;
             }
 
-            device const T1  * y  = (device const T1  *) (src1 + r1*nb11 + im*nb12);
+            const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+            device const T1  * y  = (device const T1  *) (src1 + offset1);
             device const T14 * y4 = (device const T14 *) y;
 
             float sumf = 0;
@@ -1511,12 +1566,14 @@ kernel void kernel_mul_mv(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1533,12 +1590,14 @@ kernel void kernel_mul_mv(
         nb00,
         nb01,
         nb02,
+        nb03,
         ne10,
         ne11,
         ne12,
         nb10,
         nb11,
         nb12,
+        nb13,
         ne0,
         ne1,
         r2,
@@ -1564,12 +1623,14 @@ kernel void kernel_mul_mv_1row(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1584,10 +1645,11 @@ kernel void kernel_mul_mv_1row(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
     device const T     * x = (device const T     *) (src0 + offset0);
-    device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+    device const float * y = (device const float *) (src1 + offset1);
 
     float sumf = 0;
     if (ne00 < 128) {
@@ -1631,12 +1693,14 @@ kernel void kernel_mul_mv_l4(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -1651,12 +1715,14 @@ kernel void kernel_mul_mv_l4(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02;
+    const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
 
     device const T4 * x4 = (device const T4 *) (src0 + offset0);
 
     for (int r1 = 0; r1 < nrows; ++r1) {
-        device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
+        const uint offset1 = r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+        device const float4 * y4 = (device const float4 *) (src1 + offset1);
 
         float sumf = 0;
         for (int i = tiisg; i < ne00/4; i += 32) {
@@ -1933,6 +1999,85 @@ kernel void kernel_im2col(
 template [[host_name("kernel_im2col_f32")]] kernel im2col_t kernel_im2col;
 template [[host_name("kernel_im2col_f16")]] kernel im2col_t kernel_im2col;
 
+typedef void (im2col_ext_t)(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        constant   int32_t & N,
+        constant   int32_t & KH,
+        constant   int32_t & KW,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]);
+
+template 
+kernel void kernel_im2col_ext(
+        device const float * x,
+        device        char * dst,
+        constant   int32_t & ofs0,
+        constant   int32_t & ofs1,
+        constant   int32_t & IW,
+        constant   int32_t & IH,
+        constant   int32_t & CHW,
+        constant   int32_t & s0,
+        constant   int32_t & s1,
+        constant   int32_t & p0,
+        constant   int32_t & p1,
+        constant   int32_t & d0,
+        constant   int32_t & d1,
+        constant   int32_t & N,
+        constant   int32_t & KH,
+        constant   int32_t & KW,
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint3  tgpg[[threadgroups_per_grid]],      // tgpg[0] = D x IC x KH x KW, CHW = IC x KH x KW
+        uint3 tpitg[[thread_position_in_threadgroup]],
+        uint3   ntg[[threads_per_threadgroup]]) {  // [M, 1, 1]
+    const int32_t KHW = KH * KW;             // KHW == ntg[1] * ntg[2], KW == ntg[2]
+
+    const int32_t d = tgpig[0] / CHW;
+    const int32_t chw = tgpig[0] % CHW;
+    const int32_t tgpig_0 = chw / KHW;  // 0 ~ (IC - 1)
+    const int32_t HW = tgpig[0] % KHW;
+
+    const int32_t tpitg_0 = (d * ntg[0]) + tpitg[0];
+    if (tpitg_0 >= N) {
+        return;
+    }
+
+    const int32_t tpitg_1 = HW / KW;
+    const int32_t tpitg_2 = HW % KW;
+
+    const int32_t iiw = tgpig[2] * s0 + tpitg_2 * d0 - p0;
+    const int32_t iih = tgpig[1] * s1 + tpitg_1 * d1 - p1;
+
+    const int32_t offset_dst =
+        (tpitg_0 * tgpg[1] * tgpg[2] + tgpig[1] * tgpg[2] + tgpig[2]) * CHW +
+        (tgpig_0 * KHW + tpitg_1 * KW + tpitg_2);
+
+    device T * pdst = (device T *) (dst);
+
+    if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
+        pdst[offset_dst] = 0.0f;
+    } else {
+        const int32_t offset_src = tpitg_0 * ofs0 + tgpig_0 * ofs1;
+        pdst[offset_dst] = x[offset_src + iih * IW + iiw];
+    }
+}
+
+template [[host_name("kernel_im2col_ext_f32")]] kernel im2col_ext_t kernel_im2col_ext;
+template [[host_name("kernel_im2col_ext_f16")]] kernel im2col_ext_t kernel_im2col_ext;
+
 kernel void kernel_upscale_f32(
     device  const char * src0,
     device        char * dst,
@@ -3337,8 +3482,14 @@ void kernel_mul_mv_q2_K_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -3354,21 +3505,19 @@ void kernel_mul_mv_q2_K_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
+    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
 
-    const int step = sizeof(block_q2_K) * nb;
-
     const int ix = tiisg/8;  // 0...3
     const int it = tiisg%8;  // 0...7
     const int iq = it/4;     // 0 or 1
@@ -3413,9 +3562,9 @@ void kernel_mul_mv_q2_K_f32_impl(
                                  (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
                          dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
 
-            qs += step/2;
-            sc += step;
-            dh += step/2;
+            qs += nb01/2;
+            sc += nb01;
+            dh += nb01/2;
         }
 
         y4 += 4 * QK_K;
@@ -3440,12 +3589,14 @@ kernel void kernel_mul_mv_q2_K_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -3454,7 +3605,7 @@ kernel void kernel_mul_mv_q2_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_q3_K_f32_impl(
@@ -3464,8 +3615,14 @@ void kernel_mul_mv_q3_K_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -3486,10 +3643,11 @@ void kernel_mul_mv_q3_K_f32_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
+    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
 
     float yl[32];
 
@@ -3529,8 +3687,6 @@ void kernel_mul_mv_q3_K_f32_impl(
     const int q_offset = 32*ip + l0;
     const int y_offset = 128*ip + 32*il + l0;
 
-    const int step = sizeof(block_q3_K) * nb / 2;
-
     device const float * y1 = yy + ix*QK_K + y_offset;
 
     uint32_t scales32, aux32;
@@ -3540,7 +3696,6 @@ void kernel_mul_mv_q3_K_f32_impl(
     float sumf1[2] = {0.f};
     float sumf2[2] = {0.f};
     for (int i = ix; i < nb; i += 4) {
-
         for (int l = 0; l < 8; ++l) {
             yl[l+ 0] = y1[l+ 0];
             yl[l+ 8] = y1[l+16];
@@ -3554,7 +3709,6 @@ void kernel_mul_mv_q3_K_f32_impl(
         device const half * dh = &x[i].d;
 
         for (int row = 0; row < 2; ++row) {
-
             const float d_all = (float)dh[0];
 
             scales16[0] = a[4];
@@ -3594,15 +3748,13 @@ void kernel_mul_mv_q3_K_f32_impl(
             sumf1[row] += d1 * (scales[1] - 32);
             sumf2[row] += d2 * (scales[3] - 32);
 
-            q  += step;
-            h  += step;
-            a  += step;
-            dh += step;
-
+            q  += nb01/2;
+            h  += nb01/2;
+            a  += nb01/2;
+            dh += nb01/2;
         }
 
         y1 += 4 * QK_K;
-
     }
 
     for (int row = 0; row < 2; ++row) {
@@ -3627,12 +3779,14 @@ kernel void kernel_mul_mv_q3_K_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -3641,7 +3795,7 @@ kernel void kernel_mul_mv_q3_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_q4_K_f32_impl(
@@ -3651,8 +3805,14 @@ void kernel_mul_mv_q4_K_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -3677,29 +3837,26 @@ void kernel_mul_mv_q4_K_f32_impl(
     const int im = tgpig.z;
     //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
     const int first_row = r0 * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
-    device const float      * y = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
+    device const float      * y = (device const float      *) ((device char *) src1 + offset1);
 
     float yl[16];
     float yh[16];
     float sumf[N_DST]={0.f}, all_sum;
 
-    const int step = sizeof(block_q4_K) * nb / 2;
-
     device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
 
     uint16_t sc16[4];
     thread const uint8_t * sc8 = (thread const uint8_t *)sc16;
 
     for (int ib = ix; ib < nb; ib += 4) {
-
         float4 sumy = {0.f, 0.f, 0.f, 0.f};
         for (int i = 0; i < 8; ++i) {
             yl[i+0] = y4[i+  0]; sumy[0] += yl[i+0];
@@ -3713,7 +3870,6 @@ void kernel_mul_mv_q4_K_f32_impl(
         device const half     * dh = &x[ib].d;
 
         for (int row = 0; row < N_DST; row++) {
-
             sc16[0] = sc[0] & kmask1;
             sc16[1] = sc[2] & kmask1;
             sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
@@ -3742,9 +3898,9 @@ void kernel_mul_mv_q4_K_f32_impl(
                                  (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
                          dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
-            q1 += step;
-            sc += step;
-            dh += step;
+            q1 += nb01/2;
+            sc += nb01/2;
+            dh += nb01/2;
         }
 
         y4 += 4 * QK_K;
@@ -3769,12 +3925,14 @@ kernel void kernel_mul_mv_q4_K_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -3783,7 +3941,7 @@ kernel void kernel_mul_mv_q4_K_f32(
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_q5_K_f32_impl(
@@ -3793,8 +3951,14 @@ void kernel_mul_mv_q5_K_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -3815,15 +3979,14 @@ void kernel_mul_mv_q5_K_f32_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
+    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
 
     float sumf[2]={0.f};
 
-    const int step = sizeof(block_q5_K) * nb;
-
     float yl[16], yh[16];
 
     const uint16_t kmask1 = 0x3f3f;
@@ -3851,7 +4014,6 @@ void kernel_mul_mv_q5_K_f32_impl(
     device const float * y1 = yy + ix*QK_K + y_offset;
 
     for (int i = ix; i < nb; i += 4) {
-
         device const uint8_t * q1 = x[i].qs + q_offset;
         device const uint8_t * qh = x[i].qh + l0;
         device const half * dh = &x[i].d;
@@ -3867,7 +4029,6 @@ void kernel_mul_mv_q5_K_f32_impl(
         }
 
         for (int row = 0; row < 2; ++row) {
-
             device const uint8_t * q2 = q1 + 64;
 
             sc16[0] = a[0] & kmask1;
@@ -3896,15 +4057,13 @@ void kernel_mul_mv_q5_K_f32_impl(
                                  sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
                          dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
 
-            q1 += step;
-            qh += step;
-            dh += step/2;
-            a  += step/2;
-
+            q1 += nb01;
+            qh += nb01;
+            dh += nb01/2;
+            a  += nb01/2;
         }
 
         y1 += 4 * QK_K;
-
     }
 
     for (int row = 0; row < 2; ++row) {
@@ -3926,12 +4085,14 @@ kernel void kernel_mul_mv_q5_K_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -3940,7 +4101,7 @@ kernel void kernel_mul_mv_q5_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_q6_K_f32_impl(
@@ -3950,8 +4111,14 @@ void kernel_mul_mv_q6_K_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -3977,10 +4144,11 @@ void kernel_mul_mv_q6_K_f32_impl(
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =  r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
-    device const float     * yy = (device const float      *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
+    device const float     * yy = (device const float      *) ((device char *) src1 + offset1);
 
     float sumf = 0;
 
@@ -4036,12 +4204,14 @@ kernel void kernel_mul_mv_q6_K_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4050,7 +4220,7 @@ kernel void kernel_mul_mv_q6_K_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 // ======================= "True" 2-bit
@@ -4062,8 +4232,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4079,15 +4255,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_iq2_xxs * x = (device const block_iq2_xxs *) src0 + ib_row + offset0;
-    device const float         * y = (device const float         *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
+    device const float         * y = (device const float         *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4140,8 +4316,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
             }
             sumf[row] += d * sum;
 
-            dh += nb*sizeof(block_iq2_xxs)/2;
-            q2 += nb*sizeof(block_iq2_xxs)/2;
+            dh += nb01/2;
+            q2 += nb01/2;
         }
 
         y4 += 32 * 32;
@@ -4166,12 +4342,14 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4181,7 +4359,7 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_iq2_xs_f32_impl(
@@ -4191,8 +4369,14 @@ void kernel_mul_mv_iq2_xs_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4208,15 +4392,15 @@ void kernel_mul_mv_iq2_xs_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_iq2_xs * x = (device const block_iq2_xs *) src0 + ib_row + offset0;
-    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
+    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4278,9 +4462,9 @@ void kernel_mul_mv_iq2_xs_f32_impl(
             }
             sumf[row] += d1 * sum1 + d2 * sum2;
 
-            dh += nb*sizeof(block_iq2_xs)/2;
-            q2 += nb*sizeof(block_iq2_xs)/2;
-            sc += nb*sizeof(block_iq2_xs);
+            dh += nb01/2;
+            q2 += nb01/2;
+            sc += nb01;
         }
 
         y4 += 32 * 32;
@@ -4305,12 +4489,14 @@ kernel void kernel_mul_mv_iq2_xs_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4320,7 +4506,7 @@ kernel void kernel_mul_mv_iq2_xs_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_iq3_xxs_f32_impl(
@@ -4330,8 +4516,14 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4347,15 +4539,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_iq3_xxs * x = (device const block_iq3_xxs *) src0 + ib_row + offset0;
-    device const float         * y = (device const float         *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
+    device const float         * y = (device const float         *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4410,9 +4602,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
             }
             sumf[row] += d * (sum[0] + sum[1]);
 
-            dh  += nb*sizeof(block_iq3_xxs)/2;
-            q3  += nb*sizeof(block_iq3_xxs);
-            gas += nb*sizeof(block_iq3_xxs)/2;
+            dh  += nb01/2;
+            q3  += nb01;
+            gas += nb01/2;
         }
 
         y4 += 32 * 32;
@@ -4437,12 +4629,14 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4452,7 +4646,7 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_iq3_s_f32_impl(
@@ -4462,8 +4656,14 @@ void kernel_mul_mv_iq3_s_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4479,15 +4679,15 @@ void kernel_mul_mv_iq3_s_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_iq3_s * x = (device const block_iq3_s *) src0 + ib_row + offset0;
-    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
+    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4540,11 +4740,11 @@ void kernel_mul_mv_iq3_s_f32_impl(
             }
             sumf[row] += d * (sum[0] + sum[1]);
 
-            dh  += nb*sizeof(block_iq3_s)/2;
-            qs  += nb*sizeof(block_iq3_s);
-            qh  += nb*sizeof(block_iq3_s);
-            sc  += nb*sizeof(block_iq3_s);
-            signs += nb*sizeof(block_iq3_s);
+            dh    += nb01/2;
+            qs    += nb01;
+            qh    += nb01;
+            sc    += nb01;
+            signs += nb01;
         }
 
         y4 += 32 * 32;
@@ -4569,12 +4769,14 @@ kernel void kernel_mul_mv_iq3_s_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4584,7 +4786,7 @@ kernel void kernel_mul_mv_iq3_s_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_iq2_s_f32_impl(
@@ -4594,8 +4796,14 @@ void kernel_mul_mv_iq2_s_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4611,15 +4819,15 @@ void kernel_mul_mv_iq2_s_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
 
-    device const block_iq2_s * x = (device const block_iq2_s *) src0 + ib_row + offset0;
-    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+    device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
+    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4673,11 +4881,11 @@ void kernel_mul_mv_iq2_s_f32_impl(
             }
             sumf[row] += d1 * sum[0] + d2 * sum[1];
 
-            dh  += nb*sizeof(block_iq2_s)/2;
-            qs  += nb*sizeof(block_iq2_s);
-            qh  += nb*sizeof(block_iq2_s);
-            sc  += nb*sizeof(block_iq2_s);
-            signs += nb*sizeof(block_iq2_s);
+            dh    += nb01/2;
+            qs    += nb01;
+            qh    += nb01;
+            sc    += nb01;
+            signs += nb01;
         }
 
         y4 += 32 * 32;
@@ -4702,12 +4910,14 @@ kernel void kernel_mul_mv_iq2_s_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -4717,7 +4927,7 @@ kernel void kernel_mul_mv_iq2_s_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 void kernel_mul_mv_iq1_s_f32_impl(
@@ -4727,8 +4937,14 @@ void kernel_mul_mv_iq1_s_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4744,14 +4960,15 @@ void kernel_mul_mv_iq1_s_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
-    device const block_iq1_s * x = (device const block_iq1_s *) src0 + ib_row + offset0;
-    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
+    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4794,9 +5011,9 @@ void kernel_mul_mv_iq1_s_f32_impl(
             }
             sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
 
-            dh += nb*sizeof(block_iq1_s)/2;
-            qs += nb*sizeof(block_iq1_s);
-            qh += nb*sizeof(block_iq1_s)/2;
+            dh += nb01/2;
+            qs += nb01;
+            qh += nb01/2;
         }
 
         y4 += 32 * 32;
@@ -4817,8 +5034,14 @@ void kernel_mul_mv_iq1_m_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4834,14 +5057,15 @@ void kernel_mul_mv_iq1_m_f32_impl(
     const int im = tgpig.z;
 
     const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
-    device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
-    device const float       * y = (device const float       *) src1 + r1*ne10 + im*ne00*ne1;
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
+    device const float       * y = (device const float       *) ((device char *) src1 + offset1);
 
     float yl[32];
     float sumf[N_DST]={0.f}, all_sum;
@@ -4893,9 +5117,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
             sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
                                              (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
 
-            sc += nb*sizeof(block_iq1_m)/2;
-            qs += nb*sizeof(block_iq1_m);
-            qh += nb*sizeof(block_iq1_m);
+            sc += nb01/2;
+            qs += nb01;
+            qh += nb01;
         }
 
         y4 += 32 * 32;
@@ -4916,8 +5140,14 @@ void kernel_mul_mv_iq4_nl_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -4933,14 +5163,15 @@ void kernel_mul_mv_iq4_nl_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
     const int first_row = (r0 * 2 + sgitg) * 2;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
-    device const block_iq4_nl * x = (device const block_iq4_nl *) src0 + ib_row + offset0;
-    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
+    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
 
     const int ix = tiisg/2;  // 0...15
     const int it = tiisg%2;  // 0 or 1
@@ -5010,8 +5241,14 @@ void kernel_mul_mv_iq4_xs_f32_impl(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -5027,14 +5264,15 @@ void kernel_mul_mv_iq4_xs_f32_impl(
     const int r1 = tgpig.y;
     const int im = tgpig.z;
     const int first_row = (r0 * 2 + sgitg) * 2;
-    const int ib_row = first_row * nb;
 
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
-    device const block_iq4_xs * x = (device const block_iq4_xs *) src0 + ib_row + offset0;
-    device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
+    const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+    const uint offset1 =        r1*nb11 + (i12   )*nb12 + (i13   )*nb13;
+
+    device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
+    device const float        * y = (device const float        *) ((device char *) src1 + offset1);
 
     const int ix = tiisg/16;  // 0 or 1
     const int it = tiisg%16;  // 0...15
@@ -5109,12 +5347,14 @@ kernel void kernel_mul_mv_iq1_s_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -5123,7 +5363,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq1_m_f32")]]
@@ -5137,12 +5377,14 @@ kernel void kernel_mul_mv_iq1_m_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -5151,7 +5393,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
         uint  tiisg[[thread_index_in_simdgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq4_nl_f32")]]
@@ -5165,12 +5407,14 @@ kernel void kernel_mul_mv_iq4_nl_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -5180,7 +5424,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 [[host_name("kernel_mul_mv_iq4_xs_f32")]]
@@ -5194,12 +5438,14 @@ kernel void kernel_mul_mv_iq4_xs_f32(
         constant  uint64_t & nb00,
         constant  uint64_t & nb01,
         constant  uint64_t & nb02,
+        constant  uint64_t & nb03,
         constant   int64_t & ne10,
         constant   int64_t & ne11,
         constant   int64_t & ne12,
         constant  uint64_t & nb10,
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
+        constant  uint64_t & nb13,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   uint    & r2,
@@ -5209,7 +5455,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
         uint tiisg[[thread_index_in_simdgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]]) {
 
-    kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+    kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
 }
 
 //============================= templates and their specializations =============================
@@ -5754,10 +6000,12 @@ kernel void kernel_mul_mm(device const  uchar * src0,
                           constant    int64_t & ne02,
                           constant   uint64_t & nb01,
                           constant   uint64_t & nb02,
+                          constant   uint64_t & nb03,
                           constant    int64_t & ne12,
                           constant   uint64_t & nb10,
                           constant   uint64_t & nb11,
                           constant   uint64_t & nb12,
+                          constant   uint64_t & nb13,
                           constant    int64_t & ne0,
                           constant    int64_t & ne1,
                           constant       uint & r2,
@@ -5794,12 +6042,13 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     const uint i12 = im%ne12;
     const uint i13 = im/ne12;
 
-    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
+    uint   offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
     ushort offset1 = il/nl;
 
     device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
     device const float   * y = (device const float   *)(src1
-        + nb12 * im
+        + nb13 * i13
+        + nb12 * i12
         + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
         + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
 
@@ -6178,12 +6427,14 @@ typedef void (kernel_mul_mv_impl_t)(
                   uint64_t   nb00,
                   uint64_t   nb01,
                   uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne11,
                    int64_t   ne12,
                   uint64_t   nb10,
                   uint64_t   nb11,
                   uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -6198,8 +6449,14 @@ typedef void (kernel_mul_mv2_impl_t)(
                    int64_t   ne00,
                    int64_t   ne01,
                    int64_t   ne02,
+                  uint64_t   nb01,
+                  uint64_t   nb02,
+                  uint64_t   nb03,
                    int64_t   ne10,
                    int64_t   ne12,
+                  uint64_t   nb11,
+                  uint64_t   nb12,
+                  uint64_t   nb13,
                    int64_t   ne0,
                    int64_t   ne1,
                    uint      r2,
@@ -6220,6 +6477,7 @@ void mmv_fn(
                     uint64_t   nb00,
                     uint64_t   nb01,
                     uint64_t   nb02,
+                    uint64_t   nb03,
                      int64_t   ne10,
                      int64_t   ne11,
                      int64_t   ne12,
@@ -6227,6 +6485,7 @@ void mmv_fn(
                     uint64_t   nb10,
                     uint64_t   nb11,
                     uint64_t   nb12,
+                    uint64_t   nb13,
                      int64_t   ne0,
                      int64_t   ne1,
                     uint64_t   nb1,
@@ -6237,7 +6496,7 @@ void mmv_fn(
         uint                   tiitg,
         uint                   tiisg,
         uint                   sgitg) {
-    impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
+    impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg);
 }
 
 template
@@ -6251,6 +6510,7 @@ void mmv_fn(
                     uint64_t   nb00,
                     uint64_t   nb01,
                     uint64_t   nb02,
+                    uint64_t   nb03,
                      int64_t   ne10,
                      int64_t   ne11,
                      int64_t   ne12,
@@ -6258,6 +6518,7 @@ void mmv_fn(
                     uint64_t   nb10,
                     uint64_t   nb11,
                     uint64_t   nb12,
+                    uint64_t   nb13,
                      int64_t   ne0,
                      int64_t   ne1,
                     uint64_t   nb1,
@@ -6268,7 +6529,7 @@ void mmv_fn(
         uint                   tiitg,
         uint                   tiisg,
         uint                   sgitg) {
-    impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
+    impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
 }
 
 typedef decltype(mmv_fn>) mul_mv_impl_fn_t;
@@ -6317,8 +6578,8 @@ kernel void kernel_mul_mv_id(
     const int64_t i2 = i12;
 
     device const char * src0_cur = src0s + i02*nb02;
-    device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
-    device      float * dst_cur  = dst + i1*ne0 + i2*ne1*ne0;
+    device const char * src1_cur = src1  + i11*nb11 + i12*nb12;
+    device      float *  dst_cur = dst   + i1*ne0   + i2*ne1*ne0;
 
     impl_fn(
         /* src0 */ src0_cur,
@@ -6326,19 +6587,21 @@ kernel void kernel_mul_mv_id(
         /* dst  */ dst_cur,
         /* ne00 */ ne00,
         /* ne01 */ ne01,
-        /* ne02 */ 1,//ne02,
+        /* ne02 */ 1, // ne02,
         /* nb00 */ nb00,
         /* nb01 */ nb01,
         /* nb02 */ nb02,
+        /* nb03 */ nb02, // ne02 == 1
         /* ne10 */ ne10,
-        /* ne11 */ 1,//ne11,
-        /* ne12 */ 1,//ne12,
-        /* ne13 */ 1,//ne13,
+        /* ne11 */ 1, // ne11,
+        /* ne12 */ 1, // ne12,
+        /* ne13 */ 1, // ne13,
         /* nb10 */ nb10,
         /* nb11 */ nb11,
         /* nb12 */ nb12,
+        /* ne13 */ nb12, // ne12 == 1
         /* ne0  */ ne0,
-        /* ne1  */ 1,//ne1,
+        /* ne1  */ 1, // ne1,
         /* nb1  */ nb1,
         /* r2   */ 1,
         /* r3   */ 1,
@@ -6372,3 +6635,102 @@ template [[host_name("kernel_mul_mv_id_iq3_s_f32")]]   kernel kernel_mul_mv_id_t
 template [[host_name("kernel_mul_mv_id_iq2_s_f32")]]   kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
 template [[host_name("kernel_mul_mv_id_iq4_nl_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
 template [[host_name("kernel_mul_mv_id_iq4_xs_f32")]]  kernel kernel_mul_mv_id_t kernel_mul_mv_id>;
+
+kernel void kernel_pool_2d_max_f32(
+        device  const float * src0,
+        device        float * dst,
+        constant    int32_t & k0,
+        constant    int32_t & k1,
+        constant    int32_t & s0,
+        constant    int32_t & s1,
+        constant    int32_t & p0,
+        constant    int32_t & p1,
+        constant    int64_t & IH,
+        constant    int64_t & IW,
+        constant    int64_t & OH,
+        constant    int64_t & OW,
+        constant    int64_t & parallel_elements,
+        uint        gid[[thread_position_in_grid]]) {
+
+    if (gid >= parallel_elements) {
+        return;
+    }
+
+    const int idx = gid;
+    const int I_HW = IH * IW;
+    const int O_HW = OH * OW;
+    const int nc = idx / O_HW;
+    const int cur_oh = idx % O_HW / OW;
+    const int cur_ow = idx % O_HW % OW;
+
+    device const float * i_ptr = src0 + nc * I_HW;
+    device       float * o_ptr = dst  + nc * O_HW;
+
+    const int start_h = cur_oh * s1 - p1;
+    const int bh = MAX(0,  start_h);
+    const int eh = MIN(IH, start_h + k1);
+    const int start_w = cur_ow * s0 - p0;
+    const int bw = MAX(0,  start_w);
+    const int ew = MIN(IW, start_w + k0);
+
+    float res = -INFINITY;
+
+    for (int i = bh; i < eh; i += 1) {
+        for (int j = bw; j < ew; j += 1) {
+            res = MAX(res, i_ptr[i * IW + j]);
+        }
+    }
+
+    o_ptr[cur_oh * OW + cur_ow] = res;
+}
+
+kernel void kernel_pool_2d_avg_f32(
+        device  const float * src0,
+        device        float * dst,
+        constant    int32_t & k0,
+        constant    int32_t & k1,
+        constant    int32_t & s0,
+        constant    int32_t & s1,
+        constant    int32_t & p0,
+        constant    int32_t & p1,
+        constant    int64_t & IH,
+        constant    int64_t & IW,
+        constant    int64_t & OH,
+        constant    int64_t & OW,
+        constant    int64_t & parallel_elements,
+        uint        gid[[thread_position_in_grid]]) {
+
+    if (gid >= parallel_elements) {
+        return;
+    }
+
+    const int idx = gid;
+    const int I_HW = IH * IW;
+    const int O_HW = OH * OW;
+    const int nc = idx / O_HW;
+    const int cur_oh = idx % O_HW / OW;
+    const int cur_ow = idx % O_HW % OW;
+
+    device const float * i_ptr = src0 + nc * I_HW;
+    device       float * o_ptr = dst  + nc * O_HW;
+
+    const int start_h = cur_oh * s1 - p1;
+    const int bh = MAX(0,  start_h);
+    const int eh = MIN(IH, start_h + k1);
+    const int start_w = cur_ow * s0 - p0;
+    const int bw = MAX(0,  start_w);
+    const int ew = MIN(IW, start_w + k0);
+    // const float scale = 1. / ((eh - bh) * (ew - bw));
+    const float scale = 1. / (k0 * k1);
+
+    float res = 0;
+
+    for (int i = bh; i < eh; i += 1) {
+        for (int j = bw; j < ew; j += 1) {
+            float cur = i_ptr[i * IW + j];
+            res += cur * scale;
+        }
+    }
+
+    o_ptr[cur_oh * OW + cur_ow] = res;
+}
diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc.cpp
index ab7298cbae0..0e936b3437e 100644
--- a/ggml/src/ggml-rpc.cpp
+++ b/ggml/src/ggml-rpc.cpp
@@ -25,7 +25,7 @@
 #  include 
 #  include 
 #endif
-#include 
+#include 
 
 #define UNUSED GGML_UNUSED
 
@@ -57,8 +57,9 @@ struct socket_t {
     }
 };
 
-// ggml_tensor is serialized into rpc_tensor
+// all RPC structures must be packed
 #pragma pack(push, 1)
+// ggml_tensor is serialized into rpc_tensor
 struct rpc_tensor {
     uint64_t id;
     uint32_t type;
@@ -76,7 +77,6 @@ struct rpc_tensor {
 
     char padding[4];
 };
-#pragma pack(pop)
 
 static_assert(sizeof(rpc_tensor) % 8 == 0, "rpc_tensor size must be multiple of 8");
 
@@ -96,6 +96,65 @@ enum rpc_cmd {
     RPC_CMD_COUNT,
 };
 
+struct rpc_msg_alloc_buffer_req {
+    uint64_t size;
+};
+
+struct rpc_msg_alloc_buffer_rsp {
+    uint64_t remote_ptr;
+    uint64_t remote_size;
+};
+
+struct rpc_msg_get_alignment_rsp {
+    uint64_t alignment;
+};
+
+struct rpc_msg_get_max_size_rsp {
+    uint64_t max_size;
+};
+
+struct rpc_msg_buffer_get_base_req {
+    uint64_t remote_ptr;
+};
+
+struct rpc_msg_buffer_get_base_rsp {
+    uint64_t base_ptr;
+};
+
+struct rpc_msg_free_buffer_req {
+    uint64_t remote_ptr;
+};
+
+struct rpc_msg_buffer_clear_req {
+    uint64_t remote_ptr;
+    uint8_t value;
+};
+
+struct rpc_msg_get_tensor_req {
+    rpc_tensor tensor;
+    uint64_t offset;
+    uint64_t size;
+};
+
+struct rpc_msg_copy_tensor_req {
+    rpc_tensor src;
+    rpc_tensor dst;
+};
+
+struct rpc_msg_copy_tensor_rsp {
+    uint8_t result;
+};
+
+struct rpc_msg_graph_compute_rsp {
+    uint8_t result;
+};
+
+struct rpc_msg_get_device_memory_rsp {
+    uint64_t free_mem;
+    uint64_t total_mem;
+};
+#pragma pack(pop)
+
 // RPC data structures
 
 static ggml_guid_t ggml_backend_rpc_guid() {
@@ -240,6 +299,38 @@ static bool recv_data(sockfd_t sockfd, void * data, size_t size) {
     return true;
 }
 
+static bool send_msg(sockfd_t sockfd, const void * msg, size_t msg_size) {
+    if (!send_data(sockfd, &msg_size, sizeof(msg_size))) {
+        return false;
+    }
+    return send_data(sockfd, msg, msg_size);
+}
+
+static bool recv_msg(sockfd_t sockfd, void * msg, size_t msg_size) {
+    uint64_t size;
+    if (!recv_data(sockfd, &size, sizeof(size))) {
+        return false;
+    }
+    if (size != msg_size) {
+        return false;
+    }
+    return recv_data(sockfd, msg, msg_size);
+}
+
+static bool recv_msg(sockfd_t sockfd, std::vector & input) {
+    uint64_t size;
+    if (!recv_data(sockfd, &size, sizeof(size))) {
+        return false;
+    }
+    try {
+        input.resize(size);
+    } catch (const std::bad_alloc & e) {
+        fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", size);
+        return false;
+    }
+    return recv_data(sockfd, input.data(), size);
+}
+
 static bool parse_endpoint(const std::string & endpoint, std::string & host, int & port) {
     size_t pos = endpoint.find(':');
     if (pos == std::string::npos) {
@@ -252,28 +343,27 @@ static bool parse_endpoint(const std::string & endpoint, std::string & host, int
 
 // RPC request : | rpc_cmd (1 byte) | request_size (8 bytes) | request_data (request_size bytes) |
 // RPC response: | response_size (8 bytes) | response_data (response_size bytes) |
-static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const std::vector & input, std::vector & output) {
+static bool send_rpc_cmd(const std::shared_ptr & sock, enum rpc_cmd cmd, const void * input, size_t input_size, void * output, size_t output_size) {
     uint8_t cmd_byte = cmd;
     if (!send_data(sock->fd, &cmd_byte, sizeof(cmd_byte))) {
         return false;
     }
-    uint64_t input_size = input.size();
     if (!send_data(sock->fd, &input_size, sizeof(input_size))) {
         return false;
     }
-    if (!send_data(sock->fd, input.data(), input.size())) {
+    if (!send_data(sock->fd, input, input_size)) {
         return false;
     }
-    uint64_t output_size;
-    if (!recv_data(sock->fd, &output_size, sizeof(output_size))) {
+    // TODO: currently the output_size is always known, do we need support for commands with variable output size?
+    // even if we do, we can skip sending output_size from the server for commands with known output size
+    uint64_t out_size;
+    if (!recv_data(sock->fd, &out_size, sizeof(out_size))) {
         return false;
     }
-    if (output_size == 0) {
-        output.clear();
-        return true;
+    if (out_size != output_size) {
+        return false;
     }
-    output.resize(output_size);
-    if (!recv_data(sock->fd, output.data(), output_size)) {
+    if (!recv_data(sock->fd, output, output_size)) {
         return false;
     }
     return true;
@@ -326,14 +416,9 @@ static const char * ggml_backend_rpc_buffer_get_name(ggml_backend_buffer_t buffe
 
 static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) {
     ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    // input serialization format: | remote_ptr (8 bytes) |
-    std::vector input(sizeof(uint64_t), 0);
-    uint64_t remote_ptr = ctx->remote_ptr;
-    memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
-    std::vector output;
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, input, output);
+    rpc_msg_free_buffer_req request = {ctx->remote_ptr};
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0);
     GGML_ASSERT(status);
-    GGML_ASSERT(output.empty());
     delete ctx;
 }
 
@@ -342,20 +427,13 @@ static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) {
     if (ctx->base_cache.find(buffer) != ctx->base_cache.end()) {
         return ctx->base_cache[buffer];
     }
-    // input serialization format: | remote_ptr (8 bytes) |
-    std::vector input(sizeof(uint64_t), 0);
-    uint64_t remote_ptr = ctx->remote_ptr;
-    memcpy(input.data(), &remote_ptr, sizeof(remote_ptr));
-    std::vector output;
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, input, output);
+    rpc_msg_buffer_get_base_req request = {ctx->remote_ptr};
+    rpc_msg_buffer_get_base_rsp response;
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_GET_BASE, &request, sizeof(request), &response, sizeof(response));
     GGML_ASSERT(status);
-    GGML_ASSERT(output.size() == sizeof(uint64_t));
-    // output serialization format: | base_ptr (8 bytes) |
-    uint64_t base_ptr;
-    memcpy(&base_ptr, output.data(), sizeof(base_ptr));
-    void * base = reinterpret_cast(base_ptr);
-    ctx->base_cache[buffer] = base;
-    return base;
+    void * base_ptr = reinterpret_cast(response.base_ptr);
+    ctx->base_cache[buffer] = base_ptr;
+    return base_ptr;
 }
 
 static rpc_tensor serialize_tensor(const ggml_tensor * tensor) {
@@ -405,26 +483,18 @@ static void ggml_backend_rpc_buffer_set_tensor(ggml_backend_buffer_t buffer, ggm
     memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
     memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
     memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), data, size);
-    std::vector output;
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input, output);
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_SET_TENSOR, input.data(), input.size(), nullptr, 0);
     GGML_ASSERT(status);
 }
 
 static void ggml_backend_rpc_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) {
     ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    // input serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
-    int input_size = sizeof(rpc_tensor) + 2*sizeof(uint64_t);
-    std::vector input(input_size, 0);
-    rpc_tensor rpc_tensor = serialize_tensor(tensor);
-    memcpy(input.data(), &rpc_tensor, sizeof(rpc_tensor));
-    memcpy(input.data() + sizeof(rpc_tensor), &offset, sizeof(offset));
-    memcpy(input.data() + sizeof(rpc_tensor) + sizeof(offset), &size, sizeof(size));
-    std::vector output;
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, input, output);
+    rpc_msg_get_tensor_req request;
+    request.tensor = serialize_tensor(tensor);
+    request.offset = offset;
+    request.size = size;
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_GET_TENSOR, &request, sizeof(request), data, size);
     GGML_ASSERT(status);
-    GGML_ASSERT(output.size() == size);
-    // output serialization format: | data (size bytes) |
-    memcpy(data, output.data(), size);
 }
 
 static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) {
@@ -437,30 +507,19 @@ static bool ggml_backend_rpc_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
         return false;
     }
     ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    // input serialization format: | rpc_tensor src | rpc_tensor dst |
-    int input_size = 2*sizeof(rpc_tensor);
-    std::vector input(input_size, 0);
-    rpc_tensor rpc_src = serialize_tensor(src);
-    rpc_tensor rpc_dst = serialize_tensor(dst);
-    memcpy(input.data(), &rpc_src, sizeof(rpc_src));
-    memcpy(input.data() + sizeof(rpc_src), &rpc_dst, sizeof(rpc_dst));
-    std::vector output;
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, input, output);
+    rpc_msg_copy_tensor_req request;
+    request.src = serialize_tensor(src);
+    request.dst = serialize_tensor(dst);
+    rpc_msg_copy_tensor_rsp response;
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_COPY_TENSOR, &request, sizeof(request), &response, sizeof(response));
     GGML_ASSERT(status);
-    // output serialization format: | result (1 byte) |
-    GGML_ASSERT(output.size() == 1);
-    return output[0];
+    return response.result;
 }
 
 static void ggml_backend_rpc_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
     ggml_backend_rpc_buffer_context * ctx = (ggml_backend_rpc_buffer_context *)buffer->context;
-    // serialization format: | bufptr (8 bytes) | value (1 byte) |
-    int input_size = sizeof(uint64_t) + sizeof(uint8_t);
-    std::vector input(input_size, 0);
-    memcpy(input.data(), &ctx->remote_ptr, sizeof(ctx->remote_ptr));
-    memcpy(input.data() + sizeof(ctx->remote_ptr), &value, sizeof(value));
-    std::vector output;
-    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, input, output);
+    rpc_msg_buffer_clear_req request = {ctx->remote_ptr, value};
+    bool status = send_rpc_cmd(ctx->sock, RPC_CMD_BUFFER_CLEAR, &request, sizeof(request), nullptr, 0);
     GGML_ASSERT(status);
 }
 
@@ -484,25 +543,16 @@ static const char * ggml_backend_rpc_buffer_type_name(ggml_backend_buffer_type_t
 
 static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
     ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
-    // input serialization format: | size (8 bytes) |
-    int input_size = sizeof(uint64_t);
-    std::vector input(input_size, 0);
-    memcpy(input.data(), &size, sizeof(size));
-    std::vector output;
+    rpc_msg_alloc_buffer_req request = {size};
+    rpc_msg_alloc_buffer_rsp response;
     auto sock = get_socket(buft_ctx->endpoint);
-    bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, input, output);
+    bool status = send_rpc_cmd(sock, RPC_CMD_ALLOC_BUFFER, &request, sizeof(request), &response, sizeof(response));
     GGML_ASSERT(status);
-    GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
-    // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
-    uint64_t remote_ptr;
-    memcpy(&remote_ptr, output.data(), sizeof(remote_ptr));
-    size_t remote_size;
-    memcpy(&remote_size, output.data() + sizeof(uint64_t), sizeof(remote_size));
-    if (remote_ptr != 0) {
+    if (response.remote_ptr != 0) {
         ggml_backend_buffer_t buffer = ggml_backend_buffer_init(buft,
             ggml_backend_rpc_buffer_interface,
-            new ggml_backend_rpc_buffer_context{sock, {}, remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
-            remote_size);
+            new ggml_backend_rpc_buffer_context{sock, {}, response.remote_ptr, "RPC[" + std::string(buft_ctx->endpoint) + "]"},
+            response.remote_size);
         return buffer;
     } else {
         return nullptr;
@@ -510,16 +560,10 @@ static ggml_backend_buffer_t ggml_backend_rpc_buffer_type_alloc_buffer(ggml_back
 }
 
 static size_t get_alignment(const std::shared_ptr & sock) {
-    // input serialization format: | 0 bytes |
-    std::vector input;
-    std::vector output;
-    bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, input, output);
+    rpc_msg_get_alignment_rsp response;
+    bool status = send_rpc_cmd(sock, RPC_CMD_GET_ALIGNMENT, nullptr, 0, &response, sizeof(response));
     GGML_ASSERT(status);
-    GGML_ASSERT(output.size() == sizeof(uint64_t));
-    // output serialization format: | alignment (8 bytes) |
-    uint64_t alignment;
-    memcpy(&alignment, output.data(), sizeof(alignment));
-    return alignment;
+    return response.alignment;
 }
 
 static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
@@ -528,16 +572,10 @@ static size_t ggml_backend_rpc_buffer_type_get_alignment(ggml_backend_buffer_typ
 }
 
 static size_t get_max_size(const std::shared_ptr & sock) {
-    // input serialization format: | 0 bytes |
-    std::vector input;
-    std::vector output;
-    bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, input, output);
+    rpc_msg_get_max_size_rsp response;
+    bool status = send_rpc_cmd(sock, RPC_CMD_GET_MAX_SIZE, nullptr, 0, &response, sizeof(response));
     GGML_ASSERT(status);
-    GGML_ASSERT(output.size() == sizeof(uint64_t));
-    // output serialization format: | max_size (8 bytes) |
-    uint64_t max_size;
-    memcpy(&max_size, output.data(), sizeof(max_size));
-    return max_size;
+    return response.max_size;
 }
 
 static size_t ggml_backend_rpc_get_max_size(ggml_backend_buffer_type_t buft) {
@@ -622,28 +660,11 @@ static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, g
     ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
     std::vector input;
     serialize_graph(cgraph, input);
-    std::vector output;
+    rpc_msg_graph_compute_rsp response;
     auto sock = get_socket(rpc_ctx->endpoint);
-    bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input, output);
+    bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
     GGML_ASSERT(status);
-    GGML_ASSERT(output.size() == 1);
-    return (enum ggml_status)output[0];
-}
-
-static bool ggml_backend_rpc_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
-    UNUSED(backend);
-    UNUSED(op);
-    //TODO: call the remote backend and cache the results
-    return true;
-}
-
-static bool ggml_backend_rpc_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
-    if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
-        return false;
-    }
-    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
-    ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
-    return buft_ctx->endpoint == rpc_ctx->endpoint;
+    return (enum ggml_status)response.result;
 }
 
 static ggml_backend_i ggml_backend_rpc_interface = {
@@ -659,8 +680,8 @@ static ggml_backend_i ggml_backend_rpc_interface = {
     /* .graph_plan_update       = */ NULL,
     /* .graph_plan_compute      = */ NULL,
     /* .graph_compute           = */ ggml_backend_rpc_graph_compute,
-    /* .supports_op             = */ ggml_backend_rpc_supports_op,
-    /* .supports_buft           = */ ggml_backend_rpc_supports_buft,
+    /* .supports_op             = */ NULL,
+    /* .supports_buft           = */ NULL,
     /* .offload_op              = */ NULL,
     /* .event_record            = */ NULL,
     /* .event_wait              = */ NULL,
@@ -691,7 +712,7 @@ GGML_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * en
 
     ggml_backend_buffer_type_t buft = new ggml_backend_buffer_type {
         /* .iface   = */ ggml_backend_rpc_buffer_type_interface,
-        /* .device  = */ nullptr,
+        /* .device  = */ ggml_backend_rpc_add_device(endpoint),
         /* .context = */ buft_ctx
     };
     buft_map[endpoint] = buft;
@@ -707,7 +728,7 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint) {
     ggml_backend_t backend = new ggml_backend {
         /* .guid      = */ ggml_backend_rpc_guid(),
         /* .interface = */ ggml_backend_rpc_interface,
-        /* .device    = */ nullptr,
+        /* .device    = */ ggml_backend_rpc_add_device(endpoint),
         /* .context   = */ ctx
     };
     return backend;
@@ -718,19 +739,11 @@ GGML_API bool ggml_backend_is_rpc(ggml_backend_t backend) {
 }
 
 static void get_device_memory(const std::shared_ptr & sock, size_t * free, size_t * total) {
-    // input serialization format: | 0 bytes |
-    std::vector input;
-    std::vector output;
-    bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, input, output);
+    rpc_msg_get_device_memory_rsp response;
+    bool status = send_rpc_cmd(sock, RPC_CMD_GET_DEVICE_MEMORY, nullptr, 0, &response, sizeof(response));
     GGML_ASSERT(status);
-    GGML_ASSERT(output.size() == 2*sizeof(uint64_t));
-    // output serialization format: | free (8 bytes) | total (8 bytes) |
-    uint64_t free_mem;
-    memcpy(&free_mem, output.data(), sizeof(free_mem));
-    uint64_t total_mem;
-    memcpy(&total_mem, output.data() + sizeof(uint64_t), sizeof(total_mem));
-    *free = free_mem;
-    *total = total_mem;
+    *free = response.free_mem;
+    *total = response.total_mem;
 }
 
 GGML_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total) {
@@ -750,16 +763,16 @@ class rpc_server {
     rpc_server(ggml_backend_t backend) : backend(backend) {}
     ~rpc_server();
 
-    bool alloc_buffer(const std::vector & input, std::vector & output);
-    void get_alignment(std::vector & output);
-    void get_max_size(std::vector & output);
-    bool buffer_get_base(const std::vector & input, std::vector & output);
-    bool free_buffer(const std::vector & input);
-    bool buffer_clear(const std::vector & input);
+    void alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response);
+    void get_alignment(rpc_msg_get_alignment_rsp & response);
+    void get_max_size(rpc_msg_get_max_size_rsp & response);
+    bool buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response);
+    bool free_buffer(const rpc_msg_free_buffer_req & request);
+    bool buffer_clear(const rpc_msg_buffer_clear_req & request);
     bool set_tensor(const std::vector & input);
-    bool get_tensor(const std::vector & input, std::vector & output);
-    bool copy_tensor(const std::vector & input, std::vector & output);
-    bool graph_compute(const std::vector & input, std::vector & output);
+    bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response);
+    bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
+    bool graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response);
 
 private:
     ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@@ -773,80 +786,50 @@ class rpc_server {
     std::unordered_set buffers;
 };
 
-bool rpc_server::alloc_buffer(const std::vector & input, std::vector & output) {
-    // input serialization format: | size (8 bytes) |
-    if (input.size() != sizeof(uint64_t)) {
-        return false;
-    }
-    uint64_t size;
-    memcpy(&size, input.data(), sizeof(size));
+void rpc_server::alloc_buffer(const rpc_msg_alloc_buffer_req & request, rpc_msg_alloc_buffer_rsp & response) {
     ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
-    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, size);
-    uint64_t remote_ptr = 0;
-    uint64_t remote_size = 0;
+    ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(buft, request.size);
+    response.remote_ptr = 0;
+    response.remote_size = 0;
     if (buffer != nullptr) {
-        remote_ptr = reinterpret_cast(buffer);
-        remote_size = buffer->size;
-        GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, size, remote_ptr, remote_size);
+        response.remote_ptr = reinterpret_cast(buffer);
+        response.remote_size = buffer->size;
+        GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> remote_ptr: %" PRIx64 ", remote_size: %" PRIu64 "\n", __func__, request.size, response.remote_ptr, response.remote_size);
         buffers.insert(buffer);
     } else {
-        GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, size);
+        GGML_PRINT_DEBUG("[%s] size: %" PRIu64 " -> failed\n", __func__, request.size);
     }
-    // output serialization format: | remote_ptr (8 bytes) | remote_size (8 bytes) |
-    output.resize(2*sizeof(uint64_t), 0);
-    memcpy(output.data(), &remote_ptr, sizeof(remote_ptr));
-    memcpy(output.data() + sizeof(uint64_t), &remote_size, sizeof(remote_size));
-    return true;
 }
 
-void rpc_server::get_alignment(std::vector & output) {
+void rpc_server::get_alignment(rpc_msg_get_alignment_rsp & response) {
     ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
     size_t alignment = ggml_backend_buft_get_alignment(buft);
     GGML_PRINT_DEBUG("[%s] alignment: %lu\n", __func__, alignment);
-    // output serialization format: | alignment (8 bytes) |
-    output.resize(sizeof(uint64_t), 0);
-    memcpy(output.data(), &alignment, sizeof(alignment));
+    response.alignment = alignment;
 }
 
-void rpc_server::get_max_size(std::vector & output) {
+void rpc_server::get_max_size(rpc_msg_get_max_size_rsp & response) {
     ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(backend);
     size_t max_size = ggml_backend_buft_get_max_size(buft);
     GGML_PRINT_DEBUG("[%s] max_size: %lu\n", __func__, max_size);
-    // output serialization format: | max_size (8 bytes) |
-    output.resize(sizeof(uint64_t), 0);
-    memcpy(output.data(), &max_size, sizeof(max_size));
+    response.max_size = max_size;
 }
 
-bool rpc_server::buffer_get_base(const std::vector & input, std::vector & output) {
-    // input serialization format: | remote_ptr (8 bytes) |
-    if (input.size() != sizeof(uint64_t)) {
-        return false;
-    }
-    uint64_t remote_ptr;
-    memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
-    GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
-    ggml_backend_buffer_t buffer = reinterpret_cast(remote_ptr);
+bool rpc_server::buffer_get_base(const rpc_msg_buffer_get_base_req & request, rpc_msg_buffer_get_base_rsp & response) {
+    GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
+    ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr);
     if (buffers.find(buffer) == buffers.end()) {
         GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
         return false;
     }
     void * base = ggml_backend_buffer_get_base(buffer);
-    // output serialization format: | base_ptr (8 bytes) |
-    uint64_t base_ptr = reinterpret_cast(base);
-    output.resize(sizeof(uint64_t), 0);
-    memcpy(output.data(), &base_ptr, sizeof(base_ptr));
+    response.base_ptr = reinterpret_cast(base);
     return true;
 }
 
-bool rpc_server::free_buffer(const std::vector & input) {
-    // input serialization format: | remote_ptr (8 bytes) |
-    if (input.size() != sizeof(uint64_t)) {
-        return false;
-    }
-    uint64_t remote_ptr;
-    memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
-    GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, remote_ptr);
-    ggml_backend_buffer_t buffer = reinterpret_cast(remote_ptr);
+bool rpc_server::free_buffer(const rpc_msg_free_buffer_req & request) {
+    GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 "\n", __func__, request.remote_ptr);
+    ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr);
     if (buffers.find(buffer) == buffers.end()) {
         GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
         return false;
@@ -856,22 +839,14 @@ bool rpc_server::free_buffer(const std::vector & input) {
     return true;
 }
 
-bool rpc_server::buffer_clear(const std::vector & input) {
-    // input serialization format: | remote_ptr (8 bytes) | value (1 byte) |
-    if (input.size() != sizeof(uint64_t) + sizeof(uint8_t)) {
-        return false;
-    }
-    uint64_t remote_ptr;
-    memcpy(&remote_ptr, input.data(), sizeof(remote_ptr));
-    uint8_t value;
-    memcpy(&value, input.data() + sizeof(uint64_t), sizeof(value));
-    GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, remote_ptr, value);
-    ggml_backend_buffer_t buffer = reinterpret_cast(remote_ptr);
+bool rpc_server::buffer_clear(const rpc_msg_buffer_clear_req & request) {
+    GGML_PRINT_DEBUG("[%s] remote_ptr: %" PRIx64 ", value: %u\n", __func__, request.remote_ptr, request.value);
+    ggml_backend_buffer_t buffer = reinterpret_cast(request.remote_ptr);
     if (buffers.find(buffer) == buffers.end()) {
         GGML_PRINT_DEBUG("[%s] buffer not found\n", __func__);
         return false;
     }
-    ggml_backend_buffer_clear(buffer, value);
+    ggml_backend_buffer_clear(buffer, request.value);
     return true;
 }
 
@@ -946,74 +921,55 @@ bool rpc_server::set_tensor(const std::vector & input) {
     return true;
 }
 
-bool rpc_server::get_tensor(const std::vector & input, std::vector & output) {
-    // serialization format: | rpc_tensor | offset (8 bytes) | size (8 bytes) |
-    if (input.size() != sizeof(rpc_tensor) + 2*sizeof(uint64_t)) {
-        return false;
-    }
-    const rpc_tensor * in_tensor = (const rpc_tensor *)input.data();
-    uint64_t offset;
-    memcpy(&offset, input.data() + sizeof(rpc_tensor), sizeof(offset));
-    uint64_t size;
-    memcpy(&size, input.data() + sizeof(rpc_tensor) + sizeof(offset), sizeof(size));
-
+bool rpc_server::get_tensor(const rpc_msg_get_tensor_req & request, std::vector & response) {
     struct ggml_init_params params {
         /*.mem_size   =*/ ggml_tensor_overhead(),
         /*.mem_buffer =*/ NULL,
         /*.no_alloc   =*/ true,
     };
     struct ggml_context * ctx = ggml_init(params);
-    ggml_tensor * tensor = deserialize_tensor(ctx, in_tensor);
+    ggml_tensor * tensor = deserialize_tensor(ctx, &request.tensor);
     if (tensor == nullptr) {
         GGML_PRINT_DEBUG("[%s] error deserializing tensor\n", __func__);
         ggml_free(ctx);
         return false;
     }
-    GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, offset, size);
+    GGML_PRINT_DEBUG("[%s] buffer: %p, data: %p, offset: %" PRIu64 ", size: %" PRIu64 "\n", __func__, (void*)tensor->buffer, tensor->data, request.offset, request.size);
 
     // sanitize tensor->data
     {
         const size_t p0 = (size_t) ggml_backend_buffer_get_base(tensor->buffer);
         const size_t p1 = p0 + ggml_backend_buffer_get_size(tensor->buffer);
 
-        if (in_tensor->data + offset < p0 || in_tensor->data + offset >= p1 || size > (p1 - in_tensor->data - offset)) {
-            GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
+        if (request.tensor.data + request.offset < p0 ||
+            request.tensor.data + request.offset >= p1 ||
+            request.size > (p1 - request.tensor.data - request.offset)) {
+                GGML_ABORT("[%s] tensor->data out of bounds\n", __func__);
         }
     }
 
-    // output serialization format: | data (size bytes) |
-    output.resize(size, 0);
-    ggml_backend_tensor_get(tensor, output.data(), offset, size);
+    response.resize(request.size, 0);
+    ggml_backend_tensor_get(tensor, response.data(), request.offset, request.size);
     ggml_free(ctx);
     return true;
 }
 
-bool rpc_server::copy_tensor(const std::vector & input, std::vector & output) {
-    // serialization format: | rpc_tensor src | rpc_tensor dst |
-    if (input.size() != 2*sizeof(rpc_tensor)) {
-        return false;
-    }
-    const rpc_tensor * rpc_src = (const rpc_tensor *)input.data();
-    const rpc_tensor * rpc_dst = (const rpc_tensor *)(input.data() + sizeof(rpc_src));
-
+bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response) {
     struct ggml_init_params params {
         /*.mem_size   =*/ 2*ggml_tensor_overhead(),
         /*.mem_buffer =*/ NULL,
         /*.no_alloc   =*/ true,
     };
     struct ggml_context * ctx = ggml_init(params);
-    ggml_tensor * src = deserialize_tensor(ctx, rpc_src);
-    ggml_tensor * dst = deserialize_tensor(ctx, rpc_dst);
+    ggml_tensor * src = deserialize_tensor(ctx, &request.src);
+    ggml_tensor * dst = deserialize_tensor(ctx, &request.dst);
     if (src == nullptr || dst == nullptr) {
         GGML_PRINT_DEBUG("[%s] error deserializing tensors\n", __func__);
         ggml_free(ctx);
         return false;
     }
     GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
-    bool result = ggml_backend_buffer_copy_tensor(src, dst);
-    // output serialization format: | result (1 byte) |
-    output.resize(1, 0);
-    output[0] = result;
+    response.result = ggml_backend_buffer_copy_tensor(src, dst);
     ggml_free(ctx);
     return true;
 }
@@ -1042,7 +998,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
     return result;
 }
 
-bool rpc_server::graph_compute(const std::vector & input, std::vector & output) {
+bool rpc_server::graph_compute(const std::vector & input, rpc_msg_graph_compute_rsp & response) {
     // serialization format:
     // | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
     if (input.size() < sizeof(uint32_t)) {
@@ -1082,9 +1038,7 @@ bool rpc_server::graph_compute(const std::vector & input, std::vectornodes[i] = create_node(id, ctx, tensor_ptrs, tensor_map);
     }
     ggml_status status = ggml_backend_graph_compute(backend, graph);
-    // output serialization format: | status (1 byte) |
-    output.resize(1, 0);
-    output[0] = status;
+    response.result = status;
     ggml_free(ctx);
     return true;
 }
@@ -1107,89 +1061,157 @@ static void rpc_serve_client(ggml_backend_t backend, sockfd_t sockfd, size_t fre
             fprintf(stderr, "Unknown command: %d\n", cmd);
             break;
         }
-        std::vector input;
-        std::vector output;
-        uint64_t input_size;
-        if (!recv_data(sockfd, &input_size, sizeof(input_size))) {
-            break;
-        }
-        try {
-            input.resize(input_size);
-        } catch (const std::bad_alloc & e) {
-            fprintf(stderr, "Failed to allocate input buffer of size %" PRIu64 "\n", input_size);
-            break;
-        }
-        if (!recv_data(sockfd, input.data(), input_size)) {
-            break;
-        }
-        bool ok = true;
         switch (cmd) {
             case RPC_CMD_ALLOC_BUFFER: {
-                ok = server.alloc_buffer(input, output);
+                rpc_msg_alloc_buffer_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
+                    return;
+                }
+                rpc_msg_alloc_buffer_rsp response;
+                server.alloc_buffer(request, response);
+                if (!send_msg(sockfd, &response, sizeof(response))) {
+                    return;
+                }
                 break;
             }
             case RPC_CMD_GET_ALIGNMENT: {
-                server.get_alignment(output);
+                if (!recv_msg(sockfd, nullptr, 0)) {
+                    return;
+                }
+                rpc_msg_get_alignment_rsp response;
+                server.get_alignment(response);
+                if (!send_msg(sockfd, &response, sizeof(response))) {
+                    return;
+                }
                 break;
             }
             case RPC_CMD_GET_MAX_SIZE: {
-                server.get_max_size(output);
+                if (!recv_msg(sockfd, nullptr, 0)) {
+                    return;
+                }
+                rpc_msg_get_max_size_rsp response;
+                server.get_max_size(response);
+                if (!send_msg(sockfd, &response, sizeof(response))) {
+                    return;
+                }
                 break;
             }
             case RPC_CMD_BUFFER_GET_BASE: {
-                ok = server.buffer_get_base(input, output);
+                rpc_msg_buffer_get_base_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
+                    return;
+                }
+                rpc_msg_buffer_get_base_rsp response;
+                if (!server.buffer_get_base(request, response)) {
+                    return;
+                }
+                if (!send_msg(sockfd, &response, sizeof(response))) {
+                    return;
+                }
                 break;
             }
             case RPC_CMD_FREE_BUFFER: {
-                ok = server.free_buffer(input);
+                rpc_msg_free_buffer_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
+                    return;
+                }
+                if (!server.free_buffer(request)) {
+                    return;
+                }
+                if (!send_msg(sockfd, nullptr, 0)) {
+                    return;
+                }
                 break;
             }
             case RPC_CMD_BUFFER_CLEAR: {
-                ok = server.buffer_clear(input);
+                rpc_msg_buffer_clear_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
+                    return;
+                }
+                if (!server.buffer_clear(request)) {
+                    return;
+                }
+                if (!send_msg(sockfd, nullptr, 0)) {
+                    return;
+                }
                 break;
             }
             case RPC_CMD_SET_TENSOR: {
-                ok = server.set_tensor(input);
+                std::vector input;
+                if (!recv_msg(sockfd, input)) {
+                    return;
+                }
+                if (!server.set_tensor(input)) {
+                    return;
+                }
+                if (!send_msg(sockfd, nullptr, 0)) {
+                    return;
+                }
                 break;
             }
             case RPC_CMD_GET_TENSOR: {
-                ok = server.get_tensor(input, output);
+                rpc_msg_get_tensor_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
+                    return;
+                }
+                std::vector response;
+                if (!server.get_tensor(request, response)) {
+                    return;
+                }
+                if (!send_msg(sockfd, response.data(), response.size())) {
+                    return;
+                }
                 break;
             }
             case RPC_CMD_COPY_TENSOR: {
-                ok = server.copy_tensor(input, output);
+                rpc_msg_copy_tensor_req request;
+                if (!recv_msg(sockfd, &request, sizeof(request))) {
+                    return;
+                }
+                rpc_msg_copy_tensor_rsp response;
+                if (!server.copy_tensor(request, response)) {
+                    return;
+                }
+                if (!send_msg(sockfd, &response, sizeof(response))) {
+                    return;
+                }
                 break;
             }
             case RPC_CMD_GRAPH_COMPUTE: {
-                ok = server.graph_compute(input, output);
+                std::vector input;
+                if (!recv_msg(sockfd, input)) {
+                    return;
+                }
+                rpc_msg_graph_compute_rsp response;
+                if (!server.graph_compute(input, response)) {
+                    return;
+                }
+                if (!send_msg(sockfd, &response, sizeof(response))) {
+                    return;
+                }
                 break;
             }
             case RPC_CMD_GET_DEVICE_MEMORY: {
-                // output serialization format: | free (8 bytes) | total (8 bytes) |
-                output.resize(2*sizeof(uint64_t), 0);
-                memcpy(output.data(), &free_mem, sizeof(free_mem));
-                memcpy(output.data() + sizeof(uint64_t), &total_mem, sizeof(total_mem));
+                if (!recv_msg(sockfd, nullptr, 0)) {
+                    return;
+                }
+                rpc_msg_get_device_memory_rsp response;
+                response.free_mem = free_mem;
+                response.total_mem = total_mem;
+                if (!send_msg(sockfd, &response, sizeof(response))) {
+                    return;
+                }
                 break;
             }
             default: {
                 fprintf(stderr, "Unknown command: %d\n", cmd);
-                ok = false;
+                return;
             }
         }
-        if (!ok) {
-            break;
-        }
-        uint64_t output_size = output.size();
-        if (!send_data(sockfd, &output_size, sizeof(output_size))) {
-            break;
-        }
-        if (!send_data(sockfd, output.data(), output_size)) {
-            break;
-        }
     }
 }
 
-void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
+void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, size_t free_mem, size_t total_mem) {
     std::string host;
     int port;
     if (!parse_endpoint(endpoint, host, port)) {
@@ -1226,3 +1248,179 @@ void start_rpc_server(ggml_backend_t backend, const char * endpoint, size_t free
     WSACleanup();
 #endif
 }
+
+// device interface
+
+struct ggml_backend_rpc_device_context {
+    std::string endpoint;
+    std::string name;
+};
+
+static const char * ggml_backend_rpc_device_get_name(ggml_backend_dev_t dev) {
+    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
+
+    return ctx->name.c_str();
+}
+
+static const char * ggml_backend_rpc_device_get_description(ggml_backend_dev_t dev) {
+    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
+
+    return ctx->name.c_str();
+}
+
+static void ggml_backend_rpc_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
+
+    ggml_backend_rpc_get_device_memory(ctx->endpoint.c_str(), free, total);
+
+    UNUSED(dev);
+}
+
+static enum ggml_backend_dev_type ggml_backend_rpc_device_get_type(ggml_backend_dev_t dev) {
+    // TODO: obtain value from the server
+    return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
+
+    UNUSED(dev);
+}
+
+static void ggml_backend_rpc_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_rpc_device_get_name(dev);
+    props->description = ggml_backend_rpc_device_get_description(dev);
+    props->type        = ggml_backend_rpc_device_get_type(dev);
+    ggml_backend_rpc_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = {
+        /* .async                 = */ false,
+        /* .host_buffer           = */ false,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_rpc_device_init(ggml_backend_dev_t dev, const char * params) {
+    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
+
+    return ggml_backend_rpc_init(ctx->endpoint.c_str());
+
+    UNUSED(params);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_rpc_device_get_buffer_type(ggml_backend_dev_t dev) {
+    ggml_backend_rpc_device_context * ctx = (ggml_backend_rpc_device_context *)dev->context;
+
+    return ggml_backend_rpc_buffer_type(ctx->endpoint.c_str());
+
+    UNUSED(dev);
+}
+
+static ggml_backend_buffer_t ggml_backend_rpc_device_buffer_from_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    return ggml_backend_cpu_buffer_from_ptr(ptr, size);
+
+    UNUSED(dev);
+    UNUSED(max_tensor_size);
+}
+
+static bool ggml_backend_rpc_device_supports_op(ggml_backend_dev_t dev, const struct ggml_tensor * op) {
+    UNUSED(dev);
+    UNUSED(op);
+    //TODO: call the remote backend and cache the results
+    return true;
+}
+
+static bool ggml_backend_rpc_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    if (!buft || buft->iface.get_name != ggml_backend_rpc_buffer_type_name) {
+        return false;
+    }
+    ggml_backend_rpc_buffer_type_context * buft_ctx = (ggml_backend_rpc_buffer_type_context *)buft->context;
+    ggml_backend_rpc_device_context * dev_ctx = (ggml_backend_rpc_device_context *)dev->context;
+    return buft_ctx->endpoint == dev_ctx->endpoint;
+}
+
+static const struct ggml_backend_device_i ggml_backend_rpc_device_i = {
+    /* .get_name             = */ ggml_backend_rpc_device_get_name,
+    /* .get_description      = */ ggml_backend_rpc_device_get_description,
+    /* .get_memory           = */ ggml_backend_rpc_device_get_memory,
+    /* .get_type             = */ ggml_backend_rpc_device_get_type,
+    /* .get_props            = */ ggml_backend_rpc_device_get_props,
+    /* .init_backend         = */ ggml_backend_rpc_device_init,
+    /* .get_buffer_type      = */ ggml_backend_rpc_device_get_buffer_type,
+    /* .get_host_buffer_type = */ NULL,
+    /* .buffer_from_host_ptr = */ ggml_backend_rpc_device_buffer_from_ptr,
+    /* .supports_op          = */ ggml_backend_rpc_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_rpc_device_supports_buft,
+    /* .offload_op           = */ NULL,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
+
+// backend reg interface
+
+static const char * ggml_backend_rpc_reg_get_name(ggml_backend_reg_t reg) {
+    return "RPC";
+
+    UNUSED(reg);
+}
+
+static size_t ggml_backend_rpc_reg_get_device_count(ggml_backend_reg_t reg) {
+    return 0;
+
+    UNUSED(reg);
+}
+
+static ggml_backend_dev_t ggml_backend_rpc_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    GGML_ABORT("The RPC backend does not have enumerated devices - use ggml_backend_add_device instead");
+
+    UNUSED(reg);
+    UNUSED(index);
+}
+
+static void * ggml_backend_rpc_get_proc_address(ggml_backend_reg_t reg, const char * name) {
+    if (std::strcmp(name, "ggml_backend_rpc_add_device") == 0) {
+        return (void *)ggml_backend_rpc_add_device;
+    }
+    return NULL;
+
+    UNUSED(reg);
+}
+
+static const struct ggml_backend_reg_i ggml_backend_rpc_reg_i = {
+    /* .get_name         = */ ggml_backend_rpc_reg_get_name,
+    /* .get_device_count = */ ggml_backend_rpc_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_rpc_reg_get_device,
+    /* .get_proc_address = */ ggml_backend_rpc_get_proc_address,
+};
+
+ggml_backend_reg_t ggml_backend_rpc_reg(void) {
+    static struct ggml_backend_reg ggml_backend_rpc_reg = {
+        /* .iface   = */ ggml_backend_rpc_reg_i,
+        /* .context = */ NULL,
+    };
+
+    return &ggml_backend_rpc_reg;
+}
+
+ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint) {
+    static std::unordered_map dev_map;
+
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
+
+    if (dev_map.find(endpoint) != dev_map.end()) {
+        return dev_map[endpoint];
+    }
+
+    ggml_backend_rpc_device_context * ctx = new ggml_backend_rpc_device_context {
+        /* .endpoint = */ endpoint,
+        /* .name     = */ "RPC[" + std::string(endpoint) + "]",
+    };
+
+    ggml_backend_dev_t dev = new ggml_backend_device {
+        /* .iface   = */ ggml_backend_rpc_device_i,
+        /* .reg     = */ ggml_backend_rpc_reg(),
+        /* .context = */ ctx,
+    };
+
+    dev_map[endpoint] = dev;
+
+    return dev;
+}
diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp
index 4d3f1c5ce04..4d91ee46086 100644
--- a/ggml/src/ggml-sycl.cpp
+++ b/ggml/src/ggml-sycl.cpp
@@ -40,3070 +40,2877 @@
 #include "ggml-sycl/presets.hpp"
 #include "ggml-sycl/gemm.hpp"
 
-bool   ggml_sycl_loaded(void);
-void   ggml_sycl_free_data(struct ggml_tensor * tensor);
-void   ggml_sycl_copy_to_device(struct ggml_tensor * tensor);
-void   ggml_sycl_set_main_device(int main_device);
-void   ggml_sycl_set_mul_mat_q(bool mul_mat_q);
-void   ggml_sycl_get_device_description(int device, char * description, size_t description_size);
-bool   ggml_backend_is_sycl(ggml_backend_t backend);
-int    ggml_backend_sycl_get_device(ggml_backend_t backend);
-static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
-static inline int get_sycl_env(const char *env_name, int default_val);
-
+static bool g_sycl_loaded = false;
 
-void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
-                    const void *ptr_src, size_t size) {
-    char *host_buf = (char *)malloc(size);
-    q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
-    q_dst.memcpy((char *)ptr_dst, host_buf, size).wait();
-    free(host_buf);
-}
+static ggml_sycl_device_info ggml_sycl_init() {
+    ggml_sycl_device_info info = {};
 
-typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
-typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
-typedef void (*ggml_sycl_op_mul_mat_t)(
-    ggml_backend_sycl_context & ctx,
-    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
-    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
-    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
-    const int64_t src1_ncols, const int64_t src1_padded_row_size,
-    const queue_ptr &stream);
-typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                       const ggml_tensor *src1,
-                                       ggml_tensor *dst, const float *src0_dd,
-                                       const float *src1_dd, float *dst_dd,
-                                       const queue_ptr &main_stream);
+    info.device_count = dpct::dev_mgr::instance().device_count();
+    if (info.device_count == 0) {
+        fprintf(stderr, "%s: failed to initialize " GGML_SYCL_NAME ": %s\n", __func__);
+        return info;
+    }
 
-static __dpct_inline__ float op_repeat(const float a, const float b) {
-    return b;
-    GGML_UNUSED(a);
-}
+    GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
 
-static __dpct_inline__ float op_add(const float a, const float b) {
-    return a + b;
-}
+    int64_t total_vram = 0;
+#if defined(GGML_SYCL_FORCE_MMQ)
+    fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ:   yes\n", __func__);
+#else
+    fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ:   no\n", __func__);
+#endif
+#if defined(SYCL_USE_XMX)
+    fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
+#else
+    fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
+#endif
+    fprintf(stderr, "%s: found %d " GGML_SYCL_NAME " devices:\n", __func__, info.device_count);
 
-static __dpct_inline__ float op_mul(const float a, const float b) {
-    return a * b;
-}
+    for (int i = 0; i < info.device_count; ++i) {
+        info.devices[i].vmm = 0;
+        dpct::device_info prop;
+        SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
+            prop, dpct::dev_mgr::instance().get_device(i))));
 
-static __dpct_inline__ float op_div(const float a, const float b) {
-    return a / b;
-}
+        info.default_tensor_split[i] = total_vram;
+        total_vram += prop.get_global_mem_size();
 
-template
-static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
-        int ne0, int ne1, int ne2, int ne3,
-        int ne10, int ne11, int ne12, int ne13,
-        /*int s0, */ int s1,  int s2,  int s3,
-        /*int s10,*/ int s11, int s12, int s13,
-        const sycl::nd_item<3> &item_ct1) {
-    const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-    const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1));
-    const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                    item_ct1.get_local_id(0)) /
-                   ne3;
-    const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                    item_ct1.get_local_id(0)) %
-                   ne3;
+        info.devices[i].cc =
+            100 * prop.get_major_version() + 10 * prop.get_minor_version();
 
-    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
-        return;
+        info.max_work_group_sizes[i] = prop.get_max_work_group_size();
     }
 
-    const int i11 = i1 % ne11;
-    const int i12 = i2 % ne12;
-    const int i13 = i3 % ne13;
-
-    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
-    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
-    const size_t i_dst  = i_src0;
-
-    const src0_t * src0_row = src0 + i_src0;
-    const src1_t * src1_row = src1 + i_src1;
-    dst_t * dst_row = dst + i_dst;
-
-    for (int i0 = i0s; i0 < ne0;
-         i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
-        const int i10 = i0 % ne10;
-        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+    for (int id = 0; id < info.device_count; ++id) {
+        info.default_tensor_split[id] /= total_vram;
     }
+    return info;
 }
 
-template
-static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
-        int ne0, int ne1, int ne2, int ne3,
-        int ne10, int ne11, int ne12, int ne13,
-        /*int s0, */ int s1,  int s2,  int s3,
-        /*int s10,*/ int s11, int s12, int s13,
-        const sycl::nd_item<3> &item_ct1) {
-
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+const ggml_sycl_device_info & ggml_sycl_info() {
+    static ggml_sycl_device_info info = ggml_sycl_init();
+    return info;
+}
 
-    const int i3 = i/(ne2*ne1*ne0);
-    const int i2 = (i/(ne1*ne0)) % ne2;
-    const int i1 = (i/ne0) % ne1;
-    const int i0 = i % ne0;
+void print_device_detail(int id, sycl::device &device, std::string device_type) {
 
-    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
-        return;
-    }
+    dpct::device_info prop;
+    SYCL_CHECK(CHECK_TRY_ERROR(
+        dpct::get_device_info(prop, device)));
 
-    const int i11 = i1 % ne11;
-    const int i12 = i2 % ne12;
-    const int i13 = i3 % ne13;
+    std::string version;
+    version += std::to_string(prop.get_major_version());
+    version += ".";
+    version += std::to_string(prop.get_minor_version());
 
-    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
-    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
-    const size_t i_dst  = i_src0;
+    device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
+    std::string name = std::string(prop.get_name());
+    name = std::regex_replace(name, std::regex("\\(R\\)"), "");
+    name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
 
-    const src0_t * src0_row = src0 + i_src0;
-    const src1_t * src1_row = src1 + i_src1;
-    dst_t * dst_row = dst + i_dst;
+    auto global_mem_size = prop.get_global_mem_size()/1000000;
 
-    const int i10 = i0 % ne10;
-    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+    fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
+            name.c_str(), version.c_str(), prop.get_max_compute_units(),
+            prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
+            global_mem_size, device.get_info().c_str());
 }
 
-static void acc_f32(const float * x, const float * y, float * dst, const int ne,
-    const int ne10, const int ne11, const int ne12,
-    const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= ne) {
-        return;
+void ggml_backend_sycl_print_sycl_devices() {
+    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
+    int device_count = dpct::dev_mgr::instance().device_count();
+    std::map DeviceNums;
+    fprintf(stderr, "found %d SYCL devices:\n", device_count);
+    fprintf(stderr, "|  |                   |                                       |       |Max    |        |Max  |Global |                     |\n");
+    fprintf(stderr, "|  |                   |                                       |       |compute|Max work|sub  |mem    |                     |\n");
+    fprintf(stderr, "|ID|        Device Type|                                   Name|Version|units  |group   |group|size   |       Driver version|\n");
+    fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n");
+    for (int id = 0; id < device_count; ++id) {
+        sycl::device device = dpct::dev_mgr::instance().get_device(id);
+        sycl::backend backend = device.get_backend();
+        std::string backend_type = get_device_backend_and_type(device);
+        int type_id=DeviceNums[backend_type]++;
+        std::stringstream device_type;
+        device_type << "[" <<  backend_type << ":" << std::to_string(type_id) << "]";
+        print_device_detail(id, device, device_type.str());
     }
-    int src1_idx = i - offset;
-    int oz = src1_idx / nb2;
-    int oy = (src1_idx - (oz * nb2)) / nb1;
-    int ox = src1_idx % nb1;
-    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
-        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+}
+
+static inline int get_sycl_env(const char *env_name, int default_val) {
+    char *user_device_string = getenv(env_name);
+    int user_number = default_val;
+
+    unsigned n;
+    if (user_device_string != NULL &&
+        sscanf(user_device_string, " %u", &n) == 1) {
+        user_number = (int)n;
     } else {
-        dst[i] = x[i];
+        user_number = default_val;
     }
+    return user_number;
 }
 
-static void gelu_f32(const float * x, float * dst, const int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const float GELU_COEF_A    = 0.044715f;
-    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+static void ggml_check_sycl() try {
+    static bool initialized = false;
 
-    if (i >= k) {
-        return;
-    }
+    if (!initialized) {
+        fprintf(stderr, "[SYCL] call ggml_check_sycl\n");
+        g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
 
-    float xi = x[i];
-    dst[i] = 0.5f * xi *
-             (1.0f +
-              sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
-}
+        fprintf(stderr, "%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
 
-static void silu_f32(const float * x, float * dst, const int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+#if defined(GGML_SYCL_F16)
+        fprintf(stderr, "%s: GGML_SYCL_F16: yes\n", __func__);
+#else
+        fprintf(stderr, "%s: GGML_SYCL_F16: no\n", __func__);
+#endif
 
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
-}
+/* NOT REMOVE, keep it for next optimize for XMX.
+#if defined(SYCL_USE_XMX)
+        fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
+#else
+        fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
+#endif
+*/
 
-static void gelu_quick_f32(const float *x, float *dst, int k,
-                           const sycl::nd_item<3> &item_ct1) {
-    const float GELU_QUICK_COEF = -1.702f;
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= k) {
-        return;
+        if (CHECK_TRY_ERROR(g_all_sycl_device_count =
+                            dpct::dev_mgr::instance().device_count()) != 0) {
+            initialized = true;
+            g_sycl_loaded = false;
+            return;
+        }
+        GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
+        ggml_backend_sycl_print_sycl_devices();
+        initialized = true;
+        g_sycl_loaded = true;
     }
-    dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
 }
-
-static void tanh_f32(const float *x, float *dst, int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::tanh((float)(x[i]));
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
-static void relu_f32(const float * x, float * dst, const int k,
-                     const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::fmax((float)(x[i]), (float)0);
+/*
+device_index: device index from 0 to n (continue numbers).
+    It is used for device select/set in SYCL backend internal data structure.
+*/
+inline void check_allow_gpu_index(const int device_index) {
+  if (device_index >= ggml_sycl_info().device_count) {
+    char error_buf[256];
+    snprintf(
+        error_buf,
+        sizeof(error_buf),
+        "%s error: device_index:%d is out of range: [0-%d]",
+        __func__,
+        device_index,
+        ggml_sycl_info().device_count - 1);
+    fprintf(stderr, "%s\n", error_buf);
+    assert(false);
+  }
 }
 
-static void hardsigmoid_f32(const float * x, float * dst, const int k,
-                            const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+GGML_API void ggml_backend_sycl_get_gpu_list(int *id_list, int max_len) try {
+    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_gpu_list\n");
+    for(int i=0;i= k) {
-        return;
+    for (int i=0;i< ggml_sycl_info().device_count;i++){
+        if (i>=max_len) break;
+        id_list[i] = i;
     }
-    dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
+    return;
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
-static void hardswish_f32(const float * x, float * dst, const int k,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+// sycl buffer
 
-    if (i >= k) {
-        return;
-    }
-    dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
-}
+struct ggml_backend_sycl_buffer_context {
+    int device;
+    void * dev_ptr = nullptr;
+    queue_ptr stream;
+    std::string name;
 
-static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
-                           const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-    if (i >= k) {
-        return;
-    }
-    dst[i] = sycl::fmax((float)(x[i]), (float)0) +
-             sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
-}
+     ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
+        device(device), dev_ptr(dev_ptr), stream(stream) {
+            check_allow_gpu_index(device);
+            name = (GGML_SYCL_NAME + std::to_string(device));
+        }
 
-static void sqr_f32(const float * x, float * dst, const int k,
-                    const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
 
-    if (i >= k) {
-        return;
+    ~ggml_backend_sycl_buffer_context() {
+        if (dev_ptr != nullptr) {
+            ggml_sycl_set_device(device);
+            SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
+        }
     }
-    dst[i] = x[i] * x[i];
+};
+
+static const char * ggml_backend_sycl_buffer_get_name(ggml_backend_buffer_t buffer) {
+    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
+    return ctx->name.c_str();
 }
 
-static void upscale_f32(const float  *x, float *dst, const int nb00, const int nb01,
-                        const int nb02, const int nb03, const int ne10, const int ne11,
-                        const int ne12, const int ne13, const float sf0, const float sf1,
-                        const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
-    int index = item_ct1.get_local_id(0) +
-               item_ct1.get_group(0) * item_ct1.get_local_range(0);
-    if (index >= ne10 * ne11 * ne12 * ne13) {
-        return;
-    }
-    // operation
-    int i10 = index % ne10;
-    int i11 = (index / ne10) % ne11;
-    int i12 = (index / (ne10 * ne11)) % ne12;
-    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
+static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) {
+    return buffer->iface.get_name == ggml_backend_sycl_buffer_get_name;
+}
 
-    int i00 = i10 / sf0;
-    int i01 = i11 / sf1;
-    int i02 = i12 / sf2;
-    int i03 = i13 / sf3;
+static void
+ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
+    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
+    ggml_sycl_set_device(ctx->device);
 
-    dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
+    delete ctx;
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
-static void pad_f32(const float  *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
-                    const sycl::nd_item<3> &item_ct1) {
-    int nidx = item_ct1.get_local_id(2) +
-               item_ct1.get_group(2) * item_ct1.get_local_range(2);
-    if (nidx >= ne0) {
-        return;
-    }
-
-    // operation
-    int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
-                     item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
-    if (nidx < ne00 && item_ct1.get_group(1) < ne01 &&
-        item_ct1.get_group(0) < ne02) {
-        int offset_src = nidx + item_ct1.get_group(1) * ne00 +
-                         item_ct1.get_group(0) * ne00 * ne01;
-            dst[offset_dst] = x[offset_src];
-    } else {
-        dst[offset_dst] = 0.0f;
-    }
+static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
+    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
+    return ctx->dev_ptr;
 }
 
-template
-static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
-                          const sycl::nd_item<3> &item_ct1) {
-    const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
+static void
+ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
+                                     ggml_tensor *tensor) try {
+    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
 
-    if (ix >= kx_padded) {
+    if (tensor->view_src != NULL && tensor->view_offs == 0) {
+        assert(tensor->view_src->buffer->buft == buffer->buft);
+        tensor->backend = tensor->view_src->backend;
+        tensor->extra = tensor->view_src->extra;
         return;
     }
 
-    const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                   item_ct1.get_local_id(1);
-
-    const int i_padded = iy*kx_padded + ix;
-
-    block_q8_1 * y = (block_q8_1 *) vy;
 
-    const int ib = i_padded / QK8_1; // block index
-    const int iqs = i_padded % QK8_1; // quant index
-    typedef  sycl::vec TC;
-    typedef  sycl::vec TQ;
-    TC zeros;
-    TQ qzeros;
-#pragma unroll
-    for (int i = 0; i < QUANT_BLOCK_TILE; i++)
-    {
-        zeros[i] = 0.f;
-        qzeros[i] = 0;
-    }
-    const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros;
-    float sum = xi[0];
-    float amax = sycl::fabs(xi[0]);
-#pragma unroll
-    for (int i = 1; i < QUANT_BLOCK_TILE; i++)
-    {
-        sum += xi[i];
-        amax = sycl::fmax(sycl::fabs(xi[i]), amax);
-    }
-    sum = warp_reduce_sum(sum, item_ct1);
-    amax = warp_reduce_max(amax, item_ct1);
+    if (ggml_is_quantized(tensor->type)) {
+        // initialize padding to 0 to avoid possible NaN values
+        size_t original_size = ggml_nbytes(tensor);
+        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
 
-    const float d = amax / 127;
-    TQ q = qzeros;
-    if (amax != 0.0f)
-    {
-#pragma unroll
-        for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
-            q[i] = sycl::round(xi[i] / d);
+        if (padded_size > original_size && tensor->view_src == nullptr) {
+            SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset(
+                (char *)tensor->data + original_size, 0,
+                padded_size - original_size).wait()));
         }
     }
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
 
-    *(TQ *)&y[ib].qs[iqs] = q;
+static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
+                                                ggml_tensor *tensor,
+                                                const void *data, size_t offset,
+                                                size_t size) try {
 
-    if (iqs > 0) {
-        return;
-    }
+    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
 
-    reinterpret_cast(y[ib].ds.x()) = d;
-    reinterpret_cast(y[ib].ds.y()) = sum;
+    ggml_sycl_set_device(ctx->device);
+    auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
+    char* host_buf = (char*)malloc(size);
+    memcpy(host_buf, data, size);
+    SYCL_CHECK(
+        CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
+                             .wait()));
+    free(host_buf);
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
-template
-static void k_get_rows(
-            const void * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12,
-            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
-
-    const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                     item_ct1.get_local_id(2)) *
-                    2;
-    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) /
-                    ne12;
-    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) %
-                    ne12;
-
-    if (i00 >= ne00) {
-        return;
-    }
-
-    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
+                                                const ggml_tensor *tensor,
+                                                void *data, size_t offset,
+                                                size_t size) try {
 
-    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
+    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
 
-    const int ib = i00/qk; // block index
-    const int iqs = (i00%qk)/qr; // quant index
-    const int iybs = i00 - i00%qk; // dst block start index
-    const int y_offset = qr == 1 ? 1 : qk/2;
+    ggml_sycl_set_device(ctx->device);
+    auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue();
 
-    // dequantize
-    dfloat2 v;
-    dequantize_kernel(src0_row, ib, iqs, v);
+    SYCL_CHECK(CHECK_TRY_ERROR(
+        stream.memcpy(data, (const char *)tensor->data + offset, size)
+            .wait()));
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
 
-    dst_row[iybs + iqs + 0] = v.x();
-    dst_row[iybs + iqs + y_offset] = v.y();
+void dev2dev_memcpy(sycl::queue &q_dst, sycl::queue &q_src, void *ptr_dst,
+                    const void *ptr_src, size_t size) {
+    char *host_buf = (char *)malloc(size);
+    q_src.memcpy(host_buf, (const char *)ptr_src, size).wait();
+    q_dst.memcpy((char *)ptr_dst, host_buf, size).wait();
+    free(host_buf);
 }
 
-template
-static void k_get_rows_float(
-            const src0_t * src0, const int32_t * src1, dst_t * dst,
-            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
-            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
-            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
-            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
-            size_t s10, size_t s11, size_t s12,
-            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
+static bool
+ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
+                                    const ggml_tensor *src,
+                                    ggml_tensor *dst) try {
+    if (ggml_backend_buffer_is_sycl(src->buffer)) {
+        ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
+        ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
 
-    const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                    item_ct1.get_local_id(2);
-    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) /
-                    ne12;
-    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
-                     item_ct1.get_local_id(0)) %
-                    ne12;
+        ggml_sycl_set_device(src_ctx->device);
+        /*
+        DPCT1009:198: SYCL uses exceptions to report errors and does not use the
+        error codes. The original code was commented out and a warning string
+        was inserted. You need to rewrite this code.
+        */
+        SYCL_CHECK(CHECK_TRY_ERROR(
+            dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw()));
+        ggml_sycl_set_device(dst_ctx->device);
+        /*
+        DPCT1009:199: SYCL uses exceptions to report errors and does not use the
+        error codes. The original code was commented out and a warning string
+        was inserted. You need to rewrite this code.
+        */
+        SYCL_CHECK(CHECK_TRY_ERROR(
+            dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
+        /*
+        DPCT1009:200: SYCL uses exceptions to report errors and does not use the
+        error codes. The original code was commented out and a warning string
+        was inserted. You need to rewrite this code.
+        */
 
-    if (i00 >= ne00) {
-        return;
-    }
+        queue_ptr stream_dst = dst_ctx->stream;
+        queue_ptr stream_src = src_ctx->stream;
+        size_t size = ggml_nbytes(src);
 
-    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+        //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs.
+        dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size);
 
-    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
-    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+//todo, it's known issue:error in device2device cross GPUs. reused when the issue is fixed. DON"T remove
+#if 0
+        SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(
+            (char *)dst->data, (const char *)src->data, size).wait()));
 
-    dst_row[i00] = src0_row[i00];
+        /*
+        DPCT1009:201: SYCL uses exceptions to report errors and does not use the
+        error codes. The original code was commented out and a warning string
+        was inserted. You need to rewrite this code.
+        */
+        SYCL_CHECK(CHECK_TRY_ERROR(
+            dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
+#endif
+        return true;
+    }
+    return false;
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
-static void mul_mat_p021_f16_f32(
-    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
-    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
-    const sycl::nd_item<3> &item_ct1) {
-
-    const sycl::half *x = (const sycl::half *)vx;
 
-    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                      item_ct1.get_local_id(1);
-    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                        item_ct1.get_local_id(0);
-    const int channel_x = channel / (nchannels_y / nchannels_x);
+static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
+                                           uint8_t value) try {
+     ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
 
-    const int nrows_y = ncols_x;
-    const int nrows_dst = nrows_x;
-    const int row_dst = row_x;
+    ggml_sycl_set_device(ctx->device);
+    queue_ptr stream = ctx->stream;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
 
-    float tmp = 0.0f;
+    SYCL_CHECK(CHECK_TRY_ERROR((*stream)
+                                    .memset(ctx->dev_ptr, value, buffer->size)
+                                    .wait()));
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
 
-    for (int col_x0 = 0; col_x0 < ncols_x;
-         col_x0 += item_ct1.get_local_range(2)) {
-        const int col_x = col_x0 + item_ct1.get_local_id(2);
+static const ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
+    /* .get_name        = */ ggml_backend_sycl_buffer_get_name,
+    /* .free_buffer     = */ ggml_backend_sycl_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_sycl_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_sycl_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
+    /* .set_tensor      = */ ggml_backend_sycl_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_sycl_buffer_get_tensor,
+    /* .cpy_tensor      = */ ggml_backend_sycl_buffer_cpy_tensor,
+    /* .clear           = */ ggml_backend_sycl_buffer_clear,
+    /* .reset           = */ NULL,
+};
 
-        if (col_x >= ncols_x) {
-            break;
-        }
+// sycl buffer type
+struct ggml_backend_sycl_buffer_type_context {
+    int device;
+    std::string name;
 
-        // x is transposed and permuted
-        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
-        const float xi =
-            sycl::vec(x[ix])
-                .convert()[0];
+    // each buffer type has its own stream
+    queue_ptr stream = nullptr;
+};
 
-        const int row_y = col_x;
+static const char * ggml_backend_sycl_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
 
+    return ctx->name.c_str();
+}
 
-        // y is not transposed but permuted
-        const int iy = channel*nrows_y + row_y;
+static ggml_backend_buffer_t
+ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
+                                           size_t size) try {
+    ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
+    ggml_sycl_set_device(buft_ctx->device);
+    const queue_ptr stream = buft_ctx->stream;
+    size = std::max(size, (size_t)1); // syclMalloc returns null for size 0
 
-        tmp += xi * y[iy];
+    void * dev_ptr;
+    SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device(
+                                    size, *stream)));
+    if (!dev_ptr) {
+        fprintf(stderr, "%s: can't malloc %lu Bytes memory on device", __func__, size);
+        return nullptr;
     }
+    ggml_backend_sycl_buffer_context * ctx = new  ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream);
+    return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size);
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
 
-    // dst is not transposed and not permuted
-    const int idst = channel*nrows_dst + row_dst;
+static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return 128;
+    GGML_UNUSED(buft);
+}
 
-    // sum up partial sums and write back result
-#pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
+static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
+    return dpct::get_current_device().get_max_mem_alloc_size();
 
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[idst] = tmp;
-    }
+    GGML_UNUSED(buft);
 }
 
-static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
-    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
-    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
-    const sycl::nd_item<3> &item_ct1) {
+static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+    size_t size = ggml_nbytes(tensor);
+    int64_t ne0 = tensor->ne[0];
 
-    const sycl::half *x = (const sycl::half *)vx;
+    if (ggml_is_quantized(tensor->type)) {
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
+    }
 
-    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                      item_ct1.get_local_id(1);
-    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
-                        item_ct1.get_local_id(0);
-    const int channel_x = channel / channel_x_divisor;
+    return size;
 
-    const int nrows_y   = ncols_x;
-    const int nrows_dst = nrows_x;
-    const int row_dst   = row_x;
+    GGML_UNUSED(buft);
+}
 
-    const int idst = channel*nrows_dst + row_dst;
-
-    float tmp = 0.0f;
-
-    for (int col_x0 = 0; col_x0 < ncols_x;
-         col_x0 += item_ct1.get_local_range(2)) {
-        const int col_x = col_x0 + item_ct1.get_local_id(2);
-
-        if (col_x >= ncols_x) {
-            break;
-        }
+static const ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_sycl_buffer_type_get_name,
+    /* .alloc_buffer     = */ ggml_backend_sycl_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_sycl_buffer_type_get_alignment,
+    /* .get_max_size     = */ ggml_backend_sycl_buffer_type_get_max_size,
+    /* .get_alloc_size   = */ ggml_backend_sycl_buffer_type_get_alloc_size,
+    /* .is_host          = */ NULL,
+};
 
-        const int row_y = col_x;
+ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
 
-        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
-        const int iy = channel*nrows_y + row_y;
+    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
 
-        const float xi =
-            sycl::vec(x[ix])
-                .convert()[0];
+    auto dev_count = ggml_backend_sycl_get_device_count();
 
-        tmp += xi * y[iy];
+    if (device>=dev_count or device<0) {
+        printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
+            device, dev_count-1);
+        GGML_ASSERT(device 0; mask >>= 1) {
-        tmp +=
-            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
-    }
+    static bool ggml_backend_sycl_buffer_type_initialized = false;
 
-    if (item_ct1.get_local_id(2) == 0) {
-        dst[idst] = tmp;
+    if (!ggml_backend_sycl_buffer_type_initialized) {
+        for (int i = 0; i < dev_count; i++) {
+            auto & device_i = dpct::dev_mgr::instance().get_device(i);
+            queue_ptr stream = &(device_i.default_queue());
+            ggml_backend_sycl_buffer_types[i] = {
+                /* .iface    = */ ggml_backend_sycl_buffer_type_interface,
+                /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), i),
+                /* .context  = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), stream},
+            };
+        }
+        ggml_backend_sycl_buffer_type_initialized = true;
     }
+    return &ggml_backend_sycl_buffer_types[device];
 }
 
-static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    float * dsti = (float *) cdsti;
+ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
+    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
 
-    *dsti = *xi;
-}
+    int device = ctx->device;
+    if (device>=ggml_sycl_info().device_count or device<0) {
+        printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
+            device, ggml_sycl_info().device_count-1);
+        GGML_ASSERT(device(*xi)
-                .convert()[0];
+    if (!ggml_backend_sycl_buffer_type_initialized) {
+        for (int i = 0; i < ggml_sycl_info().device_count; i++) {
+            ggml_backend_sycl_buffer_types[i] = {
+                /* .iface    = */ ggml_backend_sycl_buffer_type_interface,
+                /* .device   = */ nullptr,
+                /* .context  = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
+            };
+        }
+        ggml_backend_sycl_buffer_type_initialized = true;
+    }
+    return &ggml_backend_sycl_buffer_types[device];
 }
 
-static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
-    const sycl::half *xi = (const sycl::half *)cxi;
-    sycl::half *dsti = (sycl::half *)cdsti;
-
-    *dsti = *xi;
-}
+// sycl split buffer
 
-static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
-    const sycl::half *xi = (const sycl::half *)cxi;
-    float * dsti = (float *) cdsti;
+static int64_t get_row_rounding(ggml_type type, const std::array & tensor_split) {
+    int64_t min_compute_capability = INT_MAX;
+    int64_t max_compute_capability = INT_MIN;
+    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+        if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) {
+            if (min_compute_capability > ggml_sycl_info().devices[i].cc) {
+                min_compute_capability = ggml_sycl_info().devices[i].cc;
+            }
+            if (max_compute_capability < ggml_sycl_info().devices[i].cc) {
+                max_compute_capability = ggml_sycl_info().devices[i].cc;
+            }
+        }
+    }
 
-    *dsti = *xi;
+    switch(type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+            return max_compute_capability >= VER_GEN9 ? 128 : 64;
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+            return 64;
+        case GGML_TYPE_F16:
+        case GGML_TYPE_F32:
+            return 1;
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_IQ2_XXS:
+        case GGML_TYPE_IQ2_XS:
+        case GGML_TYPE_IQ2_S:
+        case GGML_TYPE_IQ1_S:
+        case GGML_TYPE_IQ1_M:
+        case GGML_TYPE_IQ3_XXS:
+        case GGML_TYPE_IQ4_XS:
+        case GGML_TYPE_IQ4_NL:
+            return max_compute_capability >= VER_GEN9 ? 128 : 64;
+        case GGML_TYPE_IQ3_S:
+            return max_compute_capability >= VER_GEN9 ? 128 : 64;
+        case GGML_TYPE_Q6_K:
+            return 64;
+        default:
+            GGML_ABORT("fatal error");
+    }
 }
 
-static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
-    const int16_t *xi = (const int16_t *)cxi;
-    int16_t *dsti = (int16_t *)cdsti;
+static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array & tensor_split, int id) {
+    const int64_t nrows = ggml_nrows(tensor);
+    const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
 
-    *dsti = *xi;
+    *row_low = id == 0 ? 0 : nrows*tensor_split[id];
+    *row_low -= *row_low % rounding;
+    if (id == ggml_sycl_info().device_count - 1) {
+        *row_high = nrows;
+    } else {
+        *row_high = nrows*tensor_split[id + 1];
+        *row_high -= *row_high % rounding;
+    }
 }
 
-static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
-    const int32_t *xi = (const int32_t *)cxi;
-    int32_t *dsti = (int32_t *)cdsti;
+static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
-    *dsti = *xi;
+    return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
 }
 
-template 
-static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
-                        const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                        const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                        const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+struct ggml_backend_sycl_split_buffer_type_context {
+    std::array tensor_split;
+};
 
-    if (i >= ne) {
-        return;
+struct ggml_backend_sycl_split_buffer_context {
+    ~ggml_backend_sycl_split_buffer_context() try {
+        for (ggml_tensor_extra_gpu * extra : tensor_extras) {
+            for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+                for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
+                    if (extra->events[i][is] != nullptr) {
+                        /*
+                        DPCT1009:206: SYCL uses exceptions to report errors and
+                        does not use the error codes. The original code was
+                        commented out and a warning string was inserted. You
+                        need to rewrite this code.
+                        */
+                        SYCL_CHECK(CHECK_TRY_ERROR(
+                            dpct::destroy_event(extra->events[i][is])));
+                    }
+                }
+                if (extra->data_device[i] != nullptr) {
+                    /*
+                    DPCT1009:207: SYCL uses exceptions to report errors and does
+                    not use the error codes. The original code was commented out
+                    and a warning string was inserted. You need to rewrite this
+                    code.
+                    */
+                    ggml_sycl_set_device(i);
+                    SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
+                        extra->data_device[i], *(streams[i]))));
+                }
+            }
+            delete extra;
+        }
+    }
+    catch (sycl::exception const &exc) {
+      std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+                << ", line:" << __LINE__ << std::endl;
+      std::exit(1);
     }
 
-    // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
-    // then combine those indices with the corresponding byte offsets to get the total offsets
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+    std::vector tensor_extras;
+    std::vector streams;
+};
 
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
+static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backend_buffer_t buffer) {
+    return GGML_SYCL_NAME "_Split";
 
-    cpy_1(cx + x_offset, cdst + dst_offset);
+    GGML_UNUSED(buffer);
 }
 
-static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q8_0 * dsti = (block_q8_0 *) cdsti;
+static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
+   return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
+}
 
-    float amax = 0.0f; // absolute max
+static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
+    delete ctx;
+}
 
-    for (int j = 0; j < QK8_0; j++) {
-        const float v = xi[j];
-        amax = sycl::fmax(amax, sycl::fabs((float)v));
-    }
+static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
+    // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
+    return (void *)0x1000;
 
-    const float d = amax / ((1 << 7) - 1);
-    const float id = d ? 1.0f/d : 0.0f;
+    GGML_UNUSED(buffer);
+}
 
-    dsti->d = d;
+static void
+ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
+                                           ggml_tensor *tensor) try {
+    GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
 
-    for (int j = 0; j < QK8_0; ++j) {
-        const float x0 = xi[j]*id;
+    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
+    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
 
-        dsti->qs[j] = sycl::round((float)x0);
-    }
-}
+    const int64_t ne0 = tensor->ne[0];
 
-static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q4_0 * dsti = (block_q4_0 *) cdsti;
+    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
 
-    float amax = 0.0f;
-    float vmax = 0.0f;
+    ctx->tensor_extras.push_back(extra);
+        ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
 
-    for (int j = 0; j < QK4_0; ++j) {
-        const float v = xi[j];
-        if (amax < sycl::fabs((float)v)) {
-            amax = sycl::fabs((float)v);
-            vmax = v;
+    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
+
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
         }
-    }
 
-    const float d  = vmax / -8;
-    const float id = d ? 1.0f/d : 0.0f;
+        size_t size = ggml_nbytes_split(tensor, nrows_split);
+        const size_t original_size = size;
 
-    dsti->d = d;
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
 
-    for (int j = 0; j < QK4_0/2; ++j) {
-        const float x0 = xi[0       + j]*id;
-        const float x1 = xi[QK4_0/2 + j]*id;
+        // FIXME: do not crash if cudaMalloc fails
+        // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
+        ggml_sycl_set_device(i);
+        const queue_ptr stream = ctx->streams[i];
+        char * buf;
+        /*
+        DPCT1009:208: SYCL uses exceptions to report errors and does not use the
+        error codes. The original code was commented out and a warning string
+        was inserted. You need to rewrite this code.
+        */
+        SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device(
+                                        size, *stream)));
+        if (!buf) {
+            char err_buf[1024];
+            snprintf(err_buf, 1023, "%s: can't malloc %lu Bytes memory on device", __func__, size);
+            throw std::runtime_error(err_buf);
+        }
+        // set padding to 0 to avoid possible NaN values
+        if (size > original_size) {
+            /*
+            DPCT1009:209: SYCL uses exceptions to report errors and does not use
+            the error codes. The original code was commented out and a warning
+            string was inserted. You need to rewrite this code.
+            */
+            SYCL_CHECK(CHECK_TRY_ERROR(
+                (*stream)
+                    .memset(buf + original_size, 0, size - original_size)
+                    .wait()));
+        }
 
-        const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
-        const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
+        extra->data_device[i] = buf;
 
-        dsti->qs[j]  = xi0;
-        dsti->qs[j] |= xi1 << 4;
+        for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
+            /*
+            DPCT1009:210: SYCL uses exceptions to report errors and does not use
+            the error codes. The original code was commented out and a warning
+            string was inserted. You need to rewrite this code.
+            */
+            SYCL_CHECK(
+                CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
+        }
     }
+    tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
+    tensor->extra = extra;
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
-static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
-    const float * xi = (const float *) cxi;
-    block_q4_1 * dsti = (block_q4_1 *) cdsti;
-
-    float vmin = FLT_MAX;
-    float vmax = -FLT_MAX;
+static void
+ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
+                                          ggml_tensor *tensor, const void *data,
+                                          size_t offset, size_t size) try {
+    // split tensors must always be set in their entirety at once
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
 
-    for (int j = 0; j < QK4_1; ++j) {
-        const float v = xi[j];
+    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
+    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
 
-        if (v < vmin) vmin = v;
-        if (v > vmax) vmax = v;
-    }
+    const int64_t ne0 = tensor->ne[0];
+    const size_t nb1 = tensor->nb[1];
+    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
 
-    const float d  = (vmax - vmin) / ((1 << 4) - 1);
-    const float id = d ? 1.0f/d : 0.0f;
+    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
 
-    dsti->dm.x() = d;
-    dsti->dm.y() = vmin;
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
 
-    for (int j = 0; j < QK4_1/2; ++j) {
-        const float x0 = (xi[0       + j] - vmin)*id;
-        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
+        const size_t offset_split = row_low*nb1;
+        size_t size = ggml_nbytes_split(tensor, nrows_split);
+        const size_t original_size = size;
 
-        const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
-        const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
 
-        dsti->qs[j]  = xi0;
-        dsti->qs[j] |= xi1 << 4;
+        const char * buf_host = (const char *)data + offset_split;
+        /*
+        DPCT1009:211: SYCL uses exceptions to report errors and does not use the
+        error codes. The original code was commented out and a warning string
+        was inserted. You need to rewrite this code.
+        */
+        ggml_sycl_set_device(i);
+        const queue_ptr stream = ctx->streams[i];
+        SYCL_CHECK(CHECK_TRY_ERROR(
+            (*stream)
+                .memcpy(extra->data_device[i], buf_host, original_size)
+                .wait()));
     }
 }
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
 
-template 
-static void cpy_f32_q(const char * cx, char * cdst, const int ne,
-                      const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
-                      const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
-                      const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
-    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                   item_ct1.get_local_id(2)) *
-                  qk;
-
-    if (i >= ne) {
-        return;
-    }
+static void
+ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
+                                          const ggml_tensor *tensor, void *data,
+                                          size_t offset, size_t size) try {
+    // split tensors must always be set in their entirety at once
+    GGML_ASSERT(offset == 0);
+    GGML_ASSERT(size == ggml_nbytes(tensor));
 
-    const int i03 = i/(ne00 * ne01 * ne02);
-    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
-    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
-    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
-    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
+    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
+    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
 
-    const int i13 = i/(ne10 * ne11 * ne12);
-    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
-    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
-    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
-    const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
+    const int64_t ne0 = tensor->ne[0];
+    const size_t nb1 = tensor->nb[1];
+    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
 
-    cpy_blck(cx + x_offset, cdst + dst_offset);
-}
+    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
 
-static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
-                           const sycl::nd_item<3> &item_ct1) {
-    const int row = item_ct1.get_group(1);
-    const int col = item_ct1.get_local_id(2);
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
 
-    float sum = 0.0f;
-    for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
-        sum += x[row * ncols + i];
-    }
+        const size_t offset_split = row_low*nb1;
+        size_t size = ggml_nbytes_split(tensor, nrows_split);
+        const size_t original_size = size;
 
-    sum = warp_reduce_sum(sum, item_ct1);
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
 
-    if (col == 0) {
-        dst[row] = sum;
+        char * buf_host = (char *)data + offset_split;
+        /*
+        DPCT1009:212: SYCL uses exceptions to report errors and does not use the
+        error codes. The original code was commented out and a warning string
+        was inserted. You need to rewrite this code.
+        */
+        ggml_sycl_set_device(i);
+        const queue_ptr stream = ctx->streams[i];
+        SYCL_CHECK(CHECK_TRY_ERROR(
+            (*stream)
+                .memcpy(buf_host, extra->data_device[i], original_size)
+                .wait()));
     }
 }
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
 
-
-template
-static inline void ggml_sycl_swap(T & a, T & b) {
-    T tmp = a;
-    a = b;
-    b = tmp;
+static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
+    GGML_UNUSED(buffer);
+    GGML_UNUSED(value);
 }
 
-template 
-__dpct_inline__ static void
-k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
-                  const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
-    // bitonic sort
-    int col = item_ct1.get_local_id(2);
-    int row = item_ct1.get_group(1);
+static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {
+    /* .get_name        = */ ggml_backend_sycl_split_buffer_get_name,
+    /* .free_buffer     = */ ggml_backend_sycl_split_buffer_free_buffer,
+    /* .get_base        = */ ggml_backend_sycl_split_buffer_get_base,
+    /* .init_tensor     = */ ggml_backend_sycl_split_buffer_init_tensor,
+    /* .memset_tensor   = */ NULL,
+    /* .set_tensor      = */ ggml_backend_sycl_split_buffer_set_tensor,
+    /* .get_tensor      = */ ggml_backend_sycl_split_buffer_get_tensor,
+    /* .cpy_tensor      = */ NULL,
+    /* .clear           = */ ggml_backend_sycl_split_buffer_clear,
+    /* .reset           = */ NULL,
+};
 
-    if (col >= ncols_pad) {
-        return;
-    }
+// sycl split buffer type
 
-    const float * x_row = x + row * ncols;
-    auto dst_row = (int *)dpct_local;
+static const char * ggml_backend_sycl_split_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
+    return GGML_SYCL_NAME "_Split";
 
-    // initialize indices
-    dst_row[col] = col;
+    GGML_UNUSED(buft);
+}
 
-    item_ct1.barrier(sycl::access::fence_space::local_space);
+static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
+    // instead, we allocate them for each tensor separately in init_tensor
+    // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
+    // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
+    ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context();
 
-    for (int k = 2; k <= ncols_pad; k *= 2) {
-        for (int j = k / 2; j > 0; j /= 2) {
-            int ixj = col ^ j;
-            if (ixj > col) {
-                if ((col & k) == 0) {
-                    if (dst_row[col] >= ncols ||
-                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
-                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
-                    ) {
-                        ggml_sycl_swap(dst_row[col], dst_row[ixj]);
-                    }
-                } else {
-                    if (dst_row[ixj] >= ncols ||
-                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
-                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
-                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
-                    ) {
-                        ggml_sycl_swap(dst_row[col], dst_row[ixj]);
-                    }
-                }
-            }
-            /*
-            DPCT1118:1: SYCL group functions and algorithms must be encountered
-            in converged control flow. You may need to adjust the code.
-            */
-            item_ct1.barrier(sycl::access::fence_space::local_space);
-        }
-    }
+    return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size);
+}
 
-    // copy the result to dst without the padding
-    if (col < ncols) {
-        dst[row * ncols + col] = dst_row[col];
-    }
+static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
+    return 128;
+    GGML_UNUSED(buft);
 }
 
+static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
+    ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context;
 
-static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
-                              const sycl::nd_item<3> &item_ct1) {
-    const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
+    size_t total_size = 0;
 
-    if (col >= ncols) {
-        return;
-    }
+    const int64_t ne0 = tensor->ne[0];
 
-    const int i = row*ncols + col;
-    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
-    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
-    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
-}
+    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+        int64_t row_low, row_high;
+        get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i);
 
-static void scale_f32(const float * x, float * dst, const float scale, const int k,
-                      const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
+        int64_t nrows_split = row_high - row_low;
+        if (nrows_split == 0) {
+            continue;
+        }
 
-    if (i >= k) {
-        return;
+        total_size += ggml_nbytes_split(tensor, nrows_split);
+
+        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
+        if (ne0 % MATRIX_ROW_PADDING != 0) {
+            total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
+        }
     }
 
-    dst[i] = scale * x[i];
+    return total_size;
 }
 
-static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
-                      const sycl::nd_item<3> &item_ct1) {
-    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                  item_ct1.get_local_id(2);
-
-    if (i >= k) {
-        return;
-    }
+static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
+    return false;
 
-    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
+    GGML_UNUSED(buft);
 }
 
-template 
-static  void pool2d_nchw_kernel(
-        const int ih, const int iw, const int oh, const int ow,
-        const int kh, const int kw, const int sh, const int sw,
-        const int ph, const int pw, const int parallel_elements,
-        const Ti* src, To* dst, const enum ggml_op_pool op,
-        const sycl::nd_item<3> &item_ct1) {
-        int idx = item_ct1.get_local_id(2) +
-                  item_ct1.get_group(2) * item_ct1.get_local_range(2);
-        if (idx >= parallel_elements) {
-            return;
-        }
+static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = {
+    /* .get_name         = */ ggml_backend_sycl_split_buffer_type_get_name,
+    /* .alloc_buffer     = */ ggml_backend_sycl_split_buffer_type_alloc_buffer,
+    /* .get_alignment    = */ ggml_backend_sycl_split_buffer_type_get_alignment,
+    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
+    /* .get_alloc_size   = */ ggml_backend_sycl_split_buffer_type_get_alloc_size,
+    /* .is_host          = */ ggml_backend_sycl_split_buffer_type_is_host,
+};
 
-        const int I_HW = ih * iw;
-        const int O_HW = oh * ow;
-        const int nc = idx / O_HW;
-        const int cur_oh = idx % O_HW / ow;
-        const int cur_ow = idx % O_HW % ow;
-        const Ti* i_ptr = src + nc * I_HW;
-        To* o_ptr = dst + nc * O_HW;
-        const int start_h = cur_oh * sh - ph;
-        const int bh = sycl::max(0, start_h);
-        const int eh = sycl::min(ih, start_h + kh);
-        const int start_w = cur_ow * sw - pw;
-        const int bw = sycl::max(0, start_w);
-        const int ew = sycl::min(iw, start_w + kw);
+ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
+    static std::mutex mutex;
+    std::lock_guard lock(mutex);
 
-        To res = 0;
+    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
+    ggml_check_sycl();
+    // FIXME: this is not thread safe
+    static std::map, struct ggml_backend_buffer_type> buft_map;
 
-        switch (op) {
-            case GGML_OP_POOL_AVG: res = 0; break;
-            case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
-        }
+    std::array tensor_split_arr = {};
 
-        for (int i = bh; i < eh; i += 1) {
-            for (int j = bw; j < ew; j += 1) {
-#if DPCT_COMPATIBILITY_TEMP >= 350
-                /*
-                DPCT1098:106: The '*' expression is used instead of the __ldg
-                call. These two expressions do not provide the exact same
-                functionality. Check the generated code for potential precision
-                and/or performance issues.
-                */
-                Ti cur = *(i_ptr + i * iw + j);
-#else
-                Ti cur = i_ptr[i * iw + j];
-#endif
-                switch (op) {
-                    case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
-                    case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
-                }
-            }
+    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; });
+    if (all_zero) {
+        tensor_split_arr = ggml_sycl_info().default_tensor_split;
+    } else {
+        float split_sum = 0.0f;
+        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+            tensor_split_arr[i] = split_sum;
+            split_sum += tensor_split[i];
         }
-        o_ptr[cur_oh * ow + cur_ow] = res;
-}
+        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+            tensor_split_arr[i] /= split_sum;
+        }
+    }
 
-template 
-static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                          ggml_tensor *dst, const void *src0_dd,
-                          const int32_t *src1_dd, float *dst_dd,
-                          queue_ptr stream) {
+    auto it = buft_map.find(tensor_split_arr);
+    if (it != buft_map.end()) {
+        return &it->second;
+    }
 
-    GGML_TENSOR_BINARY_OP_LOCALS
+    struct ggml_backend_buffer_type buft {
+        /* .iface   = */ ggml_backend_sycl_split_buffer_type_interface,
+        /* .device  = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
+        /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr},
+    };
 
-    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
-    const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
-    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
+    auto result = buft_map.emplace(tensor_split_arr, buft);
+    return &result.first->second;
+}
 
-    // strides in elements
-    //const size_t s0 = nb0 / ggml_element_size(dst);
-    const size_t s1 = nb1 / ggml_element_size(dst);
-    const size_t s2 = nb2 / ggml_element_size(dst);
-    const size_t s3 = nb3 / ggml_element_size(dst);
+// host buffer type
 
-    const size_t s10 = nb10 / ggml_element_size(src1);
-    const size_t s11 = nb11 / ggml_element_size(src1);
-    const size_t s12 = nb12 / ggml_element_size(src1);
-    //const size_t s13 = nb13 / ggml_element_size(src1);
+static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
+    return GGML_SYCL_NAME "_Host";
 
-    GGML_ASSERT(ne00 % 2 == 0);
+    GGML_UNUSED(buft);
+}
 
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             k_get_rows(
-                                 src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
-                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
-                         });
+static const char * ggml_backend_sycl_host_buffer_name(ggml_backend_buffer_t buffer) {
+    return GGML_SYCL_NAME "_Host";
 
-    (void) dst;
+    GGML_UNUSED(buffer);
 }
 
-template 
-static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst,
-                                const src0_t *src0_dd, const int32_t *src1_dd,
-                                float *dst_dd, queue_ptr stream) {
-
-    GGML_TENSOR_BINARY_OP_LOCALS
+static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
+    ggml_sycl_host_free(buffer->context);
+}
 
-    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
-    const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
-    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
+static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
+    void * ptr = ggml_sycl_host_malloc(size);
 
-    // strides in elements
-    //const size_t s0 = nb0 / ggml_element_size(dst);
-    const size_t s1 = nb1 / ggml_element_size(dst);
-    const size_t s2 = nb2 / ggml_element_size(dst);
-    const size_t s3 = nb3 / ggml_element_size(dst);
+    if (ptr == nullptr) {
+        // fallback to cpu buffer
+        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
+    }
 
-    const size_t s10 = nb10 / ggml_element_size(src1);
-    const size_t s11 = nb11 / ggml_element_size(src1);
-    const size_t s12 = nb12 / ggml_element_size(src1);
-    //const size_t s13 = nb13 / ggml_element_size(src1);
+    // FIXME: this is a hack to avoid having to implement a new buffer type
+    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
+    buffer->buft = buft;
+    buffer->iface.get_name = ggml_backend_sycl_host_buffer_name;
+    buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;
 
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
+    return buffer;
+}
 
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) {
-                k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
-                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
-            });
-    }
+ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() {
+    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n");
+    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = {
+        /* .iface    = */ {
+            /* .get_name         = */ ggml_backend_sycl_host_buffer_type_name,
+            /* .alloc_buffer     = */ ggml_backend_sycl_host_buffer_type_alloc_buffer,
+            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
+            /* .get_max_size     = */ NULL, // TODO: return device.maxBufferLength
+            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
+            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
+        },
+        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), 0),
+        /* .context  = */ nullptr,
+    };
 
-    (void) dst;
+    return &ggml_backend_sycl_buffer_type_host;
 }
 
-template
-struct bin_bcast_sycl {
-    template 
-    void operator()(ggml_backend_sycl_context & ctx,
-                    const struct ggml_tensor *src0,
-                    const struct ggml_tensor *src1, struct ggml_tensor *dst,
-                    const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
-                    queue_ptr stream) {
-
-        GGML_TENSOR_BINARY_OP_LOCALS
-
-        int nr0 = ne10/ne0;
-        int nr1 = ne11/ne1;
-        int nr2 = ne12/ne2;
-        int nr3 = ne13/ne3;
+// buffer pool for sycl (legacy)
+struct ggml_sycl_pool_leg : public ggml_sycl_pool {
+    static const int MAX_SYCL_BUFFERS = 256;
 
-        int nr[4] = { nr0, nr1, nr2, nr3 };
+    int device;
+    queue_ptr qptr;
+    struct ggml_sycl_buffer {
+        void * ptr = nullptr;
+        size_t size = 0;
+    };
 
-        // collapse dimensions until first broadcast dimension
-        int64_t cne0[] = {ne0, ne1, ne2, ne3};
-        int64_t cne1[] = {ne10, ne11, ne12, ne13};
-        size_t cnb0[] = {nb0, nb1, nb2, nb3};
-        size_t cnb1[] = {nb10, nb11, nb12, nb13};
-        auto collapse = [](int64_t cne[]) {
-            cne[0] *= cne[1];
-            cne[1] = cne[2];
-            cne[2] = cne[3];
-            cne[3] = 1;
-        };
+    ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {};
+    size_t pool_size = 0;
 
-        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
-            cnb[1] *= cne[1];
-            cnb[2] *= cne[2];
-            cnb[3] *= cne[3];
-        };
+    explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) :
+        qptr(qptr_),
+        device(device_) {
+    }
 
-        for (int i = 0; i < 4; i++) {
-            if (nr[i] != 1) {
-                break;
+    ~ggml_sycl_pool_leg() {
+        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
+            ggml_sycl_buffer & b = buffer_pool[i];
+            if (b.ptr != nullptr) {
+                SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
+                pool_size -= b.size;
             }
-            if (i > 0) {
-                collapse_nb(cnb0, cne0);
-                collapse_nb(cnb1, cne1);
-                collapse(cne0);
-                collapse(cne1);
+        }
+        GGML_ASSERT(pool_size == 0);
+    }
+
+    void * alloc(size_t size, size_t * actual_size) override {
+#ifdef DEBUG_sycl_MALLOC
+        int nnz = 0;
+        size_t max_size = 0;
+#endif
+        size_t best_diff = 1ull << 36;
+        int ibest = -1;
+        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
+            ggml_sycl_buffer& b = buffer_pool[i];
+            if (b.ptr != nullptr) {
+#ifdef DEBUG_sycl_MALLOC
+                ++nnz;
+                if (b.size > max_size) max_size = b.size;
+#endif
+                if (b.size >= size) {
+                    size_t diff = b.size - size;
+                    if (diff < best_diff) {
+                        best_diff = diff;
+                        ibest = i;
+                        if (!best_diff) {
+                            void * ptr = b.ptr;
+                            *actual_size = b.size;
+                            b.ptr = nullptr;
+                            b.size = 0;
+                            return ptr;
+                        }
+                    }
+                }
             }
         }
-        {
-            int64_t ne0 = cne0[0];
-            int64_t ne1 = cne0[1];
-            int64_t ne2 = cne0[2];
-            int64_t ne3 = cne0[3];
+        if (ibest >= 0) {
+            ggml_sycl_buffer& b = buffer_pool[ibest];
+            void * ptr = b.ptr;
+            *actual_size = b.size;
+            b.ptr = nullptr;
+            b.size = 0;
+            return ptr;
+        }
+        void * ptr;
+        size_t look_ahead_size = (size_t) (1.05 * size);
 
-            int64_t ne10 = cne1[0];
-            int64_t ne11 = cne1[1];
-            int64_t ne12 = cne1[2];
-            int64_t ne13 = cne1[3];
+        SYCL_CHECK(
+            CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device(
+                                look_ahead_size, *qptr)));
+        if (!ptr) {
+            fprintf(stderr, "%s: can't malloc %lu Bytes memory on device", __func__, look_ahead_size);
+            return nullptr;
+        }
 
-            size_t nb0 = cnb0[0];
-            size_t nb1 = cnb0[1];
-            size_t nb2 = cnb0[2];
-            size_t nb3 = cnb0[3];
+        *actual_size = look_ahead_size;
+        pool_size += look_ahead_size;
 
-            size_t nb10 = cnb1[0];
-            size_t nb11 = cnb1[1];
-            size_t nb12 = cnb1[2];
-            size_t nb13 = cnb1[3];
+    #ifdef DEBUG_SYCL_MALLOC
+        fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
+                (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
+    #endif
+        // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr);
+        return ptr;
+    }
 
-            size_t s0 = nb0 / sizeof(dst_t);
-            size_t s1 = nb1 / sizeof(dst_t);
-            size_t s2 = nb2 / sizeof(dst_t);
-            size_t s3 = nb3 / sizeof(dst_t);
+    void free(void * ptr, size_t size) override {
+        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
+            ggml_sycl_buffer& b = buffer_pool[i];
+            if (b.ptr == nullptr) {
+                b.ptr = ptr;
+                b.size = size;
+                return;
+            }
+        }
+        fprintf(stderr, "WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n");
+        SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
+        pool_size -= size;
+    }
+};
 
-            size_t s10 = nb10 / sizeof(src1_t);
-            size_t s11 = nb11 / sizeof(src1_t);
-            size_t s12 = nb12 / sizeof(src1_t);
-            size_t s13 = nb13 / sizeof(src1_t);
+std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
+    // TBD: NO VMM support
+    // if (ggml_sycl_info().devices[device].vmm) {
+    //     return std::unique_ptr(new ggml_sycl_pool_vmm(device));
+    // }
+   return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device));
+}
 
-            GGML_ASSERT(s0 == 1);
-            GGML_ASSERT(s10 == 1);
+// TBD pool with virtual memory management
+// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
 
-            const int block_size = 128;
+/// kernels
 
-            int64_t hne0 = std::max(ne0/2LL, 1LL);
+typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
+typedef void (*ggml_sycl_func_t)(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
+typedef void (*ggml_sycl_op_mul_mat_t)(
+    ggml_backend_sycl_context & ctx,
+    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+    const int64_t src1_ncols, const int64_t src1_padded_row_size,
+    const queue_ptr &stream);
+typedef void (*ggml_sycl_op_flatten_t)(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                       const ggml_tensor *src1,
+                                       ggml_tensor *dst, const float *src0_dd,
+                                       const float *src1_dd, float *dst_dd,
+                                       const queue_ptr &main_stream);
 
-            sycl::range<3> block_dims(1, 1, 1);
-            block_dims[2] = std::min(hne0, block_size);
-            block_dims[1] = std::min(
-                ne1, block_size / (unsigned int)block_dims[2]);
-            block_dims[0] = std::min(
-                std::min(
-                    ne2 * ne3, block_size / (unsigned int)block_dims[2] /
-                                   (unsigned int)block_dims[1]),
-                64U);
+static __dpct_inline__ float op_repeat(const float a, const float b) {
+    return b;
+    GGML_UNUSED(a);
+}
 
-            sycl::range<3> block_nums(
-                (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
-                (ne1 + block_dims[1] - 1) / block_dims[1],
-                (hne0 + block_dims[2] - 1) / block_dims[2]);
+static __dpct_inline__ float op_add(const float a, const float b) {
+    return a + b;
+}
 
-            if (block_nums[0] > 65535) {
-                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
-                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
-                {
-                    dpct::has_capability_or_fail(stream->get_device(),
-                                                 {sycl::aspect::fp16});
+static __dpct_inline__ float op_mul(const float a, const float b) {
+    return a * b;
+}
 
-                    stream->parallel_for(
-                        sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
-                                              sycl::range<3>(1, 1, block_size),
-                                          sycl::range<3>(1, 1, block_size)),
-                        [=](sycl::nd_item<3> item_ct1) {
-                            k_bin_bcast_unravel(
-                                src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
-                                ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
-                                s13, item_ct1);
-                        });
-                }
-            } else {
-                /*
-                DPCT1049:16: The work-group size passed to the SYCL kernel may
-                exceed the limit. To get the device limit, query
-                info::device::max_work_group_size. Adjust the work-group size if
-                needed.
-                */
-                dpct::has_capability_or_fail(stream->get_device(),
-                                             {sycl::aspect::fp16});
+static __dpct_inline__ float op_div(const float a, const float b) {
+    return a / b;
+}
 
-                stream->parallel_for(
-                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                    [=](sycl::nd_item<3> item_ct1) {
-                        k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1,
-                                            ne2, ne3, ne10, ne11, ne12, ne13,
-                                            s1, s2, s3, s11, s12, s13,
-                                            item_ct1);
-                    });
-            }
-        }
+template
+static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s10,*/ int s11, int s12, int s13,
+        const sycl::nd_item<3> &item_ct1) {
+    const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2);
+    const int i1 = (item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1));
+    const int i2 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+                    item_ct1.get_local_id(0)) /
+                   ne3;
+    const int i3 = (item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+                    item_ct1.get_local_id(0)) %
+                   ne3;
+
+    if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
     }
-};
 
-static void acc_f32_sycl(const float *x, const float *y, float *dst,
-                         const int n_elements, const int ne10, const int ne11,
-                         const int ne12, const int nb1, const int nb2,
-                         const int offset, queue_ptr stream) {
-    int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
-                    item_ct1);
-        });
-}
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
 
-static void gelu_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            gelu_f32(x, dst, k, item_ct1);
-        });
-}
+    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  = i_src0;
 
-static void silu_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            silu_f32(x, dst, k, item_ct1);
-        });
-}
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
 
-static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
-                                queue_ptr stream) {
-    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            gelu_quick_f32(x, dst, k, item_ct1);
-        });
+    for (int i0 = i0s; i0 < ne0;
+         i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) {
+        const int i10 = i0 % ne10;
+        dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
+    }
 }
 
-static void tanh_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            tanh_f32(x, dst, k, item_ct1);
-        });
-}
+template
+static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
+        int ne0, int ne1, int ne2, int ne3,
+        int ne10, int ne11, int ne12, int ne13,
+        /*int s0, */ int s1,  int s2,  int s3,
+        /*int s10,*/ int s11, int s12, int s13,
+        const sycl::nd_item<3> &item_ct1) {
 
-static void relu_f32_sycl(const float *x, float *dst, const int k,
-                          queue_ptr stream) {
-    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            relu_f32(x, dst, k, item_ct1);
-        });
-}
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
-                                 queue_ptr stream) {
-    const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            hardsigmoid_f32(x, dst, k, item_ct1);
-        });
-}
+    const int i3 = i/(ne2*ne1*ne0);
+    const int i2 = (i/(ne1*ne0)) % ne2;
+    const int i1 = (i/ne0) % ne1;
+    const int i0 = i % ne0;
 
-static void hardswish_f32_sycl(const float *x, float *dst, const int k,
-                               queue_ptr stream) {
-    const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            hardswish_f32(x, dst, k, item_ct1);
-        });
-}
+    if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
+        return;
+    }
 
-static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
-                                const float negative_slope,
-                                queue_ptr stream) {
-    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
-        });
-}
+    const int i11 = i1 % ne11;
+    const int i12 = i2 % ne12;
+    const int i13 = i3 % ne13;
 
-static void sqr_f32_sycl(const float *x, float *dst, const int k,
-                         queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            sqr_f32(x, dst, k, item_ct1);
-        });
-}
+    const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
+    const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
+    const size_t i_dst  = i_src0;
 
-static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
-                             const int nb02, const int nb03, const int ne10, const int ne11,
-                             const int ne12, const int ne13, const float sf0, const float sf1,
-                             const float sf2, const float sf3, queue_ptr stream) {
-    int dst_size = ne10 * ne11 * ne12 * ne13;
-    int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
-    sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
-    stream->parallel_for(
-        sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
-        [=](sycl::nd_item<1> item_ct1) {
-            upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
-        });
+    const src0_t * src0_row = src0 + i_src0;
+    const src1_t * src1_row = src1 + i_src1;
+    dst_t * dst_row = dst + i_dst;
+
+    const int i10 = i0 % ne10;
+    dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
 }
 
-static void pad_f32_sycl(const float *x, float *dst, const int ne00,
-                         const int ne01, const int ne02, const int ne0,
-                         const int ne1, const int ne2, queue_ptr stream) {
-    int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
-    sycl::range<3> gridDim(ne2, ne1, num_blocks);
-    stream->parallel_for(
-        sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
-        });
+static void acc_f32(const float * x, const float * y, float * dst, const int ne,
+    const int ne10, const int ne11, const int ne12,
+    const int nb1, const int nb2, int offset, const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+    if (i >= ne) {
+        return;
+    }
+    int src1_idx = i - offset;
+    int oz = src1_idx / nb2;
+    int oy = (src1_idx - (oz * nb2)) / nb1;
+    int ox = src1_idx % nb1;
+    if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
+        dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+    } else {
+        dst[i] = x[i];
+    }
 }
 
-static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
-                                   const int ky, const int kx_padded,
-                                   queue_ptr stream) {
-    const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
-    const sycl::range<3> num_blocks(1, ky, block_num_x);
-    int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
-    static_assert(QK8_1 % WARP_SIZE == 0);
-    const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
+static void gelu_f32(const float * x, float * dst, const int k,
+                     const sycl::nd_item<3> &item_ct1) {
+    const float GELU_COEF_A    = 0.044715f;
+    const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-        stream->parallel_for(
-            sycl::nd_range<3>(num_blocks * block_size, block_size),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
-                quantize_q8_1(x, vy, kx, kx_padded, item_ct1);
-            });
+    if (i >= k) {
+        return;
     }
-}
 
-static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
-                                           float *dst, const int ncols_x,
-                                           const int nrows_x,
-                                           const int nchannels_x,
-                                           const int nchannels_y,
-                                           queue_ptr stream) {
+    float xi = x[i];
+    dst[i] = 0.5f * xi *
+             (1.0f +
+              sycl::tanh(SQRT_2_OVER_PI * xi * (1.0f + GELU_COEF_A * xi * xi)));
+}
 
-    const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
-    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
+static void silu_f32(const float * x, float * dst, const int k,
+                     const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
-                mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
-                                     nchannels_y, item_ct1);
-            });
+    if (i >= k) {
+        return;
     }
+    dst[i] = x[i] / (1.0f + sycl::native::exp(-x[i]));
 }
 
-static void ggml_mul_mat_vec_nc_f16_f32_sycl(
-    const void *vx, const float *y, float *dst, const int ncols_x,
-    const int nrows_x, const int row_stride_x, const int nchannels_x,
-    const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
-
-    const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
-    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(block_nums * block_dims, block_dims),
-            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
-                mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
-                                       row_stride_x, channel_stride_x,
-                                       nchannels_y / nchannels_x, item_ct1);
-            });
+static void gelu_quick_f32(const float *x, float *dst, int k,
+                           const sycl::nd_item<3> &item_ct1) {
+    const float GELU_QUICK_COEF = -1.702f;
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+    if (i >= k) {
+        return;
     }
+    dst[i] = x[i] * (1.0f / (1.0f + sycl::native::exp(GELU_QUICK_COEF * x[i])));
 }
 
-static void
-ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
-                      const int ne01, const int ne02, const int nb00,
-                      const int nb01, const int nb02, const int nb03,
-                      const int ne10, const int ne11, const int ne12,
-                      const int nb10, const int nb11, const int nb12,
-                      const int nb13, queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00,
-                                           nb01, nb02, nb03, ne10, ne11, ne12,
-                                           nb10, nb11, nb12, nb13, item_ct1);
-            });
+static void tanh_f32(const float *x, float *dst, int k,
+                     const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+    if (i >= k) {
+        return;
     }
+    dst[i] = sycl::tanh((float)(x[i]));
 }
 
-static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
-                                  const int ne00, const int ne01,
-                                  const int ne02, const int nb00,
-                                  const int nb01, const int nb02,
-                                  const int nb03, const int ne10,
-                                  const int ne11, const int ne12,
-                                  const int nb10, const int nb11,
-                                  const int nb12, const int nb13,
-                                  queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
+static void relu_f32(const float * x, float * dst, const int k,
+                     const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                           item_ct1);
-            });
+    if (i >= k) {
+        return;
     }
+    dst[i] = sycl::fmax((float)(x[i]), (float)0);
 }
 
-static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
-                                  const int ne00, const int ne01,
-                                  const int ne02, const int nb00,
-                                  const int nb01, const int nb02,
-                                  const int nb03, const int ne10,
-                                  const int ne11, const int ne12,
-                                  const int nb10, const int nb11,
-                                  const int nb12, const int nb13,
-                                  queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
+static void hardsigmoid_f32(const float * x, float * dst, const int k,
+                            const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                           item_ct1);
-            });
+    if (i >= k) {
+        return;
     }
+    dst[i] = sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
 }
 
-static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
-                                   const int ne00, const int ne01,
-                                   const int ne02, const int nb00,
-                                   const int nb01, const int nb02,
-                                   const int nb03, const int ne10,
-                                   const int ne11, const int ne12,
-                                   const int nb10, const int nb11,
-                                   const int nb12, const int nb13,
-                                   queue_ptr stream) {
+static void hardswish_f32(const float * x, float * dst, const int k,
+                          const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-    GGML_ASSERT(ne % QK8_0 == 0);
-    const int num_blocks = ne / QK8_0;
-    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
-                                           sycl::range<3>(1, 1, 1)),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             cpy_f32_q(
-                                 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                 item_ct1);
-                         });
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * sycl::fmin(1.0f, sycl::fmax(0.0f, (x[i] + 3.0f) / 6.0f));
 }
 
-static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
-                                   const int ne00, const int ne01,
-                                   const int ne02, const int nb00,
-                                   const int nb01, const int nb02,
-                                   const int nb03, const int ne10,
-                                   const int ne11, const int ne12,
-                                   const int nb10, const int nb11,
-                                   const int nb12, const int nb13,
-                                   queue_ptr stream) {
-
-    GGML_ASSERT(ne % QK4_0 == 0);
-    const int num_blocks = ne / QK4_0;
-    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
-                                           sycl::range<3>(1, 1, 1)),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             cpy_f32_q(
-                                 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                 item_ct1);
-                         });
+static void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope,
+                           const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
+    if (i >= k) {
+        return;
+    }
+    dst[i] = sycl::fmax((float)(x[i]), (float)0) +
+             sycl::fmin((float)(x[i]), 0.0f) * negative_slope;
 }
 
-static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
-                                   const int ne00, const int ne01,
-                                   const int ne02, const int nb00,
-                                   const int nb01, const int nb02,
-                                   const int nb03, const int ne10,
-                                   const int ne11, const int ne12,
-                                   const int nb10, const int nb11,
-                                   const int nb12, const int nb13,
-                                   queue_ptr stream) {
+static void sqr_f32(const float * x, float * dst, const int k,
+                    const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-    GGML_ASSERT(ne % QK4_1 == 0);
-    const int num_blocks = ne / QK4_1;
-    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
-                                           sycl::range<3>(1, 1, 1)),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             cpy_f32_q(
-                                 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                 item_ct1);
-                         });
+    if (i >= k) {
+        return;
+    }
+    dst[i] = x[i] * x[i];
 }
 
-static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
-                                  const int ne00, const int ne01,
-                                  const int ne02, const int nb00,
-                                  const int nb01, const int nb02,
-                                  const int nb03, const int ne10,
-                                  const int ne11, const int ne12,
-                                  const int nb10, const int nb11,
-                                  const int nb12, const int nb13,
-                                  queue_ptr stream) {
+static void upscale_f32(const float  *x, float *dst, const int nb00, const int nb01,
+                        const int nb02, const int nb03, const int ne10, const int ne11,
+                        const int ne12, const int ne13, const float sf0, const float sf1,
+                        const float sf2, const float sf3, const sycl::nd_item<1> &item_ct1) {
+    int index = item_ct1.get_local_id(0) +
+               item_ct1.get_group(0) * item_ct1.get_local_range(0);
+    if (index >= ne10 * ne11 * ne12 * ne13) {
+        return;
+    }
+    // operation
+    int i10 = index % ne10;
+    int i11 = (index / ne10) % ne11;
+    int i12 = (index / (ne10 * ne11)) % ne12;
+    int i13 = (index / (ne10 * ne11 * ne12)) % ne13;
 
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        dpct::has_capability_or_fail(stream->get_device(),
-                                     {sycl::aspect::fp16});
+    int i00 = i10 / sf0;
+    int i01 = i11 / sf1;
+    int i02 = i12 / sf2;
+    int i03 = i13 / sf3;
 
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                           item_ct1);
-            });
-    }
+    dst[index] = *(float *)((char *)x + i03 * nb03 + i02 * nb02 + i01 * nb01 + i00 * nb00);
 }
 
-static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
-                                  const int ne00, const int ne01,
-                                  const int ne02, const int nb00,
-                                  const int nb01, const int nb02,
-                                  const int nb03, const int ne10,
-                                  const int ne11, const int ne12,
-                                  const int nb10, const int nb11,
-                                  const int nb12, const int nb13,
-                                  queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        // dpct::has_capability_or_fail(stream->get_device(),
-        //                              {sycl::aspect::fp16});
-
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                           item_ct1);
-            });
+static void pad_f32(const float  *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02,
+                    const sycl::nd_item<3> &item_ct1) {
+    int nidx = item_ct1.get_local_id(2) +
+               item_ct1.get_group(2) * item_ct1.get_local_range(2);
+    if (nidx >= ne0) {
+        return;
     }
-}
-
-static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
-                                  const int ne00, const int ne01,
-                                  const int ne02, const int nb00,
-                                  const int nb01, const int nb02,
-                                  const int nb03, const int ne10,
-                                  const int ne11, const int ne12,
-                                  const int nb10, const int nb11,
-                                  const int nb12, const int nb13,
-                                  queue_ptr stream) {
-
-    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
-    {
-        // dpct::has_capability_or_fail(stream->get_device(),
-        //                              {sycl::aspect::fp16});
 
-        stream->parallel_for(
-            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
-                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
-            [=](sycl::nd_item<3> item_ct1) {
-                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
-                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
-                                           item_ct1);
-            });
+    // operation
+    int offset_dst = nidx + item_ct1.get_group(1) * ne0 +
+                     item_ct1.get_group(0) * ne0 * item_ct1.get_group_range(1);
+    if (nidx < ne00 && item_ct1.get_group(1) < ne01 &&
+        item_ct1.get_group(0) < ne02) {
+        int offset_src = nidx + item_ct1.get_group(1) * ne00 +
+                         item_ct1.get_group(0) * ne00 * ne01;
+            dst[offset_dst] = x[offset_src];
+    } else {
+        dst[offset_dst] = 0.0f;
     }
 }
 
-static void scale_f32_sycl(const float *x, float *dst, const float scale,
-                           const int k, queue_ptr stream) {
-    const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            scale_f32(x, dst, scale, k, item_ct1);
-        });
-}
-
-static void clamp_f32_sycl(const float *x, float *dst, const float min,
-                           const float max, const int k,
-                           queue_ptr stream) {
-    const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
-    stream->parallel_for(
-        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
-                              sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            clamp_f32(x, dst, min, max, k, item_ct1);
-        });
-}
-
-static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
-                              const int nrows, queue_ptr stream) {
-    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
-    const sycl::range<3> block_nums(1, nrows, 1);
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1)
-                             [[intel::reqd_sub_group_size(WARP_SIZE)]] {
-                                 k_sum_rows_f32(x, dst, ncols, item_ct1);
-                             });
-}
+template
+static void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded,
+                          const sycl::nd_item<3> &item_ct1) {
+    const int ix = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2)) * QUANT_BLOCK_TILE;
 
-static int next_power_of_2(int x) {
-    int n = 1;
-    while (n < x) {
-        n *= 2;
+    if (ix >= kx_padded) {
+        return;
     }
-    return n;
-}
 
-static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
-                                 const int nrows, ggml_sort_order order,
-                                 queue_ptr stream) {
-    // bitonic sort requires ncols to be power of 2
-    const int ncols_pad = next_power_of_2(ncols);
+    const int iy = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                   item_ct1.get_local_id(1);
 
-    const sycl::range<3> block_dims(1, 1, ncols_pad);
-    const sycl::range<3> block_nums(1, nrows, 1);
-    const size_t shared_mem = ncols_pad * sizeof(int);
+    const int i_padded = iy*kx_padded + ix;
 
-    if (order == GGML_SORT_ORDER_ASC) {
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor dpct_local_acc_ct1(
-                sycl::range<1>(shared_mem), cgh);
+    block_q8_1 * y = (block_q8_1 *) vy;
 
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1) {
-                    k_argsort_f32_i32(
-                        x, dst, ncols, ncols_pad, item_ct1,
-                        dpct_local_acc_ct1.get_multi_ptr()
-                            .get());
-                });
-        });
-    } else if (order == GGML_SORT_ORDER_DESC) {
-        stream->submit([&](sycl::handler &cgh) {
-            sycl::local_accessor dpct_local_acc_ct1(
-                sycl::range<1>(shared_mem), cgh);
+    const int ib = i_padded / QK8_1; // block index
+    const int iqs = i_padded % QK8_1; // quant index
+    typedef  sycl::vec TC;
+    typedef  sycl::vec TQ;
+    TC zeros;
+    TQ qzeros;
+#pragma unroll
+    for (int i = 0; i < QUANT_BLOCK_TILE; i++)
+    {
+        zeros[i] = 0.f;
+        qzeros[i] = 0;
+    }
+    const TC xi = ix < kx ? *(TC *)&x[iy * kx + ix] : zeros;
+    float sum = xi[0];
+    float amax = sycl::fabs(xi[0]);
+#pragma unroll
+    for (int i = 1; i < QUANT_BLOCK_TILE; i++)
+    {
+        sum += xi[i];
+        amax = sycl::fmax(sycl::fabs(xi[i]), amax);
+    }
+    sum = warp_reduce_sum(sum, item_ct1);
+    amax = warp_reduce_max(amax, item_ct1);
 
-            cgh.parallel_for(
-                sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                [=](sycl::nd_item<3> item_ct1) {
-                    k_argsort_f32_i32(
-                        x, dst, ncols, ncols_pad, item_ct1,
-                        dpct_local_acc_ct1.get_multi_ptr()
-                            .get());
-                });
-        });
-    } else {
-        GGML_ABORT("fatal error");
+    const float d = amax / 127;
+    TQ q = qzeros;
+    if (amax != 0.0f)
+    {
+#pragma unroll
+        for (int i = 0; i < QUANT_BLOCK_TILE; i++) {
+            q[i] = sycl::round(xi[i] / d);
+        }
     }
-}
 
-static void diag_mask_inf_f32_sycl(const float *x, float *dst,
-                                   const int ncols_x, const int nrows_x,
-                                   const int rows_per_channel, const int n_past,
-                                   queue_ptr stream) {
-    const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);
-    const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
-    const sycl::range<3> block_nums(1, block_num_x, nrows_x);
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             diag_mask_inf_f32(x, dst, ncols_x,
-                                               rows_per_channel, n_past,
-                                               item_ct1);
-                         });
-}
+    *(TQ *)&y[ib].qs[iqs] = q;
 
-static bool g_sycl_loaded = false;
+    if (iqs > 0) {
+        return;
+    }
 
-bool ggml_sycl_loaded(void) {
-    return g_sycl_loaded;
+    reinterpret_cast(y[ib].ds.x()) = d;
+    reinterpret_cast(y[ib].ds.y()) = sum;
 }
 
-void print_device_detail(int id, sycl::device &device, std::string device_type) {
+template
+static void k_get_rows(
+            const void * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12,
+            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
 
-    dpct::device_info prop;
-    SYCL_CHECK(CHECK_TRY_ERROR(
-        dpct::get_device_info(prop, device)));
+    const int i00 = (item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+                     item_ct1.get_local_id(2)) *
+                    2;
+    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1);
+    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) /
+                    ne12;
+    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) %
+                    ne12;
 
-    std::string version;
-    version += std::to_string(prop.get_major_version());
-    version += ".";
-    version += std::to_string(prop.get_minor_version());
+    if (i00 >= ne00) {
+        return;
+    }
 
-    device_type = std::regex_replace(device_type, std::regex("ext_oneapi_"), "");
-    std::string name = std::string(prop.get_name());
-    name = std::regex_replace(name, std::regex("\\(R\\)"), "");
-    name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
 
-    auto global_mem_size = prop.get_global_mem_size()/1000000;
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
 
-    fprintf(stderr, "|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
-            name.c_str(), version.c_str(), prop.get_max_compute_units(),
-            prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
-            global_mem_size, device.get_info().c_str());
-}
+    const int ib = i00/qk; // block index
+    const int iqs = (i00%qk)/qr; // quant index
+    const int iybs = i00 - i00%qk; // dst block start index
+    const int y_offset = qr == 1 ? 1 : qk/2;
 
-void ggml_backend_sycl_print_sycl_devices() {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_print_sycl_devices\n");
-    int device_count = dpct::dev_mgr::instance().device_count();
-    std::map DeviceNums;
-    fprintf(stderr, "found %d SYCL devices:\n", device_count);
-    fprintf(stderr, "|  |                   |                                       |       |Max    |        |Max  |Global |                     |\n");
-    fprintf(stderr, "|  |                   |                                       |       |compute|Max work|sub  |mem    |                     |\n");
-    fprintf(stderr, "|ID|        Device Type|                                   Name|Version|units  |group   |group|size   |       Driver version|\n");
-    fprintf(stderr, "|--|-------------------|---------------------------------------|-------|-------|--------|-----|-------|---------------------|\n");
-    for (int id = 0; id < device_count; ++id) {
-        sycl::device device = dpct::dev_mgr::instance().get_device(id);
-        sycl::backend backend = device.get_backend();
-        std::string backend_type = get_device_backend_and_type(device);
-        int type_id=DeviceNums[backend_type]++;
-        std::stringstream device_type;
-        device_type << "[" <<  backend_type << ":" << std::to_string(type_id) << "]";
-        print_device_detail(id, device, device_type.str());
-    }
+    // dequantize
+    dfloat2 v;
+    dequantize_kernel(src0_row, ib, iqs, v);
+
+    dst_row[iybs + iqs + 0] = v.x();
+    dst_row[iybs + iqs + y_offset] = v.y();
 }
 
-static inline int get_sycl_env(const char *env_name, int default_val) {
-    char *user_device_string = getenv(env_name);
-    int user_number = default_val;
+template
+static void k_get_rows_float(
+            const src0_t * src0, const int32_t * src1, dst_t * dst,
+            int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+            /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+            /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+            /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+            size_t s10, size_t s11, size_t s12,
+            const sycl::nd_item<3> &item_ct1/*, size_t s13*/) {
 
-    unsigned n;
-    if (user_device_string != NULL &&
-        sscanf(user_device_string, " %u", &n) == 1) {
-        user_number = (int)n;
-    } else {
-        user_number = default_val;
+    const int i00 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+                    item_ct1.get_local_id(2);
+    const int i10 = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1);
+    const int i11 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) /
+                    ne12;
+    const int i12 = (item_ct1.get_group(0) * item_ct1.get_local_range(0) +
+                     item_ct1.get_local_id(0)) %
+                    ne12;
+
+    if (i00 >= ne00) {
+        return;
     }
-    return user_number;
+
+    const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+    dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+    const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+    dst_row[i00] = src0_row[i00];
 }
 
-static void ggml_check_sycl() try {
-    static bool initialized = false;
+static void mul_mat_p021_f16_f32(
+    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst,
+    const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y,
+    const sycl::nd_item<3> &item_ct1) {
 
-    if (!initialized) {
-        fprintf(stderr, "[SYCL] call ggml_check_sycl\n");
-        g_ggml_sycl_debug = get_sycl_env("GGML_SYCL_DEBUG", 0);
+    const sycl::half *x = (const sycl::half *)vx;
 
-        fprintf(stderr, "%s: GGML_SYCL_DEBUG: %d\n", __func__, g_ggml_sycl_debug);
+    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                      item_ct1.get_local_id(1);
+    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+                        item_ct1.get_local_id(0);
+    const int channel_x = channel / (nchannels_y / nchannels_x);
 
-#if defined(GGML_SYCL_F16)
-        fprintf(stderr, "%s: GGML_SYCL_F16: yes\n", __func__);
-#else
-        fprintf(stderr, "%s: GGML_SYCL_F16: no\n", __func__);
-#endif
+    const int nrows_y = ncols_x;
+    const int nrows_dst = nrows_x;
+    const int row_dst = row_x;
 
-/* NOT REMOVE, keep it for next optimize for XMX.
-#if defined(SYCL_USE_XMX)
-        fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
-#else
-        fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
-#endif
-*/
+    float tmp = 0.0f;
 
-        if (CHECK_TRY_ERROR(g_all_sycl_device_count =
-                            dpct::dev_mgr::instance().device_count()) != 0) {
-            initialized = true;
-            g_sycl_loaded = false;
-            return;
-        }
-        GGML_ASSERT(g_all_sycl_device_count <= GGML_SYCL_MAX_DEVICES);
-        ggml_backend_sycl_print_sycl_devices();
-        initialized = true;
-        g_sycl_loaded = true;
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
+    for (int col_x0 = 0; col_x0 < ncols_x;
+         col_x0 += item_ct1.get_local_range(2)) {
+        const int col_x = col_x0 + item_ct1.get_local_id(2);
 
-static ggml_sycl_device_info ggml_sycl_init() {
-    ggml_sycl_device_info info = {};
+        if (col_x >= ncols_x) {
+            break;
+        }
 
-    info.device_count = dpct::dev_mgr::instance().device_count();
-    if (info.device_count == 0) {
-        fprintf(stderr, "%s: failed to initialize " GGML_SYCL_NAME ": %s\n", __func__);
-        return info;
-    }
+        // x is transposed and permuted
+        const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x;
+        const float xi =
+            sycl::vec(x[ix])
+                .convert()[0];
 
-    GGML_ASSERT(info.device_count <= GGML_SYCL_MAX_DEVICES);
+        const int row_y = col_x;
 
-    int64_t total_vram = 0;
-#if defined(GGML_SYCL_FORCE_MMQ)
-    fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ:   yes\n", __func__);
-#else
-    fprintf(stderr, "%s: GGML_SYCL_FORCE_MMQ:   no\n", __func__);
-#endif
-#if defined(SYCL_USE_XMX)
-    fprintf(stderr, "%s: SYCL_USE_XMX: yes\n", __func__);
-#else
-    fprintf(stderr, "%s: SYCL_USE_XMX: no\n", __func__);
-#endif
-    fprintf(stderr, "%s: found %d " GGML_SYCL_NAME " devices:\n", __func__, info.device_count);
 
-    for (int i = 0; i < info.device_count; ++i) {
-        info.devices[i].vmm = 0;
-        dpct::device_info prop;
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
-            prop, dpct::dev_mgr::instance().get_device(i))));
+        // y is not transposed but permuted
+        const int iy = channel*nrows_y + row_y;
 
-        info.default_tensor_split[i] = total_vram;
-        total_vram += prop.get_global_mem_size();
+        tmp += xi * y[iy];
+    }
 
-        info.devices[i].cc =
-            100 * prop.get_major_version() + 10 * prop.get_minor_version();
+    // dst is not transposed and not permuted
+    const int idst = channel*nrows_dst + row_dst;
 
-        info.max_work_group_sizes[i] = prop.get_max_work_group_size();
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
 
-    for (int id = 0; id < info.device_count; ++id) {
-        info.default_tensor_split[id] /= total_vram;
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[idst] = tmp;
     }
-    return info;
 }
 
-const ggml_sycl_device_info & ggml_sycl_info() {
-    static ggml_sycl_device_info info = ggml_sycl_init();
-    return info;
-}
+static void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
+    const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x,
+    const int row_stride_x, const int channel_stride_x, const int channel_x_divisor,
+    const sycl::nd_item<3> &item_ct1) {
 
-/*
-device_index: device index from 0 to n (continue numbers).
-    It is used for device select/set in SYCL backend internal data structure.
-*/
-inline void check_allow_gpu_index(const int device_index) {
-  if (device_index >= ggml_sycl_info().device_count) {
-    char error_buf[256];
-    snprintf(
-        error_buf,
-        sizeof(error_buf),
-        "%s error: device_index:%d is out of range: [0-%d]",
-        __func__,
-        device_index,
-        ggml_sycl_info().device_count - 1);
-    fprintf(stderr, "%s\n", error_buf);
-    assert(false);
-  }
-}
+    const sycl::half *x = (const sycl::half *)vx;
 
-// buffer pool for sycl (legacy)
-struct ggml_sycl_pool_leg : public ggml_sycl_pool {
-    static const int MAX_SYCL_BUFFERS = 256;
+    const int row_x = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                      item_ct1.get_local_id(1);
+    const int channel = item_ct1.get_local_range(0) * item_ct1.get_group(0) +
+                        item_ct1.get_local_id(0);
+    const int channel_x = channel / channel_x_divisor;
 
-    int device;
-    queue_ptr qptr;
-    struct ggml_sycl_buffer {
-        void * ptr = nullptr;
-        size_t size = 0;
-    };
+    const int nrows_y   = ncols_x;
+    const int nrows_dst = nrows_x;
+    const int row_dst   = row_x;
 
-    ggml_sycl_buffer buffer_pool[MAX_SYCL_BUFFERS] = {};
-    size_t pool_size = 0;
+    const int idst = channel*nrows_dst + row_dst;
 
-    explicit ggml_sycl_pool_leg(queue_ptr qptr_, int device_) :
-        qptr(qptr_),
-        device(device_) {
-    }
+    float tmp = 0.0f;
 
-    ~ggml_sycl_pool_leg() {
-        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
-            ggml_sycl_buffer & b = buffer_pool[i];
-            if (b.ptr != nullptr) {
-                SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(b.ptr, *qptr)));
-                pool_size -= b.size;
-            }
+    for (int col_x0 = 0; col_x0 < ncols_x;
+         col_x0 += item_ct1.get_local_range(2)) {
+        const int col_x = col_x0 + item_ct1.get_local_id(2);
+
+        if (col_x >= ncols_x) {
+            break;
         }
-        GGML_ASSERT(pool_size == 0);
+
+        const int row_y = col_x;
+
+        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
+        const int iy = channel*nrows_y + row_y;
+
+        const float xi =
+            sycl::vec(x[ix])
+                .convert()[0];
+
+        tmp += xi * y[iy];
     }
 
-    void * alloc(size_t size, size_t * actual_size) override {
-#ifdef DEBUG_sycl_MALLOC
-        int nnz = 0;
-        size_t max_size = 0;
-#endif
-        size_t best_diff = 1ull << 36;
-        int ibest = -1;
-        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
-            ggml_sycl_buffer& b = buffer_pool[i];
-            if (b.ptr != nullptr) {
-#ifdef DEBUG_sycl_MALLOC
-                ++nnz;
-                if (b.size > max_size) max_size = b.size;
-#endif
-                if (b.size >= size) {
-                    size_t diff = b.size - size;
-                    if (diff < best_diff) {
-                        best_diff = diff;
-                        ibest = i;
-                        if (!best_diff) {
-                            void * ptr = b.ptr;
-                            *actual_size = b.size;
-                            b.ptr = nullptr;
-                            b.size = 0;
-                            return ptr;
-                        }
-                    }
-                }
-            }
-        }
-        if (ibest >= 0) {
-            ggml_sycl_buffer& b = buffer_pool[ibest];
-            void * ptr = b.ptr;
-            *actual_size = b.size;
-            b.ptr = nullptr;
-            b.size = 0;
-            return ptr;
-        }
-        void * ptr;
-        size_t look_ahead_size = (size_t) (1.05 * size);
+    // sum up partial sums and write back result
+#pragma unroll
+    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+        tmp +=
+            dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
+    }
 
-        SYCL_CHECK(
-            CHECK_TRY_ERROR(ptr = (void *)sycl::malloc_device(
-                                look_ahead_size, *qptr)));
-        if (!ptr) {
-            fprintf(stderr, "%s: can't malloc %lu Bytes memory on device", __func__, look_ahead_size);
-            return nullptr;
-        }
+    if (item_ct1.get_local_id(2) == 0) {
+        dst[idst] = tmp;
+    }
+}
 
-        *actual_size = look_ahead_size;
-        pool_size += look_ahead_size;
+static void cpy_1_f32_f32(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    float * dsti = (float *) cdsti;
 
-    #ifdef DEBUG_SYCL_MALLOC
-        fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
-                (uint32_t)(max_size/1024/1024), (uint32_t)(g_sycl_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
-    #endif
-        // GGML_SYCL_DEBUG("ggml_sycl_pool_malloc_leg look_ahead_size=%lu, return %p\n", look_ahead_size, ptr);
-        return ptr;
-    }
+    *dsti = *xi;
+}
 
-    void free(void * ptr, size_t size) override {
-        for (int i = 0; i < MAX_SYCL_BUFFERS; ++i) {
-            ggml_sycl_buffer& b = buffer_pool[i];
-            if (b.ptr == nullptr) {
-                b.ptr = ptr;
-                b.size = size;
-                return;
-            }
-        }
-        fprintf(stderr, "WARNING: sycl buffer pool full, increase MAX_sycl_BUFFERS\n");
-        SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(ptr, *qptr)));
-        pool_size -= size;
-    }
-};
+static void cpy_1_f32_f16(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    sycl::half *dsti = (sycl::half *)cdsti;
 
-std::unique_ptr ggml_backend_sycl_context::new_pool_for_device(queue_ptr qptr, int device) {
-    // TBD: NO VMM support
-    // if (ggml_sycl_info().devices[device].vmm) {
-    //     return std::unique_ptr(new ggml_sycl_pool_vmm(device));
-    // }
-   return std::unique_ptr(new ggml_sycl_pool_leg(qptr, device));
+    *dsti = sycl::vec(*xi)
+                .convert()[0];
 }
 
-// TBD pool with virtual memory management
-// struct ggml_sycl_pool_vmm : public ggml_sycl_pool
+static void cpy_1_f16_f16(const char * cxi, char * cdsti) {
+    const sycl::half *xi = (const sycl::half *)cxi;
+    sycl::half *dsti = (sycl::half *)cdsti;
 
-static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
-                                          const struct ggml_tensor *src,
-                                          int64_t i3, int64_t i2,
-                                          int64_t i1_low, int64_t i1_high,
-                                          queue_ptr stream) try {
+    *dsti = *xi;
+}
 
-    dpct::memcpy_direction kind;
-    char * src_ptr;
-    if (src->backend == GGML_BACKEND_TYPE_CPU) {
-        kind = dpct::host_to_device;
-        src_ptr = (char *) src->data;
-        // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d  GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
-    } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
-        GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
-        kind = dpct::device_to_device;
-        ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
-        int id;
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            id = get_current_device_id()));
-        // GGML_SYCL_DEBUG("current device index %d\n", id);
-        src_ptr = (char *) extra->data_device[id];
-    } else {
-        // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
-        GGML_ABORT("fatal error");
-    }
-    char * dst_ptr = (char *) dst;
+static void cpy_1_f16_f32(const char * cxi, char * cdsti) {
+    const sycl::half *xi = (const sycl::half *)cxi;
+    float * dsti = (float *) cdsti;
 
-    GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
-    GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
-    const enum ggml_type type = src->type;
-    const int64_t ts = ggml_type_size(type);
-    const int64_t bs = ggml_blck_size(type);
-    int64_t i1_diff = i1_high - i1_low;
+    *dsti = *xi;
+}
 
-    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
-    if (nb0 == ts && nb1 == ts*ne0/bs) {
-        // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
-        // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
-        return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
-                                    kind, *stream));
+static void cpy_1_i16_i16(const char * cxi, char * cdsti) {
+    const int16_t *xi = (const int16_t *)cxi;
+    int16_t *dsti = (int16_t *)cdsti;
 
-    } else if (nb0 == ts) {
-        return CHECK_TRY_ERROR(
-            dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
-                                    ts * ne0 / bs, i1_diff, kind, *stream));
-    } else {
-        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
-            const void * rx = (const void *) ((const char *) x + i1*nb1);
-            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
-            // pretend the row is a matrix with cols=1
-            dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
-                rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
-            /*
-            DPCT1001:85: The statement could not be removed.
-            */
-            /*
-            DPCT1000:86: Error handling if-stmt was detected but could not be
-            rewritten.
-            */
-            if (r != 0) return r;
-        }
-        return 0;
-    }
+    *dsti = *xi;
 }
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
+
+static void cpy_1_i32_i32(const char * cxi, char * cdsti) {
+    const int32_t *xi = (const int32_t *)cxi;
+    int32_t *dsti = (int32_t *)cdsti;
+
+    *dsti = *xi;
 }
 
-static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                  const float *src0_d, const float *src1_d,
-                                  float *dst_d, const queue_ptr &stream) {
+template 
+static void cpy_f32_f16(const char * cx, char * cdst, const int ne,
+                        const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                        const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                        const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-    GGML_ASSERT(src1->type == GGML_TYPE_I32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    if (i >= ne) {
+        return;
+    }
 
-    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
-    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
-    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
+    // determine indices i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
+    // then combine those indices with the corresponding byte offsets to get the total offsets
+    const int i03 = i/(ne00 * ne01 * ne02);
+    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
 
-    const int32_t * src1_i32 = (const int32_t *) src1_d;
+    const int i13 = i/(ne10 * ne11 * ne12);
+    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
 
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
-                                src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_F32:
-            get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q4_0:
-            get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q4_1:
-            get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q5_0:
-            get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q5_1:
-            get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        case GGML_TYPE_Q8_0:
-            get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
-            break;
-        default:
-            // TODO: k-quants
-            fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
-            GGML_ABORT("fatal error");
-            break;
-    }
+    cpy_1(cx + x_offset, cdst + dst_offset);
 }
 
-template 
-inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                   const ggml_tensor *src1, ggml_tensor *dst,
-                                   const float *src0_dd, const float *src1_dd,
-                                   float *dst_dd,
-                                   const queue_ptr &main_stream) {
+static void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q8_0 * dsti = (block_q8_0 *) cdsti;
 
-    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-        op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
-    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
-        op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
-             (sycl::half *)dst_dd, main_stream);
-    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
-        op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
-             main_stream);
-    } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
-        op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
-             main_stream);
-    } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
-        op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
-             main_stream);
-    } else {
-        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
-            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
-        GGML_ABORT("fatal error");
-    }
-}
+    float amax = 0.0f; // absolute max
 
-static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst,
-                                const float *src0_d, const float *src1_d,
-                                float *dst_d,
-                                const queue_ptr &main_stream) {
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = xi[j];
+        amax = sycl::fmax(amax, sycl::fabs((float)v));
+    }
 
-    ggml_sycl_op_bin_bcast>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
 
-    (void) src1;
-    (void) src1_d;
-}
+    dsti->d = d;
 
-inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = xi[j]*id;
 
-    ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+        dsti->qs[j] = sycl::round((float)x0);
+    }
 }
 
-inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
+static void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_0 * dsti = (block_q4_0 *) cdsti;
 
-    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
-    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
-    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
-    int offset = dst->op_params[3] / 4; // offset in bytes
+    float amax = 0.0f;
+    float vmax = 0.0f;
 
-    acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
+    for (int j = 0; j < QK4_0; ++j) {
+        const float v = xi[j];
+        if (amax < sycl::fabs((float)v)) {
+            amax = sycl::fabs((float)v);
+            vmax = v;
+        }
+    }
 
-    (void) dst;
-}
+    const float d  = vmax / -8;
+    const float id = d ? 1.0f/d : 0.0f;
 
-inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
+    dsti->d = d;
 
-    ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
-}
+    for (int j = 0; j < QK4_0/2; ++j) {
+        const float x0 = xi[0       + j]*id;
+        const float x1 = xi[QK4_0/2 + j]*id;
 
-inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
+        const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 8.5f));
+        const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 8.5f));
 
-    ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
 }
 
-inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                              ggml_tensor *dst, const float *src0_dd,
-                              const float *src1_dd, float *dst_dd,
-                              const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+static void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
+    const float * xi = (const float *) cxi;
+    block_q4_1 * dsti = (block_q4_1 *) cdsti;
 
-    gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+    float vmin = FLT_MAX;
+    float vmax = -FLT_MAX;
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
+    for (int j = 0; j < QK4_1; ++j) {
+        const float v = xi[j];
 
-inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                              ggml_tensor *dst, const float *src0_dd,
-                              const float *src1_dd, float *dst_dd,
-                              const queue_ptr &main_stream) {
+        if (v < vmin) vmin = v;
+        if (v > vmax) vmax = v;
+    }
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    const float d  = (vmax - vmin) / ((1 << 4) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
 
-    silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+    dsti->dm.x() = d;
+    dsti->dm.y() = vmin;
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
+    for (int j = 0; j < QK4_1/2; ++j) {
+        const float x0 = (xi[0       + j] - vmin)*id;
+        const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
 
-inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                    const ggml_tensor *src1, ggml_tensor *dst,
-                                    const float *src0_dd, const float *src1_dd,
-                                    float *dst_dd,
-                                    const queue_ptr &main_stream) {
+        const uint8_t xi0 = dpct::min(15, (int8_t)(x0 + 0.5f));
+        const uint8_t xi1 = dpct::min(15, (int8_t)(x1 + 0.5f));
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+        dsti->qs[j]  = xi0;
+        dsti->qs[j] |= xi1 << 4;
+    }
+}
 
-    gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+template 
+static void cpy_f32_q(const char * cx, char * cdst, const int ne,
+                      const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
+                      const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
+                      const int nb12, const int nb13, const sycl::nd_item<3> &item_ct1) {
+    const int i = (item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                   item_ct1.get_local_id(2)) *
+                  qk;
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
+    if (i >= ne) {
+        return;
+    }
 
-inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                              ggml_tensor *dst, const float *src0_dd,
-                              const float *src1_dd, float *dst_dd,
-                              const queue_ptr &main_stream) {
+    const int i03 = i/(ne00 * ne01 * ne02);
+    const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
+    const int i01 = (i - i03*ne00*ne01*ne02  -  i02*ne01*ne00) / ne00;
+    const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
+    const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-    tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+    const int i13 = i/(ne10 * ne11 * ne12);
+    const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
+    const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
+    const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
+    const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+    cpy_blck(cx + x_offset, cdst + dst_offset);
 }
 
-inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                              ggml_tensor *dst, const float *src0_dd,
-                              const float *src1_dd, float *dst_dd,
-                              const queue_ptr &main_stream) {
+static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
+                           const sycl::nd_item<3> &item_ct1) {
+    const int row = item_ct1.get_group(1);
+    const int col = item_ct1.get_local_id(2);
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    float sum = 0.0f;
+    for (int i = col; i < ncols; i += item_ct1.get_local_range(2)) {
+        sum += x[row * ncols + i];
+    }
 
-    relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+    sum = warp_reduce_sum(sum, item_ct1);
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+    if (col == 0) {
+        dst[row] = sum;
+    }
 }
 
-static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                     const ggml_tensor *src1, ggml_tensor *dst,
-                                     const float *src0_dd, const float *src1_dd,
-                                     float *dst_dd,
-                                     const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+template
+static inline void ggml_sycl_swap(T & a, T & b) {
+    T tmp = a;
+    a = b;
+    b = tmp;
 }
 
-static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                   const ggml_tensor *src1, ggml_tensor *dst,
-                                   const float *src0_dd, const float *src1_dd,
-                                   float *dst_dd, const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+template 
+__dpct_inline__ static void
+k_argsort_f32_i32(const float *x, int *dst, const int ncols, int ncols_pad,
+                  const sycl::nd_item<3> &item_ct1, uint8_t *dpct_local) {
+    // bitonic sort
+    int col = item_ct1.get_local_id(2);
+    int row = item_ct1.get_group(1);
 
-    hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+    if (col >= ncols_pad) {
+        return;
+    }
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
+    const float * x_row = x + row * ncols;
+    auto dst_row = (int *)dpct_local;
 
-inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                    const ggml_tensor *src1, ggml_tensor *dst,
-                                    const float *src0_dd, const float *src1_dd,
-                                    float *dst_dd,
-                                    const queue_ptr &main_stream) {
+    // initialize indices
+    dst_row[col] = col;
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    item_ct1.barrier(sycl::access::fence_space::local_space);
 
-    float negative_slope;
-    memcpy(&negative_slope, dst->op_params, sizeof(float));
-
-    leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
+    for (int k = 2; k <= ncols_pad; k *= 2) {
+        for (int j = k / 2; j > 0; j /= 2) {
+            int ixj = col ^ j;
+            if (ixj > col) {
+                if ((col & k) == 0) {
+                    if (dst_row[col] >= ncols ||
+                        (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+                    ) {
+                        ggml_sycl_swap(dst_row[col], dst_row[ixj]);
+                    }
+                } else {
+                    if (dst_row[ixj] >= ncols ||
+                        (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+                            x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+                            x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+                    ) {
+                        ggml_sycl_swap(dst_row[col], dst_row[ixj]);
+                    }
+                }
+            }
+            /*
+            DPCT1118:1: SYCL group functions and algorithms must be encountered
+            in converged control flow. You may need to adjust the code.
+            */
+            item_ct1.barrier(sycl::access::fence_space::local_space);
+        }
+    }
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+    // copy the result to dst without the padding
+    if (col < ncols) {
+        dst[row * ncols + col] = dst_row[col];
+    }
 }
 
-inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past,
+                              const sycl::nd_item<3> &item_ct1) {
+    const int col = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
+                    item_ct1.get_local_id(1);
+    const int row = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                    item_ct1.get_local_id(2);
 
-    sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+    if (col >= ncols) {
+        return;
+    }
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+    const int i = row*ncols + col;
+    //dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
+    //dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
+    dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
 }
 
-inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1, ggml_tensor *dst,
-                                 const float *src0_dd, const float *src1_dd,
-                                 float *dst_dd,
-                                 const queue_ptr &main_stream) {
+static void scale_f32(const float * x, float * dst, const float scale, const int k,
+                      const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    if (i >= k) {
+        return;
+    }
 
-    const float sf0 = (float)dst->ne[0]/src0->ne[0];
-    const float sf1 = (float)dst->ne[1]/src0->ne[1];
-    const float sf2 = (float)dst->ne[2]/src0->ne[2];
-    const float sf3 = (float)dst->ne[3]/src0->ne[3];
+    dst[i] = scale * x[i];
+}
 
-    upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
-                     dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
-                     main_stream);
+static void clamp_f32(const float * x, float * dst, const float min, const float max, const int k,
+                      const sycl::nd_item<3> &item_ct1) {
+    const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
+                  item_ct1.get_local_id(2);
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
 }
 
-inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                             ggml_tensor *dst, const float *src0_dd,
-                             const float *src1_dd, float *dst_dd,
-                             const queue_ptr &main_stream) {
+template 
+static  void pool2d_nchw_kernel(
+        const int ih, const int iw, const int oh, const int ow,
+        const int kh, const int kw, const int sh, const int sw,
+        const int ph, const int pw, const int parallel_elements,
+        const Ti* src, To* dst, const enum ggml_op_pool op,
+        const sycl::nd_item<3> &item_ct1) {
+        int idx = item_ct1.get_local_id(2) +
+                  item_ct1.get_group(2) * item_ct1.get_local_range(2);
+        if (idx >= parallel_elements) {
+            return;
+        }
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->type == GGML_TYPE_F32);
-    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+        const int I_HW = ih * iw;
+        const int O_HW = oh * ow;
+        const int nc = idx / O_HW;
+        const int cur_oh = idx % O_HW / ow;
+        const int cur_ow = idx % O_HW % ow;
+        const Ti* i_ptr = src + nc * I_HW;
+        To* o_ptr = dst + nc * O_HW;
+        const int start_h = cur_oh * sh - ph;
+        const int bh = sycl::max(0, start_h);
+        const int eh = sycl::min(ih, start_h + kh);
+        const int start_w = cur_ow * sw - pw;
+        const int bw = sycl::max(0, start_w);
+        const int ew = sycl::min(iw, start_w + kw);
 
-    pad_f32_sycl(src0_dd, dst_dd,
-        src0->ne[0], src0->ne[1], src0->ne[2],
-        dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
+        To res = 0;
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
+        switch (op) {
+            case GGML_OP_POOL_AVG: res = 0; break;
+            case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
+        }
 
-static int64_t get_row_rounding(ggml_type type, const std::array & tensor_split) {
-    int64_t min_compute_capability = INT_MAX;
-    int64_t max_compute_capability = INT_MIN;
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        if (tensor_split[i] < (i + 1 < ggml_sycl_info().device_count ? tensor_split[i + 1] : 1.0f)) {
-            if (min_compute_capability > ggml_sycl_info().devices[i].cc) {
-                min_compute_capability = ggml_sycl_info().devices[i].cc;
-            }
-            if (max_compute_capability < ggml_sycl_info().devices[i].cc) {
-                max_compute_capability = ggml_sycl_info().devices[i].cc;
+        for (int i = bh; i < eh; i += 1) {
+            for (int j = bw; j < ew; j += 1) {
+#if DPCT_COMPATIBILITY_TEMP >= 350
+                /*
+                DPCT1098:106: The '*' expression is used instead of the __ldg
+                call. These two expressions do not provide the exact same
+                functionality. Check the generated code for potential precision
+                and/or performance issues.
+                */
+                Ti cur = *(i_ptr + i * iw + j);
+#else
+                Ti cur = i_ptr[i * iw + j];
+#endif
+                switch (op) {
+                    case GGML_OP_POOL_AVG: res += (cur / (kh * kw)); break;
+                    case GGML_OP_POOL_MAX: res = sycl::max(res, (To)cur); break;
+                }
             }
         }
-    }
+        o_ptr[cur_oh * ow + cur_ow] = res;
+}
 
-    switch(type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-            return max_compute_capability >= VER_GEN9 ? 128 : 64;
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-            return 64;
-        case GGML_TYPE_F16:
-        case GGML_TYPE_F32:
-            return 1;
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ4_NL:
-            return max_compute_capability >= VER_GEN9 ? 128 : 64;
-        case GGML_TYPE_IQ3_S:
-            return max_compute_capability >= VER_GEN9 ? 128 : 64;
-        case GGML_TYPE_Q6_K:
-            return 64;
-        default:
-            GGML_ABORT("fatal error");
-    }
+template 
+static void get_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                          ggml_tensor *dst, const void *src0_dd,
+                          const int32_t *src1_dd, float *dst_dd,
+                          queue_ptr stream) {
 
-}
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-inline void ggml_sycl_op_mul_mat_sycl(
-    ggml_backend_sycl_context & ctx,
-    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
-    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
-    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
-    const int64_t src1_ncols, const int64_t src1_padded_row_size,
-    const queue_ptr &stream) try {
+    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
+    const int block_num_x = (ne00 + 2*SYCL_GET_ROWS_BLOCK_SIZE - 1) / (2*SYCL_GET_ROWS_BLOCK_SIZE);
+    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
 
-    GGML_ASSERT(src0_dd_i  != nullptr);
-    GGML_ASSERT(src1_ddf_i != nullptr);
-    GGML_ASSERT(dst_dd_i   != nullptr);
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne10 = src1->ne[0];
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
 
-    const int64_t ne0 = dst->ne[0];
+    GGML_ASSERT(ne00 % 2 == 0);
 
-    const int64_t row_diff = row_high - row_low;
+    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             k_get_rows(
+                                 src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
+                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
+                         });
 
-    int id;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(id = get_current_device_id()));
+    (void) dst;
+}
 
-    // the main device has a larger memory buffer to hold the results from all GPUs
-    // ldc == nrows of the matrix that cuBLAS writes into
-    int ldc = id == ctx.device ? ne0 : row_diff;
+template 
+static void get_rows_sycl_float(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                const ggml_tensor *src1, ggml_tensor *dst,
+                                const src0_t *src0_dd, const int32_t *src1_dd,
+                                float *dst_dd, queue_ptr stream) {
 
-#ifdef GGML_SYCL_F16
-    bool use_fp16 = true;  // TODO(Yu) SYCL capability check
-#else
-    bool use_fp16 = false;
-#endif
-    if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
-        use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
-        dst->op_params[0] == GGML_PREC_DEFAULT) {
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-        // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
-        ggml_sycl_pool_alloc src0_as_f16(ctx.pool());
-        if (src0->type != GGML_TYPE_F16) {
-            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
-            GGML_ASSERT(to_fp16_sycl != nullptr);
-            size_t ne = row_diff*ne00;
-            src0_as_f16.alloc(ne);
-            to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream);
-        }
-        const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16
-                                         ? (const sycl::half *)src0_dd_i
-                                         : src0_as_f16.get();
+    const sycl::range<3> block_dims(1, 1, SYCL_GET_ROWS_BLOCK_SIZE);
+    const int block_num_x = (ne00 + SYCL_GET_ROWS_BLOCK_SIZE - 1) / SYCL_GET_ROWS_BLOCK_SIZE;
+    const sycl::range<3> block_nums(ne11 * ne12, ne10, block_num_x);
 
-        ggml_sycl_pool_alloc src1_as_f16(ctx.pool());
-        if (src1->type != GGML_TYPE_F16) {
-            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
-            GGML_ASSERT(to_fp16_sycl != nullptr);
-            size_t ne = src1_ncols*ne10;
-            src1_as_f16.alloc(ne);
-            to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);
-        }
-        const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
-                ? (const sycl::half *)src1->data + src1_padded_row_size
-                                         : src1_as_f16.get();
-        ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols);
+    // strides in elements
+    //const size_t s0 = nb0 / ggml_element_size(dst);
+    const size_t s1 = nb1 / ggml_element_size(dst);
+    const size_t s2 = nb2 / ggml_element_size(dst);
+    const size_t s3 = nb3 / ggml_element_size(dst);
 
-        const sycl::half alpha_f16 = 1.0f;
-        const sycl::half beta_f16 = 0.0f;
-#if !GGML_SYCL_DNNL
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
-            *stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
-            &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
-            src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
-            dst_f16.get(), dpct::library_data_t::real_half, ldc,
-            dpct::library_data_t::real_half)));
-        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
-        to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
-#else
-        auto dnnl_stream = ctx.stream_dnnl(stream);
-        DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(),
-            src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt());
-        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
-        to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
-#endif
-    }
-    else {
-        // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
-        ggml_sycl_pool_alloc src0_ddq_as_f32(ctx.pool());
-        ggml_sycl_pool_alloc src1_ddq_as_f32(ctx.pool());
-        if (src0->type != GGML_TYPE_F32) {
-            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
-            GGML_ASSERT(to_fp32_sycl != nullptr);
-            src0_ddq_as_f32.alloc(row_diff*ne00);
-            to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
-        }
-        if (src1->type != GGML_TYPE_F32) {
-            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
-            GGML_ASSERT(to_fp32_sycl != nullptr);
-            src1_ddq_as_f32.alloc(src1_ncols*ne10);
-            to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
-        }
-        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
-        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
+    const size_t s10 = nb10 / ggml_element_size(src1);
+    const size_t s11 = nb11 / ggml_element_size(src1);
+    const size_t s12 = nb12 / ggml_element_size(src1);
+    //const size_t s13 = nb13 / ggml_element_size(src1);
 
-        const float alpha = 1.0f;
-        const float beta = 0.0f;
-#if !GGML_SYCL_DNNL
-        SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
-            *stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
-            dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
-            src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
-            dst_dd_i, ldc)));
-#else
-        auto dnnl_stream = ctx.stream_dnnl(stream);
-         DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(),
-            src0_ddf_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt());
-#endif
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
+
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) {
+                k_get_rows_float(src0_dd, src1_dd, dst_dd, ne00, ne12, s1, s2,
+                                 s3, nb01, nb02, nb03, s10, s11, s12, item_ct1);
+            });
     }
+
     (void) dst;
-    (void) src1_ddq_i;
-    (void) src1_padded_row_size;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
 }
 
-static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                const ggml_tensor *src1, ggml_tensor *dst,
-                                const float *src0_dd, const float *src1_dd,
-                                float *dst_dd, const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+template
+struct bin_bcast_sycl {
+    template 
+    void operator()(ggml_backend_sycl_context & ctx,
+                    const struct ggml_tensor *src0,
+                    const struct ggml_tensor *src1, struct ggml_tensor *dst,
+                    const src0_t *src0_dd, const src1_t *src1_dd, dst_t *dst_dd,
+                    queue_ptr stream) {
 
-    const int32_t * opts = (const int32_t *)dst->op_params;
-    enum ggml_op_pool op = static_cast(opts[0]);
-    const int k0 = opts[1];
-    const int k1 = opts[2];
-    const int s0 = opts[3];
-    const int s1 = opts[4];
-    const int p0 = opts[5];
-    const int p1 = opts[6];
+        GGML_TENSOR_BINARY_OP_LOCALS
 
-    const int64_t IH = src0->ne[1];
-    const int64_t IW = src0->ne[0];
+        int nr0 = ne10/ne0;
+        int nr1 = ne11/ne1;
+        int nr2 = ne12/ne2;
+        int nr3 = ne13/ne3;
 
-    const int64_t N = dst->ne[3];
-    const int64_t OC = dst->ne[2];
-    const int64_t OH = dst->ne[1];
-    const int64_t OW = dst->ne[0];
+        int nr[4] = { nr0, nr1, nr2, nr3 };
 
-    const int parallel_elements = N * OC * OH * OW;
-    const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
-    sycl::range<3> block_nums(1, 1, num_blocks);
-    main_stream->parallel_for(
-        sycl::nd_range<3>(block_nums *
-                              sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
-                          sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
-        [=](sycl::nd_item<3> item_ct1) {
-            pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
-                               parallel_elements, src0_dd, dst_dd, op,
-                               item_ct1);
-        });
+        // collapse dimensions until first broadcast dimension
+        int64_t cne0[] = {ne0, ne1, ne2, ne3};
+        int64_t cne1[] = {ne10, ne11, ne12, ne13};
+        size_t cnb0[] = {nb0, nb1, nb2, nb3};
+        size_t cnb1[] = {nb10, nb11, nb12, nb13};
+        auto collapse = [](int64_t cne[]) {
+            cne[0] *= cne[1];
+            cne[1] = cne[2];
+            cne[2] = cne[3];
+            cne[3] = 1;
+        };
 
-    (void) src1;
-    (void) src1_dd;
-}
+        auto collapse_nb = [](size_t cnb[], int64_t cne[]) {
+            cnb[1] *= cne[1];
+            cnb[2] *= cne[2];
+            cnb[3] *= cne[3];
+        };
 
-inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                  const ggml_tensor *src1, ggml_tensor *dst,
-                                  const float *src0_dd, const float *src1_dd,
-                                  float *dst_dd,
-                                  const queue_ptr &main_stream) {
+        for (int i = 0; i < 4; i++) {
+            if (nr[i] != 1) {
+                break;
+            }
+            if (i > 0) {
+                collapse_nb(cnb0, cne0);
+                collapse_nb(cnb1, cne1);
+                collapse(cne0);
+                collapse(cne1);
+            }
+        }
+        {
+            int64_t ne0 = cne0[0];
+            int64_t ne1 = cne0[1];
+            int64_t ne2 = cne0[2];
+            int64_t ne3 = cne0[3];
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+            int64_t ne10 = cne1[0];
+            int64_t ne11 = cne1[1];
+            int64_t ne12 = cne1[2];
+            int64_t ne13 = cne1[3];
 
-    const int64_t ncols = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+            size_t nb0 = cnb0[0];
+            size_t nb1 = cnb0[1];
+            size_t nb2 = cnb0[2];
+            size_t nb3 = cnb0[3];
 
-    sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
+            size_t nb10 = cnb1[0];
+            size_t nb11 = cnb1[1];
+            size_t nb12 = cnb1[2];
+            size_t nb13 = cnb1[3];
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
+            size_t s0 = nb0 / sizeof(dst_t);
+            size_t s1 = nb1 / sizeof(dst_t);
+            size_t s2 = nb2 / sizeof(dst_t);
+            size_t s3 = nb3 / sizeof(dst_t);
 
-inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1, ggml_tensor *dst,
-                                 const float *src0_dd, const float *src1_dd,
-                                 float *dst_dd,
-                                 const queue_ptr &main_stream) {
+            size_t s10 = nb10 / sizeof(src1_t);
+            size_t s11 = nb11 / sizeof(src1_t);
+            size_t s12 = nb12 / sizeof(src1_t);
+            size_t s13 = nb13 / sizeof(src1_t);
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_I32);
+            GGML_ASSERT(s0 == 1);
+            GGML_ASSERT(s10 == 1);
 
-    const int64_t ncols = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+            const int block_size = 128;
 
-    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+            int64_t hne0 = std::max(ne0/2LL, 1LL);
 
-    argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
+            sycl::range<3> block_dims(1, 1, 1);
+            block_dims[2] = std::min(hne0, block_size);
+            block_dims[1] = std::min(
+                ne1, block_size / (unsigned int)block_dims[2]);
+            block_dims[0] = std::min(
+                std::min(
+                    ne2 * ne3, block_size / (unsigned int)block_dims[2] /
+                                   (unsigned int)block_dims[1]),
+                64U);
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
-}
-
-inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                       const ggml_tensor *src1,
-                                       ggml_tensor *dst, const float *src0_dd,
-                                       const float *src1_dd, float *dst_dd,
-                                       const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+            sycl::range<3> block_nums(
+                (ne2 * ne3 + block_dims[0] - 1) / block_dims[0],
+                (ne1 + block_dims[1] - 1) / block_dims[1],
+                (hne0 + block_dims[2] - 1) / block_dims[2]);
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int nrows0 = ggml_nrows(src0);
+            if (block_nums[0] > 65535) {
+                // this is the maximum number of blocks in z direction, fallback to 1D grid kernel
+                int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
+                {
+                    dpct::has_capability_or_fail(stream->get_device(),
+                                                 {sycl::aspect::fp16});
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+                    stream->parallel_for(
+                        sycl::nd_range<3>(sycl::range<3>(1, 1, block_num) *
+                                              sycl::range<3>(1, 1, block_size),
+                                          sycl::range<3>(1, 1, block_size)),
+                        [=](sycl::nd_item<3> item_ct1) {
+                            k_bin_bcast_unravel(
+                                src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3,
+                                ne10, ne11, ne12, ne13, s1, s2, s3, s11, s12,
+                                s13, item_ct1);
+                        });
+                }
+            } else {
+                /*
+                DPCT1049:16: The work-group size passed to the SYCL kernel may
+                exceed the limit. To get the device limit, query
+                info::device::max_work_group_size. Adjust the work-group size if
+                needed.
+                */
+                dpct::has_capability_or_fail(stream->get_device(),
+                                             {sycl::aspect::fp16});
 
-    diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
+                stream->parallel_for(
+                    sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                    [=](sycl::nd_item<3> item_ct1) {
+                        k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1,
+                                            ne2, ne3, ne10, ne11, ne12, ne13,
+                                            s1, s2, s3, s11, s12, s13,
+                                            item_ct1);
+                    });
+            }
+        }
+    }
+};
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+static void acc_f32_sycl(const float *x, const float *y, float *dst,
+                         const int n_elements, const int ne10, const int ne11,
+                         const int ne12, const int nb1, const int nb2,
+                         const int offset, queue_ptr stream) {
+    int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset,
+                    item_ct1);
+        });
 }
 
-inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                               ggml_tensor *dst, const float *src0_dd,
-                               const float *src1_dd, float *dst_dd,
-                               const queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+static void gelu_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            gelu_f32(x, dst, k, item_ct1);
+        });
+}
 
-    float scale;
-    memcpy(&scale, dst->op_params, sizeof(float));
+static void silu_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SILU_BLOCK_SIZE - 1) / SYCL_SILU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SILU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            silu_f32(x, dst, k, item_ct1);
+        });
+}
 
-    scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
-    /*
-    DPCT1010:87: SYCL uses exceptions to report errors and does not use the
-    error codes. The call was replaced with 0. You need to rewrite this code.
-    */
-    SYCL_CHECK(0);
+static void gelu_quick_f32_sycl(const float *x, float *dst, const int k,
+                                queue_ptr stream) {
+    const int num_blocks = (k + SYCL_GELU_BLOCK_SIZE - 1) / SYCL_GELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_GELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            gelu_quick_f32(x, dst, k, item_ct1);
+        });
+}
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+static void tanh_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_TANH_BLOCK_SIZE - 1) / SYCL_TANH_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_TANH_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            tanh_f32(x, dst, k, item_ct1);
+        });
 }
 
-inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                               ggml_tensor *dst, const float *src0_dd,
-                               const float *src1_dd, float *dst_dd,
-                               const queue_ptr &main_stream) {
+static void relu_f32_sycl(const float *x, float *dst, const int k,
+                          queue_ptr stream) {
+    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            relu_f32(x, dst, k, item_ct1);
+        });
+}
 
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+static void hardsigmoid_f32_sycl(const float *x, float *dst, const int k,
+                                 queue_ptr stream) {
+    const int num_blocks = (k + SYCL_HARDSIGMOID_BLOCK_SIZE - 1) / SYCL_HARDSIGMOID_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_HARDSIGMOID_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            hardsigmoid_f32(x, dst, k, item_ct1);
+        });
+}
 
-    float min;
-    float max;
-    memcpy(&min, dst->op_params, sizeof(float));
-    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
+static void hardswish_f32_sycl(const float *x, float *dst, const int k,
+                               queue_ptr stream) {
+    const int num_blocks = (k + SYCL_HARDSWISH_BLOCK_SIZE - 1) / SYCL_HARDSWISH_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_HARDSWISH_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            hardswish_f32(x, dst, k, item_ct1);
+        });
+}
 
-    clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
-    /*
-    DPCT1010:88: SYCL uses exceptions to report errors and does not use the
-    error codes. The call was replaced with 0. You need to rewrite this code.
-    */
-    SYCL_CHECK(0);
+static void leaky_relu_f32_sycl(const float *x, float *dst, const int k,
+                                const float negative_slope,
+                                queue_ptr stream) {
+    const int num_blocks = (k + SYCL_RELU_BLOCK_SIZE - 1) / SYCL_RELU_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_RELU_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            leaky_relu_f32(x, dst, k, negative_slope, item_ct1);
+        });
+}
 
-    (void) src1;
-    (void) dst;
-    (void) src1_dd;
+static void sqr_f32_sycl(const float *x, float *dst, const int k,
+                         queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SQR_BLOCK_SIZE - 1) / SYCL_SQR_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SQR_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            sqr_f32(x, dst, k, item_ct1);
+        });
 }
 
-static void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1, ggml_tensor *dst,
-                                 const ggml_sycl_op_flatten_t op) try {
-    const int64_t nrows0 = ggml_nrows(src0);
+static void upscale_f32_sycl(const float *x, float *dst, const int nb00, const int nb01,
+                             const int nb02, const int nb03, const int ne10, const int ne11,
+                             const int ne12, const int ne13, const float sf0, const float sf1,
+                             const float sf2, const float sf3, queue_ptr stream) {
+    int dst_size = ne10 * ne11 * ne12 * ne13;
+    int num_blocks = (dst_size + SYCL_UPSCALE_BLOCK_SIZE - 1) / SYCL_UPSCALE_BLOCK_SIZE;
+    sycl::range<1> gridDim(num_blocks * SYCL_UPSCALE_BLOCK_SIZE);
+    stream->parallel_for(
+        sycl::nd_range<1>(gridDim, sycl::range<1>(SYCL_UPSCALE_BLOCK_SIZE)),
+        [=](sycl::nd_item<1> item_ct1) {
+            upscale_f32(x, dst, nb00, nb01, nb02, nb03, ne10, ne11, ne12, ne13, sf0, sf1, sf2, sf3, item_ct1);
+        });
+}
 
-    const bool use_src1 = src1 != nullptr;
-    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
+static void pad_f32_sycl(const float *x, float *dst, const int ne00,
+                         const int ne01, const int ne02, const int ne0,
+                         const int ne1, const int ne2, queue_ptr stream) {
+    int num_blocks = (ne0 + SYCL_PAD_BLOCK_SIZE - 1) / SYCL_PAD_BLOCK_SIZE;
+    sycl::range<3> gridDim(ne2, ne1, num_blocks);
+    stream->parallel_for(
+        sycl::nd_range<3>(gridDim * sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_PAD_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            pad_f32(x, dst, ne0, ne00, ne01, ne02, item_ct1);
+        });
+}
 
-    GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(              dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+static void quantize_row_q8_1_sycl(const float *x, void *vy, const int kx,
+                                   const int ky, const int kx_padded,
+                                   queue_ptr stream) {
+    const int block_num_x = (kx_padded + SYCL_QUANTIZE_BLOCK_SIZE - 1) / SYCL_QUANTIZE_BLOCK_SIZE;
+    const sycl::range<3> num_blocks(1, ky, block_num_x);
+    int constexpr QUANT_BLOCK_TILE = QK8_1 / WARP_SIZE;
+    static_assert(QK8_1 % WARP_SIZE == 0);
+    const sycl::range<3> block_size(1, 1, SYCL_QUANTIZE_BLOCK_SIZE / QUANT_BLOCK_TILE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
 
-    ggml_tensor_extra_gpu * src0_extra =            (ggml_tensor_extra_gpu *) src0->extra;
-    ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
-    ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *)  dst->extra;
-
-    // dd = data device
-    float * src0_ddf = (float *) src0->data;
-    float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
-    float *  dst_ddf = (float *) dst->data;
-
-    ggml_sycl_pool_alloc src0_f(ctx.pool());
-    ggml_sycl_pool_alloc src1_f(ctx.pool());
-    ggml_sycl_pool_alloc  dst_f(ctx.pool());
+        stream->parallel_for(
+            sycl::nd_range<3>(num_blocks * block_size, block_size),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                quantize_q8_1(x, vy, kx, kx_padded, item_ct1);
+            });
+    }
+}
 
-    ggml_sycl_set_device(ctx.device);
-    queue_ptr main_stream = ctx.stream();
-    // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
-        // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
+static void ggml_mul_mat_p021_f16_f32_sycl(const void *vx, const float *y,
+                                           float *dst, const int ncols_x,
+                                           const int nrows_x,
+                                           const int nchannels_x,
+                                           const int nchannels_y,
+                                           queue_ptr stream) {
 
-    // do the computation
-    op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
-    // print_ggml_tensor("tensor", dst);
-}
-catch (sycl::exception const &exc) {
+    const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
+    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
 
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                mul_mat_p021_f16_f32(vx, y, dst, ncols_x, nrows_x, nchannels_x,
+                                     nchannels_y, item_ct1);
+            });
+    }
 }
 
-static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
-    static bool peer_access_enabled = false;
-
-    const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE;
+static void ggml_mul_mat_vec_nc_f16_f32_sycl(
+    const void *vx, const float *y, float *dst, const int ncols_x,
+    const int nrows_x, const int row_stride_x, const int nchannels_x,
+    const int nchannels_y, const int channel_stride_x, queue_ptr stream) {
 
-    if (peer_access_enabled == enable_peer_access) {
-        return;
-    }
+    const sycl::range<3> block_nums(nchannels_y, nrows_x, 1);
+    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
 
-#ifdef NDEBUG
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        SYCL_CHECK(ggml_sycl_set_device(i));
+        stream->parallel_for(
+            sycl::nd_range<3>(block_nums * block_dims, block_dims),
+            [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                mul_mat_vec_nc_f16_f32(vx, y, dst, ncols_x, nrows_x,
+                                       row_stride_x, channel_stride_x,
+                                       nchannels_y / nchannels_x, item_ct1);
+            });
     }
+}
 
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        SYCL_CHECK(ggml_sycl_set_device(i));
+static void
+ggml_cpy_f16_f32_sycl(const char *cx, char *cdst, const int ne, const int ne00,
+                      const int ne01, const int ne02, const int nb00,
+                      const int nb01, const int nb02, const int nb03,
+                      const int ne10, const int ne11, const int ne12,
+                      const int nb10, const int nb11, const int nb12,
+                      const int nb13, queue_ptr stream) {
 
-        for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) {
-            if (i == id_other) {
-                continue;
-            }
-            if (i != main_device && id_other != main_device) {
-                continue;
-            }
+    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
 
-            // int can_access_peer;
-            // SYCL_CHECK(syclDeviceCanAccessPeer(&can_access_peer, id, id_other));
-            // if (can_access_peer) {
-            //     if (enable_peer_access) {
-            //         SYCL_CHECK(syclDeviceEnablePeerAccess(id_other, 0));
-            //     } else {
-            //         SYCL_CHECK(syclDeviceDisablePeerAccess(id_other));
-            //     }
-            // }
-        }
+        stream->parallel_for(
+            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+            [=](sycl::nd_item<3> item_ct1) {
+                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00,
+                                           nb01, nb02, nb03, ne10, ne11, ne12,
+                                           nb10, nb11, nb12, nb13, item_ct1);
+            });
     }
-#endif // NDEBUG
-
-    peer_access_enabled = enable_peer_access;
 }
 
-struct ggml_backend_sycl_split_buffer_type_context {
-    std::array tensor_split;
-};
-
-static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1, ggml_tensor *dst,
-                                 ggml_sycl_op_mul_mat_t op,
-                                 const bool convert_src1_to_q8_1) try {
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
+static void ggml_cpy_f32_f32_sycl(const char *cx, char *cdst, const int ne,
+                                  const int ne00, const int ne01,
+                                  const int ne02, const int nb00,
+                                  const int nb01, const int nb02,
+                                  const int nb03, const int ne10,
+                                  const int ne11, const int ne12,
+                                  const int nb10, const int nb11,
+                                  const int nb12, const int nb13,
+                                  queue_ptr stream) {
 
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
-    const int64_t nrows1 = ggml_nrows(src1);
+    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
 
-    GGML_ASSERT(ne03 == ne13);
+        stream->parallel_for(
+            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+            [=](sycl::nd_item<3> item_ct1) {
+                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+                                           item_ct1);
+            });
+    }
+}
 
-    const int64_t ne0 = dst->ne[0];
-    const int64_t ne1 = dst->ne[1];
+static void ggml_cpy_f32_f16_sycl(const char *cx, char *cdst, const int ne,
+                                  const int ne00, const int ne01,
+                                  const int ne02, const int nb00,
+                                  const int nb01, const int nb02,
+                                  const int nb03, const int ne10,
+                                  const int ne11, const int ne12,
+                                  const int nb10, const int nb11,
+                                  const int nb12, const int nb13,
+                                  queue_ptr stream) {
 
-    const int nb2 = dst->nb[2];
-    const int nb3 = dst->nb[3];
+    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
 
-    GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
+        stream->parallel_for(
+            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+            [=](sycl::nd_item<3> item_ct1) {
+                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+                                           item_ct1);
+            });
+    }
+}
 
-    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
+static void ggml_cpy_f32_q8_0_sycl(const char *cx, char *cdst, const int ne,
+                                   const int ne00, const int ne01,
+                                   const int ne02, const int nb00,
+                                   const int nb01, const int nb02,
+                                   const int nb03, const int ne10,
+                                   const int ne11, const int ne12,
+                                   const int nb10, const int nb11,
+                                   const int nb12, const int nb13,
+                                   queue_ptr stream) {
 
-    const int64_t i02_divisor = ne12 / ne02;
+    GGML_ASSERT(ne % QK8_0 == 0);
+    const int num_blocks = ne / QK8_0;
+    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
+                                           sycl::range<3>(1, 1, 1)),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             cpy_f32_q(
+                                 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+                                 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+                                 item_ct1);
+                         });
+}
 
-    const size_t src0_ts = ggml_type_size(src0->type);
-    const size_t src0_bs = ggml_blck_size(src0->type);
-    const size_t q8_1_ts = sizeof(block_q8_1);
-    const size_t q8_1_bs = QK8_1;
+static void ggml_cpy_f32_q4_0_sycl(const char *cx, char *cdst, const int ne,
+                                   const int ne00, const int ne01,
+                                   const int ne02, const int nb00,
+                                   const int nb01, const int nb02,
+                                   const int nb03, const int ne10,
+                                   const int ne11, const int ne12,
+                                   const int nb10, const int nb11,
+                                   const int nb12, const int nb13,
+                                   queue_ptr stream) {
 
-    ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
-    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
-    ggml_tensor_extra_gpu *  dst_extra = (ggml_tensor_extra_gpu *)  dst->extra;
+    GGML_ASSERT(ne % QK4_0 == 0);
+    const int num_blocks = ne / QK4_0;
+    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
+                                           sycl::range<3>(1, 1, 1)),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             cpy_f32_q(
+                                 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+                                 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+                                 item_ct1);
+                         });
+}
 
-    const bool src0_is_contiguous = ggml_is_contiguous(src0);
-    const bool src1_is_contiguous = ggml_is_contiguous(src1);
+static void ggml_cpy_f32_q4_1_sycl(const char *cx, char *cdst, const int ne,
+                                   const int ne00, const int ne01,
+                                   const int ne02, const int nb00,
+                                   const int nb01, const int nb02,
+                                   const int nb03, const int ne10,
+                                   const int ne11, const int ne12,
+                                   const int nb10, const int nb11,
+                                   const int nb12, const int nb13,
+                                   queue_ptr stream) {
 
-    int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+    GGML_ASSERT(ne % QK4_1 == 0);
+    const int num_blocks = ne / QK4_1;
+    stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks),
+                                           sycl::range<3>(1, 1, 1)),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             cpy_f32_q(
+                                 cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+                                 nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+                                 item_ct1);
+                         });
+}
 
-    const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
-    GGML_ASSERT(!(split && ne02 > 1));
-    GGML_ASSERT(!(split && ne03 > 1));
-    GGML_ASSERT(!(split && ne02 < ne12));
-
-    std::array tensor_split;
-    if (split) {
-        // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check
-        // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);
-        ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
-        tensor_split = buft_ctx->tensor_split;
-    }
-
-    struct dev_data {
-        ggml_sycl_pool_alloc src0_dd_alloc;
-        ggml_sycl_pool_alloc src1_ddf_alloc;
-        ggml_sycl_pool_alloc src1_ddq_alloc;
-        ggml_sycl_pool_alloc dst_dd_alloc;
+static void ggml_cpy_f16_f16_sycl(const char *cx, char *cdst, const int ne,
+                                  const int ne00, const int ne01,
+                                  const int ne02, const int nb00,
+                                  const int nb01, const int nb02,
+                                  const int nb03, const int ne10,
+                                  const int ne11, const int ne12,
+                                  const int nb10, const int nb11,
+                                  const int nb12, const int nb13,
+                                  queue_ptr stream) {
 
-        char *src0_dd = nullptr;
-        float *src1_ddf = nullptr; // float
-        char *src1_ddq = nullptr;  // q8_1
-        float *dst_dd = nullptr;
+    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+    {
+        dpct::has_capability_or_fail(stream->get_device(),
+                                     {sycl::aspect::fp16});
 
-        int64_t row_low;
-        int64_t row_high;
-    };
+        stream->parallel_for(
+            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+            [=](sycl::nd_item<3> item_ct1) {
+                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+                                           item_ct1);
+            });
+    }
+}
 
-    dev_data dev[GGML_SYCL_MAX_DEVICES];
+static void ggml_cpy_i16_i16_sycl(const char *cx, char *cdst, const int ne,
+                                  const int ne00, const int ne01,
+                                  const int ne02, const int nb00,
+                                  const int nb01, const int nb02,
+                                  const int nb03, const int ne10,
+                                  const int ne11, const int ne12,
+                                  const int nb10, const int nb11,
+                                  const int nb12, const int nb13,
+                                  queue_ptr stream) {
 
-    int used_devices = 0;
-    queue_ptr main_stream = ctx.stream();
+    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+    {
+        // dpct::has_capability_or_fail(stream->get_device(),
+        //                              {sycl::aspect::fp16});
 
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        // by default, use all rows
-        dev[i].row_low  = 0;
-        dev[i].row_high = ne01;
+        stream->parallel_for(
+            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+            [=](sycl::nd_item<3> item_ct1) {
+                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+                                           item_ct1);
+            });
+    }
+}
 
-        // for multi GPU, get the row boundaries from tensor split
-        // and round to mul_mat_q tile sizes
-        if (split) {
-            const int64_t rounding = get_row_rounding(src0->type, tensor_split);
+static void ggml_cpy_i32_i32_sycl(const char *cx, char *cdst, const int ne,
+                                  const int ne00, const int ne01,
+                                  const int ne02, const int nb00,
+                                  const int nb01, const int nb02,
+                                  const int nb03, const int ne10,
+                                  const int ne11, const int ne12,
+                                  const int nb10, const int nb11,
+                                  const int nb12, const int nb13,
+                                  queue_ptr stream) {
 
-            if (i != 0) {
-                dev[i].row_low  = ne01*tensor_split[i];
-                if (dev[i].row_low < ne01) {
-                    dev[i].row_low -= dev[i].row_low % rounding;
-                }
-            }
+    const int num_blocks = (ne + SYCL_CPY_BLOCK_SIZE - 1) / SYCL_CPY_BLOCK_SIZE;
+    {
+        // dpct::has_capability_or_fail(stream->get_device(),
+        //                              {sycl::aspect::fp16});
 
-            if (i != ggml_sycl_info().device_count - 1) {
-                dev[i].row_high  = ne01*tensor_split[i + 1];
-                if (dev[i].row_high < ne01) {
-                    dev[i].row_high -= dev[i].row_high % rounding;
-                }
-            }
-        }
+        stream->parallel_for(
+            sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                                  sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE),
+                              sycl::range<3>(1, 1, SYCL_CPY_BLOCK_SIZE)),
+            [=](sycl::nd_item<3> item_ct1) {
+                cpy_f32_f16(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02,
+                                           nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13,
+                                           item_ct1);
+            });
     }
+}
 
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
-            continue;
-        }
-
-        used_devices++;
+static void scale_f32_sycl(const float *x, float *dst, const float scale,
+                           const int k, queue_ptr stream) {
+    const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            scale_f32(x, dst, scale, k, item_ct1);
+        });
+}
 
-        const bool src1_on_device = i == ctx.device;
-        const bool  dst_on_device = i == ctx.device;
+static void clamp_f32_sycl(const float *x, float *dst, const float min,
+                           const float max, const int k,
+                           queue_ptr stream) {
+    const int num_blocks = (k + SYCL_CLAMP_BLOCK_SIZE - 1) / SYCL_CLAMP_BLOCK_SIZE;
+    stream->parallel_for(
+        sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) *
+                              sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_CLAMP_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            clamp_f32(x, dst, min, max, k, item_ct1);
+        });
+}
 
-        ggml_sycl_set_device(i);
-        queue_ptr stream = ctx.stream(i, 0);
+static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
+                              const int nrows, queue_ptr stream) {
+    const sycl::range<3> block_dims(1, 1, WARP_SIZE);
+    const sycl::range<3> block_nums(1, nrows, 1);
+    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                         [=](sycl::nd_item<3> item_ct1)
+                             [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                                 k_sum_rows_f32(x, dst, ncols, item_ct1);
+                             });
+}
 
-        if (src0_is_contiguous) {
-            dev[i].src0_dd = (char *) src0->data;
-        } else {
-            dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0));
-        }
+static int next_power_of_2(int x) {
+    int n = 1;
+    while (n < x) {
+        n *= 2;
+    }
+    return n;
+}
 
-        if (src1_on_device && src1_is_contiguous) {
-            dev[i].src1_ddf = (float *) src1->data;
-        } else {
-            dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
-        }
+static void argsort_f32_i32_sycl(const float *x, int *dst, const int ncols,
+                                 const int nrows, ggml_sort_order order,
+                                 queue_ptr stream) {
+    // bitonic sort requires ncols to be power of 2
+    const int ncols_pad = next_power_of_2(ncols);
 
-        if (convert_src1_to_q8_1) {
-            dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
+    const sycl::range<3> block_dims(1, 1, ncols_pad);
+    const sycl::range<3> block_nums(1, nrows, 1);
+    const size_t shared_mem = ncols_pad * sizeof(int);
 
-            if (src1_on_device && src1_is_contiguous) {
-                quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
-                /*
-                DPCT1010:90: SYCL uses exceptions to report errors and does not
-                use the error codes. The call was replaced with 0. You need to
-                rewrite this code.
-                */
-                SYCL_CHECK(0);
-            }
-        }
+    if (order == GGML_SORT_ORDER_ASC) {
+        stream->submit([&](sycl::handler &cgh) {
+            sycl::local_accessor dpct_local_acc_ct1(
+                sycl::range<1>(shared_mem), cgh);
 
-        if (dst_on_device) {
-            dev[i].dst_dd = (float *) dst->data;
-        } else {
-            const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst);
-            dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf);
-        }
-    }
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1) {
+                    k_argsort_f32_i32(
+                        x, dst, ncols, ncols_pad, item_ct1,
+                        dpct_local_acc_ct1.get_multi_ptr()
+                            .get());
+                });
+        });
+    } else if (order == GGML_SORT_ORDER_DESC) {
+        stream->submit([&](sycl::handler &cgh) {
+            sycl::local_accessor dpct_local_acc_ct1(
+                sycl::range<1>(shared_mem), cgh);
 
-    // if multiple devices are used they need to wait for the main device
-    // here an event is recorded that signals that the main device has finished calculating the input data
-    if (split && used_devices > 1) {
-        ggml_sycl_set_device(ctx.device);
-        /*
-        DPCT1024:91: The original code returned the error code that was further
-        consumed by the program logic. This original code was replaced with 0.
-        You may need to rewrite the program logic consuming the error code.
-        */
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            *src0_extra->events[ctx.device][0] =
-                ctx.stream()->ext_oneapi_submit_barrier()));
+            cgh.parallel_for(
+                sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                [=](sycl::nd_item<3> item_ct1) {
+                    k_argsort_f32_i32(
+                        x, dst, ncols, ncols_pad, item_ct1,
+                        dpct_local_acc_ct1.get_multi_ptr()
+                            .get());
+                });
+        });
+    } else {
+        GGML_ABORT("fatal error");
     }
+}
 
-    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
-    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
-        const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
-        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
-
-        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-            if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
-                continue;
-            }
-
-            const bool src1_on_device = i == ctx.device;
-            const bool  dst_on_device = i == ctx.device;
-            const int64_t row_diff = dev[i].row_high - dev[i].row_low;
-
-            ggml_sycl_set_device(i);
-            queue_ptr stream = ctx.stream(i, is);
-
-            // wait for main GPU data if necessary
-            if (split && (i != ctx.device || is != 0)) {
-                /*
-                DPCT1009:163: SYCL uses exceptions to report errors and does not
-                use the error codes. The original code was commented out and a
-                warning string was inserted. You need to rewrite this code.
-                */
-                SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
-                    {*src0_extra->events[ctx.device][0]})));
-            }
-
-            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
-                const int64_t i03 = i0 / ne12;
-                const int64_t i02 = i0 % ne12;
-
-                const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
-
-                // for split tensors the data begins at i0 == i0_offset_low
-                char  *  src0_dd_i =  dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
-                float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
-                char  * src1_ddq_i = dev[i].src1_ddq +  src1_ddq_i_offset;
-                float *   dst_dd_i =   dev[i].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
-
-                // the main device memory buffer can be on VRAM scratch, with space for all partial results
-                // in that case an offset on dst_ddf_i is needed
-                if (i == ctx.device) {
-                    dst_dd_i += dev[i].row_low; // offset is 0 if no tensor split
-                }
-
-                // copy src0, src1 to device if necessary
-                if (src1_is_contiguous) {
-                    if (i != ctx.device) {
-                        if (convert_src1_to_q8_1) {
-                            char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
-                          SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
-                                src1_ddq_i, src1_ddq_i_source,
-                                src1_ncols * src1_padded_col_size * q8_1_ts /
-                                    q8_1_bs).wait()));
-                        } else {
-
-                            float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
-                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
-
-                            SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
-                                src1_ddf_i, src1_ddf_i_source,
-                                src1_ncols * ne10 * sizeof(float))));
-                        }
-                    }
-                } else if (src1_on_device && !src1_is_contiguous) {
-                    SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
-                                   src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
-                } else {
-                    GGML_ABORT("fatal error");
-                }
-
-                if (convert_src1_to_q8_1 && !src1_is_contiguous) {
-                    quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
-                    /*
-                    DPCT1010:92: SYCL uses exceptions to report errors and does
-                    not use the error codes. The call was replaced with 0. You
-                    need to rewrite this code.
-                    */
-                    SYCL_CHECK(0);
-                }
-
-                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
-                    SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));
-                }
-                if (src1->type == GGML_TYPE_F16) {
-                    src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
-                }
-                // do the computation
-                SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
-                    dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
-                /*
-                DPCT1010:93: SYCL uses exceptions to report errors and does not
-                use the error codes. The call was replaced with 0. You need to
-                rewrite this code.
-                */
-                SYCL_CHECK(0);
-
-                // copy dst to host or other device if necessary
-                if (!dst_on_device) {
-                    void * dst_off_device = dst->data;
-                    if (split) {
-                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
-                        // dst is NOT transposed.
-                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
-                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
-                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
-                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
-                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
-                        dhf_dst_i += src1_col_0*ne0 + dev[i].row_low;
+static void diag_mask_inf_f32_sycl(const float *x, float *dst,
+                                   const int ncols_x, const int nrows_x,
+                                   const int rows_per_channel, const int n_past,
+                                   queue_ptr stream) {
+    const sycl::range<3> block_dims(1, SYCL_DIAG_MASK_INF_BLOCK_SIZE, 1);
+    const int block_num_x = (ncols_x + SYCL_DIAG_MASK_INF_BLOCK_SIZE - 1) / SYCL_DIAG_MASK_INF_BLOCK_SIZE;
+    const sycl::range<3> block_nums(1, block_num_x, nrows_x);
+    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
+                         [=](sycl::nd_item<3> item_ct1) {
+                             diag_mask_inf_f32(x, dst, ncols_x,
+                                               rows_per_channel, n_past,
+                                               item_ct1);
+                         });
+}
 
-                        SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
-                            dhf_dst_i, ne0 * sizeof(float), dst_dd_i,
-                            row_diff * sizeof(float), row_diff * sizeof(float),
-                            src1_ncols, dpct::device_to_device, *stream)));
-                    } else {
-                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
-                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
-                        dhf_dst_i += src1_col_0*ne0;
-                        SYCL_CHECK(CHECK_TRY_ERROR(
-                            stream->memcpy(dhf_dst_i, dst_dd_i,
-                                           src1_ncols * ne0 * sizeof(float)).wait()));
-                    }
-                }
+static dpct::err0 ggml_sycl_cpy_tensor_2d(void *dst,
+                                          const struct ggml_tensor *src,
+                                          int64_t i3, int64_t i2,
+                                          int64_t i1_low, int64_t i1_high,
+                                          queue_ptr stream) try {
 
-                // add event for the main device to wait on until other device is done
-                if (split && (i != ctx.device || is != 0)) {
-                    /*
-                    DPCT1024:94: The original code returned the error code that
-                    was further consumed by the program logic. This original
-                    code was replaced with 0. You may need to rewrite the
-                    program logic consuming the error code.
-                    */
-                    SYCL_CHECK(CHECK_TRY_ERROR(
-                        *src0_extra->events[i][is] =
-                            stream->ext_oneapi_submit_barrier()));
-                }
-            }
-        }
+    dpct::memcpy_direction kind;
+    char * src_ptr;
+    if (src->backend == GGML_BACKEND_TYPE_CPU) {
+        kind = dpct::host_to_device;
+        src_ptr = (char *) src->data;
+        // GGML_SYCL_DEBUG("ggml_sycl_cpy_tensor_2d  GGML_BACKEND_TYPE_CPU src_ptr %p\n", src_ptr);
+    } else if (src->backend == GGML_BACKEND_TYPE_GPU || src->backend == GGML_BACKEND_TYPE_GPU_SPLIT) {
+        GGML_ASSERT(src->backend != GGML_BACKEND_TYPE_GPU_SPLIT || (i1_low == 0 && i1_high == src->ne[1]));
+        kind = dpct::device_to_device;
+        ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) src->extra;
+        int id;
+        SYCL_CHECK(CHECK_TRY_ERROR(
+            id = get_current_device_id()));
+        // GGML_SYCL_DEBUG("current device index %d\n", id);
+        src_ptr = (char *) extra->data_device[id];
+    } else {
+        // GGML_SYCL_DEBUG("GGML_ABORT("fatal error")\n");
+        GGML_ABORT("fatal error");
     }
+    char * dst_ptr = (char *) dst;
 
-    // main device waits for all other devices to be finished
-    if (split && ggml_sycl_info().device_count > 1) {
-        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
-        is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS;
+    GGML_TENSOR_LOCALS_1(int64_t, ne, src, ne);
+    GGML_TENSOR_LOCALS(int64_t, nb, src, nb);
+    const enum ggml_type type = src->type;
+    const int64_t ts = ggml_type_size(type);
+    const int64_t bs = ggml_blck_size(type);
+    int64_t i1_diff = i1_high - i1_low;
 
-        ggml_sycl_set_device(ctx.device);
-        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-            if (dev[i].row_low == dev[i].row_high) {
-                continue;
-            }
-            for (int64_t is = 0; is < is_max; ++is) {
-                SYCL_CHECK(CHECK_TRY_ERROR(
-                    ctx.stream()->ext_oneapi_submit_barrier(
-                        {*src0_extra->events[i][is]})));
-            }
+    const char * x = src_ptr + i1_low*nb1 + i2*nb2 + i3*nb3;
+    if (nb0 == ts && nb1 == ts*ne0/bs) {
+        // GGML_SYCL_DEBUG("stream->memcpy: dst_ptr=%p, x=%p, size=%lu\n", dst_ptr, x, i1_diff * nb1);
+        // return CHECK_TRY_ERROR(stream->memcpy(dst_ptr, x, i1_diff * nb1));
+        return CHECK_TRY_ERROR(dpct::async_dpct_memcpy(dst_ptr, x, i1_diff * nb1,
+                                    kind, *stream));
+
+    } else if (nb0 == ts) {
+        return CHECK_TRY_ERROR(
+            dpct::async_dpct_memcpy(dst_ptr, ts * ne0 / bs, x, nb1,
+                                    ts * ne0 / bs, i1_diff, kind, *stream));
+    } else {
+        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
+            const void * rx = (const void *) ((const char *) x + i1*nb1);
+            void * rd = (void *) (dst_ptr + i1*ts*ne0/bs);
+            // pretend the row is a matrix with cols=1
+            dpct::err0 r = CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
+                rd, ts / bs, rx, nb0, ts / bs, ne0, kind, *stream));
+            /*
+            DPCT1001:85: The statement could not be removed.
+            */
+            /*
+            DPCT1000:86: Error handling if-stmt was detected but could not be
+            rewritten.
+            */
+            if (r != 0) return r;
         }
+        return 0;
     }
 }
 catch (sycl::exception const &exc) {
@@ -3112,1044 +2919,993 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
+static void ggml_sycl_op_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                  const ggml_tensor *src1, ggml_tensor *dst,
+                                  const float *src0_d, const float *src1_d,
+                                  float *dst_d, const queue_ptr &stream) {
 
-static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+    GGML_ASSERT(src1->type == GGML_TYPE_I32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
-static void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+    GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+    GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+    GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
 
-static void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+    const int32_t * src1_i32 = (const int32_t *) src1_d;
 
-static void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
-
-static void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
+    switch (src0->type) {
+        case GGML_TYPE_F16:
+            get_rows_sycl_float(ctx, src0, src1, dst, (const sycl::half *)src0_d,
+                                src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_F32:
+            get_rows_sycl_float(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q4_0:
+            get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q4_1:
+            get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q5_0:
+            get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q5_1:
+            get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        case GGML_TYPE_Q8_0:
+            get_rows_sycl(ctx, src0, src1, dst, src0_d, src1_i32, dst_d, stream);
+            break;
+        default:
+            // TODO: k-quants
+            fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
+            GGML_ABORT("fatal error");
+            break;
+    }
 }
 
-static void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+template 
+inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd,
+                                   const queue_ptr &main_stream) {
 
-static void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
+    if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
+        op()(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
+        op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd,
+             (sycl::half *)dst_dd, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+        op()(ctx, src0, src1, dst, (const sycl::half *)src0_dd, src1_dd, dst_dd,
+             main_stream);
+    } else if (src0->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) {
+        op()(ctx, src0, src1, dst, (const int32_t *)src0_dd, (const int32_t *)src1_dd, (int32_t *)dst_dd,
+             main_stream);
+    } else if (src0->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) {
+        op()(ctx, src0, src1, dst, (const int16_t *)src0_dd, (const int16_t *)src1_dd, (int16_t *)dst_dd,
+             main_stream);
+    } else {
+        fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
+            ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
+        GGML_ABORT("fatal error");
+    }
 }
 
-static void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+static void ggml_sycl_op_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                const ggml_tensor *src1, ggml_tensor *dst,
+                                const float *src0_d, const float *src1_d,
+                                float *dst_d,
+                                const queue_ptr &main_stream) {
 
-static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+    ggml_sycl_op_bin_bcast>(ctx, dst, src0, dst, nullptr, src0_d, dst_d, main_stream);
 
-static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
+    (void) src1;
+    (void) src1_d;
 }
 
-static void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+inline void ggml_sycl_op_add(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
 
-static void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
+    ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
 }
 
-static void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+inline void ggml_sycl_op_acc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
 
-static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
 
-static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+    int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+    int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+    // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+    int offset = dst->op_params[3] / 4; // offset in bytes
 
-static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
-}
+    acc_f32_sycl(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
 
-static void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
+    (void) dst;
 }
 
+inline void ggml_sycl_op_mul(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
 
-static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_SYCL_DEBUG("call %s\n", __func__);
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
-    GGML_SYCL_DEBUG("call %s done\n", __func__);
+    ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
 }
 
-static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                       const ggml_tensor *src1,
-                                       ggml_tensor *dst) try {
-    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
-    GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
-    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+inline void ggml_sycl_op_div(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
+    ggml_sycl_op_bin_bcast>(ctx, src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
+}
 
-    const int64_t ne12 = src1->ne[2];
+inline void ggml_sycl_op_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                              ggml_tensor *dst, const float *src0_dd,
+                              const float *src1_dd, float *dst_dd,
+                              const queue_ptr &main_stream) {
 
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    queue_ptr main_stream = ctx.stream();
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    void  * src0_ddq = src0->data;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf  = (float *) dst->data;
+    gelu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
-    ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
 }
 
-static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                     const ggml_tensor *src1,
-                                     ggml_tensor *dst) try {
-    GGML_ASSERT(!ggml_is_transposed(src0));
-    GGML_ASSERT(!ggml_is_transposed(src1));
-    GGML_ASSERT(!ggml_is_permuted(src0));
-    GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+inline void ggml_sycl_op_silu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                              ggml_tensor *dst, const float *src0_dd,
+                              const float *src1_dd, float *dst_dd,
+                              const queue_ptr &main_stream) {
 
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const int64_t nb01 = src0->nb[1];
-    const int64_t nb02 = src0->nb[2];
+    silu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
-    const int64_t ne12 = src1->ne[2];
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
 
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    queue_ptr main_stream = ctx.stream();
+inline void ggml_sycl_op_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                    const ggml_tensor *src1, ggml_tensor *dst,
+                                    const float *src0_dd, const float *src1_dd,
+                                    float *dst_dd,
+                                    const queue_ptr &main_stream) {
 
-    void  * src0_ddq = src0->data;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf  = (float *) dst->data;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const int64_t row_stride_x = nb01 / sizeof(sycl::half);
-    const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
+    gelu_quick_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
-    ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
 }
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
+
+inline void ggml_sycl_op_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                              ggml_tensor *dst, const float *src0_dd,
+                              const float *src1_dd, float *dst_dd,
+                              const queue_ptr &main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+    tanh_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
 }
 
-static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
-                                   const sycl::half *src1_as_f16, char *dst,
-                                   const void **ptrs_src, void **ptrs_dst,
-                                   int64_t ne12, int64_t ne13, int64_t ne23,
-                                   size_t nb02, size_t nb03, size_t nb12,
-                                   size_t nb13, size_t nbd2, size_t nbd3,
-                                   int64_t r2, int64_t r3,
-                                   const sycl::nd_item<3> &item_ct1) {
-    int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
-                  item_ct1.get_local_id(2);
-    int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
-                  item_ct1.get_local_id(1);
+inline void ggml_sycl_op_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                              ggml_tensor *dst, const float *src0_dd,
+                              const float *src1_dd, float *dst_dd,
+                              const queue_ptr &main_stream) {
 
-    if (i13 >= ne13 || i12 >= ne12) {
-        return;
-    }
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    int64_t i03 = i13 / r3;
-    int64_t i02 = i12 / r2;
+    relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
-    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
-    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
-    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
 }
 
-static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
-                                             const ggml_tensor *src0,
-                                             const ggml_tensor *src1,
-                                             ggml_tensor *dst) try {
-    GGML_ASSERT(!ggml_is_transposed(src0));
-    GGML_ASSERT(!ggml_is_transposed(src1));
-    GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
-    GGML_ASSERT(src0->type == GGML_TYPE_F16);
-
-    GGML_TENSOR_BINARY_OP_LOCALS
+static void ggml_sycl_op_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                     const ggml_tensor *src1, ggml_tensor *dst,
+                                     const float *src0_dd, const float *src1_dd,
+                                     float *dst_dd,
+                                     const queue_ptr &main_stream) {
 
-    const int64_t ne_dst = ggml_nelements(dst);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    queue_ptr main_stream = ctx.stream();;
+    hardsigmoid_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
-    void * src0_ddq = src0->data;
-    sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
-    float * src1_ddf = (float *) src1->data;
-    float * dst_ddf = (float *) dst->data;
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
 
-    // convert src1 to fp16
-    ggml_sycl_pool_alloc src1_f16_alloc(ctx.pool());
-    if (src1->type != GGML_TYPE_F16) {
-        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
-        const int64_t ne_src1 = ggml_nelements(src1);
-        src1_f16_alloc.alloc(ne_src1);
-        GGML_ASSERT(to_fp16_sycl != nullptr);
-        to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
-    }
-    sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
-                                                       : src1_f16_alloc.get();
+static void ggml_sycl_op_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                   const ggml_tensor *src1, ggml_tensor *dst,
+                                   const float *src0_dd, const float *src1_dd,
+                                   float *dst_dd, const queue_ptr &main_stream) {
 
-    char * dst_t;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
-    dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
+    hardswish_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
-    // dst strides
-    size_t nbd2 = dst->nb[2];
-    size_t nbd3 = dst->nb[3];
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
 
-    const float alpha_f32 = 1.0f;
-    const float beta_f32 = 0.0f;
+inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                    const ggml_tensor *src1, ggml_tensor *dst,
+                                    const float *src0_dd, const float *src1_dd,
+                                    float *dst_dd,
+                                    const queue_ptr &main_stream) {
 
-    const void * alpha = &alpha_f32;
-    const void * beta  = &beta_f32;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    dst_t = (char *) dst_ddf;
+    float negative_slope;
+    memcpy(&negative_slope, dst->op_params, sizeof(float));
 
-    GGML_ASSERT(ne12 % ne02 == 0);
-    GGML_ASSERT(ne13 % ne03 == 0);
+    leaky_relu_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
 
-    // broadcast factors
-    const int64_t r2 = ne12/ne02;
-    const int64_t r3 = ne13/ne03;
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
 
-    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
-        // there is no broadcast and src0, src1 are contiguous across dims 2, 3
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-            *main_stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
-            (const char *)src0_as_f16, dpct::library_data_t::real_half,
-            nb01 / nb00, nb02 / nb00,
-            (const char *)src1_f16, dpct::library_data_t::real_half,
-            nb11 / nb10, nb12 / nb10, beta,
-            (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
-            ne12 * ne13, cu_compute_type)));
-    } else {
-        const int ne23 = ne12*ne13;
+inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
 
-        ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23);
-        ggml_sycl_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-        sycl::range<3> block_dims(1, ne12, ne13);
-        /*
-        DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
-        the limit. To get the device limit, query
-        info::device::max_work_group_size. Adjust the work-group size if needed.
-        */
-        {
-            dpct::has_capability_or_fail(main_stream->get_device(),
-                                         {sycl::aspect::fp16});
+    sqr_f32_sycl(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
 
-            main_stream->submit([&](sycl::handler &cgh) {
-                const void **ptrs_src_get = ptrs_src.get();
-                void **ptrs_dst_get = ptrs_dst.get();
-                size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
-                size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
-                cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
-                                 [=](sycl::nd_item<3> item_ct1) {
-                                     k_compute_batched_ptrs(
-                                         src0_as_f16, src1_f16,
-                                         dst_t, ptrs_src_get,
-                                         ptrs_dst_get, ne12, ne13, ne23,
-                                         nb02, nb03, nb12_scaled, nb13_scaled,
-                                         nbd2, nbd3, r2, r3, item_ct1);
-                                 });
-            });
-        }
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
-            *main_stream, oneapi::mkl::transpose::trans,
-            oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
-            (const void **)(ptrs_src.get() + 0 * ne23),
-            dpct::library_data_t::real_half, nb01 / nb00,
-            (const void **)(ptrs_src.get() + 1 * ne23),
-            dpct::library_data_t::real_half, nb11 / nb10, beta,
-            (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
-            cu_compute_type)));
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
 }
 
-inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
-    // TODO: accuracy issues in MMQ
-    return false;
-}
+inline void ggml_sycl_op_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                 const ggml_tensor *src1, ggml_tensor *dst,
+                                 const float *src0_dd, const float *src1_dd,
+                                 float *dst_dd,
+                                 const queue_ptr &main_stream) {
 
-bool ggml_sycl_supports_dmmv(enum ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_F16:
-            return true;
-        default:
-            return false;
-    }
-}
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
 
-static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
-    int64_t min_compute_capability = INT_MAX;
+    const float sf0 = (float)dst->ne[0]/src0->ne[0];
+    const float sf1 = (float)dst->ne[1]/src0->ne[1];
+    const float sf2 = (float)dst->ne[2]/src0->ne[2];
+    const float sf3 = (float)dst->ne[3]/src0->ne[3];
 
-    if (split) {
-        ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
-        auto & tensor_split = buft_ctx->tensor_split;
-        for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
-            // skip devices that are not going to do any work:
-            if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) {
-                continue;
-            }
+    upscale_f32_sycl(src0_dd, dst_dd, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
+                     dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3,
+                     main_stream);
 
-            if (min_compute_capability > ggml_sycl_info().devices[id].cc) {
-                min_compute_capability = ggml_sycl_info().devices[id].cc;
-            }
-        }
-    } else {
-        min_compute_capability    = ggml_sycl_info().devices[ctx.device].cc;
-    }
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
 
-    // check data types and tensor shapes for custom matrix multiplication kernels:
-    bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
+inline void ggml_sycl_op_pad(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                             ggml_tensor *dst, const float *src0_dd,
+                             const float *src1_dd, float *dst_dd,
+                             const queue_ptr &main_stream) {
 
-    bool use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
-        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->type == GGML_TYPE_F32);
+    GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
 
-    bool use_mul_mat_q =  ggml_sycl_supports_mmq(src0->type)
-        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+    pad_f32_sycl(src0_dd, dst_dd,
+        src0->ne[0], src0->ne[1], src0->ne[2],
+        dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
 
-    // mmvq and mmq need the __dp4a instruction which is available for gen12+
-    // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
-    use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
-#ifdef SYCL_USE_XMX
-    use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
-#endif // SYCL_USE_XMX
-
-    // mmvq path is faster in the CUDA backend.
-    if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
-        use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
-
-    if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
-        // KQ single-batch
-        ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
-        // KQV single-batch
-        ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
-    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
-        // KQ + KQV multi-batch
-        ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
-    } else if (use_dequantize_mul_mat_vec) {
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
-    } else if (use_mul_mat_vec_q) {
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
-    } else if (use_mul_mat_q) {
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
-    } else {
-        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
-    }
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
 }
 
+inline void ggml_sycl_op_mul_mat_sycl(
+    ggml_backend_sycl_context & ctx,
+    const ggml_tensor *src0, const ggml_tensor *src1, ggml_tensor *dst,
+    const char *src0_dd_i, const float *src1_ddf_i, const char *src1_ddq_i,
+    float *dst_dd_i, const int64_t row_low, const int64_t row_high,
+    const int64_t src1_ncols, const int64_t src1_padded_row_size,
+    const queue_ptr &stream) try {
 
-struct mmid_row_mapping {
-    int32_t i1;
-    int32_t i2;
-};
-
-__dpct_inline__ static void k_copy_src1_to_contiguous(
-    const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
-    int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
-    const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
-    int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
-    const sycl::nd_item<3> &item_ct1, int &src1_row) {
-    int32_t iid1 = item_ct1.get_group(2);
-    int32_t id = item_ct1.get_group(1);
+    GGML_ASSERT(src0_dd_i  != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_dd_i   != nullptr);
 
-    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne10 = src1->ne[0];
 
-    if (row_id_i != i02) {
-        return;
-    }
+    const int64_t ne0 = dst->ne[0];
 
-    const int64_t i11 = id % ne11;
-    const int64_t i12 = iid1;
+    const int64_t row_diff = row_high - row_low;
 
-    if (item_ct1.get_local_id(2) == 0) {
-        src1_row =
-            dpct::atomic_fetch_add(
-                cur_src1_row, 1);
-        row_mapping[src1_row] = {id, iid1};
-    }
-    /*
-    DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
-    sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
-    performance if there is no access to global memory.
-    */
-    item_ct1.barrier();
+    int id;
+    SYCL_CHECK(
+        CHECK_TRY_ERROR(id = get_current_device_id()));
 
-    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
-    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
+    // the main device has a larger memory buffer to hold the results from all GPUs
+    // ldc == nrows of the matrix that cuBLAS writes into
+    int ldc = id == ctx.device ? ne0 : row_diff;
 
-#pragma unroll
-    for (int i = item_ct1.get_local_id(2); i < ne10;
-         i += item_ct1.get_local_range(2)) {
-        src1_row_contiguous[i] = src1_row_original[i];
-    }
-}
+#ifdef GGML_SYCL_F16
+    bool use_fp16 = true;  // TODO(Yu) SYCL capability check
+#else
+    bool use_fp16 = false;
+#endif
+    if ((src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
+        use_fp16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1] &&
+        dst->op_params[0] == GGML_PREC_DEFAULT) {
 
-__dpct_inline__ static void k_copy_dst_from_contiguous(
-    char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
-    const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
-    size_t nb2, const sycl::nd_item<3> &item_ct1) {
-    int32_t i = item_ct1.get_group(2);
+        // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp16 path\n");
+        ggml_sycl_pool_alloc src0_as_f16(ctx.pool());
+        if (src0->type != GGML_TYPE_F16) {
+            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src0->type);
+            GGML_ASSERT(to_fp16_sycl != nullptr);
+            size_t ne = row_diff*ne00;
+            src0_as_f16.alloc(ne);
+            to_fp16_sycl(src0_dd_i, src0_as_f16.get(), ne, stream);
+        }
+        const sycl::half *src0_ptr = src0->type == GGML_TYPE_F16
+                                         ? (const sycl::half *)src0_dd_i
+                                         : src0_as_f16.get();
 
-    const int32_t i1 = row_mapping[i].i1;
-    const int32_t i2 = row_mapping[i].i2;
+        ggml_sycl_pool_alloc src1_as_f16(ctx.pool());
+        if (src1->type != GGML_TYPE_F16) {
+            const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
+            GGML_ASSERT(to_fp16_sycl != nullptr);
+            size_t ne = src1_ncols*ne10;
+            src1_as_f16.alloc(ne);
+            to_fp16_sycl(src1_ddf_i, src1_as_f16.get(), ne, stream);
+        }
+        const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
+                ? (const sycl::half *)src1->data + src1_padded_row_size
+                                         : src1_as_f16.get();
+        ggml_sycl_pool_alloc dst_f16(ctx.pool(), row_diff * src1_ncols);
 
-    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
-    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
+        const sycl::half alpha_f16 = 1.0f;
+        const sycl::half beta_f16 = 0.0f;
+#if !GGML_SYCL_DNNL
+        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(
+            *stream, oneapi::mkl::transpose::trans,
+            oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
+            &alpha_f16, src0_ptr, dpct::library_data_t::real_half, ne00,
+            src1_ptr, dpct::library_data_t::real_half, ne10, &beta_f16,
+            dst_f16.get(), dpct::library_data_t::real_half, ldc,
+            dpct::library_data_t::real_half)));
+        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
+        to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff*src1_ncols, stream);
+#else
+        auto dnnl_stream = ctx.stream_dnnl(stream);
+        DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ptr, DnnlGemmWrapper::to_dt(),
+            src0_ptr, DnnlGemmWrapper::to_dt(), dst_f16.get(), DnnlGemmWrapper::to_dt());
+        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
+        to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+#endif
+    }
+    else {
+        // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat_sycl - fp32 path\n");
+        ggml_sycl_pool_alloc src0_ddq_as_f32(ctx.pool());
+        ggml_sycl_pool_alloc src1_ddq_as_f32(ctx.pool());
+        if (src0->type != GGML_TYPE_F32) {
+            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src0->type);
+            GGML_ASSERT(to_fp32_sycl != nullptr);
+            src0_ddq_as_f32.alloc(row_diff*ne00);
+            to_fp32_sycl(src0_dd_i, src0_ddq_as_f32.get(), row_diff*ne00, stream);
+        }
+        if (src1->type != GGML_TYPE_F32) {
+            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(src1->type);
+            GGML_ASSERT(to_fp32_sycl != nullptr);
+            src1_ddq_as_f32.alloc(src1_ncols*ne10);
+            to_fp32_sycl(src1_ddf_i, src1_ddq_as_f32.get(), src1_ncols*ne10, stream);
+        }
+        const float * src0_ddf_i = src0->type == GGML_TYPE_F32 ? (const float *) src0_dd_i : src0_ddq_as_f32.get();
+        const float * src1_ddf1_i = src1->type == GGML_TYPE_F32 ? (const float *) src1_ddf_i : src1_ddq_as_f32.get();
 
-#pragma unroll
-    for (int j = item_ct1.get_local_id(2); j < ne0;
-         j += item_ct1.get_local_range(2)) {
-        dst_row_original[j] = dst_row_contiguous[j];
+        const float alpha = 1.0f;
+        const float beta = 0.0f;
+#if !GGML_SYCL_DNNL
+        SYCL_CHECK(CHECK_TRY_ERROR(oneapi::mkl::blas::column_major::gemm(
+            *stream, oneapi::mkl::transpose::trans,
+            oneapi::mkl::transpose::nontrans, row_diff, src1_ncols, ne10,
+            dpct::get_value(&alpha, *stream), src0_ddf_i, ne00,
+            src1_ddf1_i, ne10, dpct::get_value(&beta, *stream),
+            dst_dd_i, ldc)));
+#else
+        auto dnnl_stream = ctx.stream_dnnl(stream);
+         DnnlGemmWrapper::row_gemm(dnnl_stream, false, true, src1_ncols, row_diff, ne10, src1_ddf1_i, DnnlGemmWrapper::to_dt(),
+            src0_ddf_i, DnnlGemmWrapper::to_dt(), dst_dd_i, DnnlGemmWrapper::to_dt());
+#endif
     }
+    (void) dst;
+    (void) src1_ddq_i;
+    (void) src1_padded_row_size;
+}
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
-static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
-                                 const ggml_tensor *src1,
-                                 ggml_tensor *dst) try {
-    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
+static void ggml_sycl_op_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                const ggml_tensor *src1, ggml_tensor *dst,
+                                const float *src0_dd, const float *src1_dd,
+                                float *dst_dd, const queue_ptr &main_stream) {
 
-    const ggml_tensor *ids = dst->src[2];
-    GGML_TENSOR_BINARY_OP_LOCALS
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const queue_ptr stream = ctx.stream();
+    const int32_t * opts = (const int32_t *)dst->op_params;
+    enum ggml_op_pool op = static_cast(opts[0]);
+    const int k0 = opts[1];
+    const int k1 = opts[2];
+    const int s0 = opts[3];
+    const int s1 = opts[4];
+    const int p0 = opts[5];
+    const int p1 = opts[6];
 
-    const int64_t n_as = ne02;
-    const int64_t n_ids = ids->ne[0];
+    const int64_t IH = src0->ne[1];
+    const int64_t IW = src0->ne[0];
 
-    std::vector ids_host(ggml_nbytes(ids));
-    const char * ids_dev = (const char *) ids->data;
+    const int64_t N = dst->ne[3];
+    const int64_t OC = dst->ne[2];
+    const int64_t OH = dst->ne[1];
+    const int64_t OW = dst->ne[0];
 
-    SYCL_CHECK(CHECK_TRY_ERROR(
-        stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
-    SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
+    const int parallel_elements = N * OC * OH * OW;
+    const int num_blocks = (parallel_elements + SYCL_POOL2D_BLOCK_SIZE - 1) / SYCL_POOL2D_BLOCK_SIZE;
+    sycl::range<3> block_nums(1, 1, num_blocks);
+    main_stream->parallel_for(
+        sycl::nd_range<3>(block_nums *
+                              sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE),
+                          sycl::range<3>(1, 1, SYCL_IM2COL_BLOCK_SIZE)),
+        [=](sycl::nd_item<3> item_ct1) {
+            pool2d_nchw_kernel(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0,
+                               parallel_elements, src0_dd, dst_dd, op,
+                               item_ct1);
+        });
 
-    ggml_tensor src0_row = *src0;
-    ggml_tensor src1_row = *src1;
-    ggml_tensor dst_row = *dst;
+    (void) src1;
+    (void) src1_dd;
+}
 
-    char *src0_original = (char *)src0->data;
-    char *src1_original = (char *)src1->data;
-    char *dst_original = (char *)dst->data;
+inline void ggml_sycl_op_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                  const ggml_tensor *src1, ggml_tensor *dst,
+                                  const float *src0_dd, const float *src1_dd,
+                                  float *dst_dd,
+                                  const queue_ptr &main_stream) {
 
-    src0_row.ne[2] = 1;
-    src0_row.ne[3] = 1;
-    src0_row.nb[3] = nb02;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    src1_row.ne[1] = 1;
-    src1_row.ne[2] = 1;
-    src1_row.ne[3] = 1;
-    src1_row.nb[2] = nb11;
-    src1_row.nb[3] = nb11;
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
 
-    dst_row.ne[1] = 1;
-    dst_row.ne[2] = 1;
-    dst_row.ne[3] = 1;
-    dst_row.nb[2] = nb1;
-    dst_row.nb[3] = nb1;
-    if (ne12 == 1) {
-        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-            for (int64_t id = 0; id < n_ids; id++) {
-                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
-                GGML_ASSERT(i02 >= 0 && i02 < n_as);
+    sum_rows_f32_sycl(src0_dd, dst_dd, ncols, nrows, main_stream);
 
-                const int64_t i11 = id % ne11;
-                const int64_t i12 = iid1;
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
 
-                const int64_t i1 = id;
-                const int64_t i2 = i12;
+inline void ggml_sycl_op_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                 const ggml_tensor *src1, ggml_tensor *dst,
+                                 const float *src0_dd, const float *src1_dd,
+                                 float *dst_dd,
+                                 const queue_ptr &main_stream) {
 
-            src0_row.data = src0_original + i02*nb02;
-            src1_row.data = src1_original + + i11*nb11 + i12*nb12;
-            dst_row.data = dst_original + i1*nb1   + i2*nb2;
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_I32);
 
-            ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
-            }
-        }
-    } else {
-        ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
-        ggml_sycl_pool_alloc  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
+    const int64_t ncols = src0->ne[0];
+    const int64_t nrows = ggml_nrows(src0);
 
-        src1_row.data = src1_contiguous.get();
-        dst_row.data  =  dst_contiguous.get();
+    enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
 
-        for (int64_t i02 = 0; i02 < n_as; i02++) {
-            int64_t num_src1_rows = 0;
-            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
-                for (int64_t id = 0; id < n_ids; id++) {
-                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+    argsort_f32_i32_sycl(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream);
 
-                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
 
-                    if (row_id_i != i02) {
-                        continue;
-                    }
+inline void ggml_sycl_op_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                       const ggml_tensor *src1,
+                                       ggml_tensor *dst, const float *src0_dd,
+                                       const float *src1_dd, float *dst_dd,
+                                       const queue_ptr &main_stream) {
 
-                    num_src1_rows++;
-                }
-            }
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-            if (num_src1_rows == 0) {
-                continue;
-            }
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int nrows0 = ggml_nrows(src0);
 
+    const int n_past = ((int32_t *) dst->op_params)[0];
 
-            ggml_sycl_pool_alloc dev_cur_src1_row(ctx.pool(), 1);
-            ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows);
-            SYCL_CHECK(CHECK_TRY_ERROR(
-                stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
+    diag_mask_inf_f32_sycl(src0_dd, dst_dd, ne00, nrows0, ne01, n_past, main_stream);
 
-            {
-                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
-                sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
-                stream->submit([&](sycl::handler &cgh) {
-                    sycl::local_accessor src1_row_acc(cgh);
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
 
-                    char *__restrict src1_contiguous_get =
-                        src1_contiguous.get();
-                    int *__restrict dev_cur_src1_row_get =
-                        dev_cur_src1_row.get();
-                    mmid_row_mapping *__restrict dev_row_mapping_get =
-                        dev_row_mapping.get();
-                    size_t ids_nb_ct6 = ids->nb[1];
-                    size_t ids_nb_ct7 = ids->nb[0];
+inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                               ggml_tensor *dst, const float *src0_dd,
+                               const float *src1_dd, float *dst_dd,
+                               const queue_ptr &main_stream) {
 
-                    cgh.parallel_for(
-                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),
-                        [=](sycl::nd_item<3> item_ct1) {
-                            k_copy_src1_to_contiguous(
-                                src1_original, src1_contiguous_get,
-                                dev_cur_src1_row_get,
-                                dev_row_mapping_get, ids_dev, i02,
-                                ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
-                                item_ct1, src1_row_acc);
-                        });
-                });
-            }
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-            src0_row.data = src0_original + i02*nb02;
+    float scale;
+    memcpy(&scale, dst->op_params, sizeof(float));
 
-            GGML_ASSERT(nb11 == sizeof(float)*ne10);
-            GGML_ASSERT(nb1 == sizeof(float)*ne0);
-            src1_row.ne[1] = num_src1_rows;
+    scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(src0), main_stream);
+    /*
+    DPCT1010:87: SYCL uses exceptions to report errors and does not use the
+    error codes. The call was replaced with 0. You need to rewrite this code.
+    */
+    SYCL_CHECK(0);
 
-            src1_row.nb[1] = nb11;
-            src1_row.nb[2] = num_src1_rows*nb11;
-            src1_row.nb[3] = num_src1_rows*nb11;
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
 
-            dst_row.ne[1] = num_src1_rows;
-            dst_row.nb[1] = nb1;
-            dst_row.nb[2] = num_src1_rows*nb1;
-            dst_row.nb[3] = num_src1_rows*nb1;
+inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                               ggml_tensor *dst, const float *src0_dd,
+                               const float *src1_dd, float *dst_dd,
+                               const queue_ptr &main_stream) {
 
-            ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-            {
-                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
-                sycl::range<3> grid_dims(1, 1, num_src1_rows);
-                stream->submit([&](sycl::handler &cgh) {
-                    const char *__restrict dst_contiguous_get =
-                        dst_contiguous.get();
-                    const mmid_row_mapping *__restrict dev_row_mapping_get =
-                        dev_row_mapping.get();
+    float min;
+    float max;
+    memcpy(&min, dst->op_params, sizeof(float));
+    memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
 
-                    cgh.parallel_for(
-                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),
-                        [=](sycl::nd_item<3> item_ct1) {
-                            k_copy_dst_from_contiguous(dst_original,
-                                                       dst_contiguous_get,
-                                                       dev_row_mapping_get,
-                                                       ne0, nb1, nb2, item_ct1);
-                        });
-                });
-            }
-        }
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
+    clamp_f32_sycl(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
+    /*
+    DPCT1010:88: SYCL uses exceptions to report errors and does not use the
+    error codes. The call was replaced with 0. You need to rewrite this code.
+    */
+    SYCL_CHECK(0);
 
-static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale);
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
 }
 
-static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
-}
+static void ggml_sycl_op_flatten(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                 const ggml_tensor *src1, ggml_tensor *dst,
+                                 const ggml_sycl_op_flatten_t op) try {
+    const int64_t nrows0 = ggml_nrows(src0);
 
-static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
-                          ggml_tensor *dst) try {
-    const int64_t ne = ggml_nelements(src0);
-    GGML_ASSERT(ne == ggml_nelements(src1));
+    const bool use_src1 = src1 != nullptr;
+    const int64_t nrows1 = use_src1 ? ggml_nrows(src1) : 1;
 
-    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
-    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
+    GGML_ASSERT(!use_src1 || src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+    GGML_ASSERT(              dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
 
-    GGML_TENSOR_BINARY_OP_LOCALS01;
+    ggml_tensor_extra_gpu * src0_extra =            (ggml_tensor_extra_gpu *) src0->extra;
+    ggml_tensor_extra_gpu * src1_extra = use_src1 ? (ggml_tensor_extra_gpu *) src1->extra : nullptr;
+    ggml_tensor_extra_gpu * dst_extra  =            (ggml_tensor_extra_gpu *)  dst->extra;
 
-    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
-    queue_ptr main_stream = ctx.stream();
+    // dd = data device
+    float * src0_ddf = (float *) src0->data;
+    float * src1_ddf = use_src1 ? (float *) src1->data : nullptr;
+    float *  dst_ddf = (float *) dst->data;
 
-    char * src0_ddc = (char *) src0->data;
-    char * src1_ddc = (char *) src1->data;
+    ggml_sycl_pool_alloc src0_f(ctx.pool());
+    ggml_sycl_pool_alloc src1_f(ctx.pool());
+    ggml_sycl_pool_alloc  dst_f(ctx.pool());
 
-    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
-        ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
-        ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
-        ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
-        ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
-        ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
-        ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
-        ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
-    } else {
-        fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
-                ggml_type_name(src0->type), ggml_type_name(src1->type));
-        GGML_ABORT("fatal error");
-    }
+    ggml_sycl_set_device(ctx.device);
+    queue_ptr main_stream = ctx.stream();
+    // GGML_SYCL_DEBUG("ctx.device=%d, main_stream=%p src0_on_device=%d, src1_on_device=%d, dst_on_device=%d\n",
+        // ctx.device, main_stream, src0_on_device, src1_on_device, dst_on_device);
 
-    (void) dst;
+    // do the computation
+    op(ctx, src0, src1, dst, src0_ddf, src1_ddf, dst_ddf, main_stream);
+    // print_ggml_tensor("tensor", dst);
 }
 catch (sycl::exception const &exc) {
+
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
             << ", line:" << __LINE__ << std::endl;
   std::exit(1);
 }
 
-static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    // TODO: why do we pass dst as src1 here?
-    ggml_sycl_cpy(ctx, src0, dst, nullptr);
-    (void) src1;
-}
+static void ggml_sycl_set_peer_access(const int n_tokens, int main_device) {
+    static bool peer_access_enabled = false;
 
-static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
-}
+    const bool enable_peer_access = n_tokens <= GGML_SYCL_PEER_MAX_BATCH_SIZE;
 
-static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
-}
+    if (peer_access_enabled == enable_peer_access) {
+        return;
+    }
 
-static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope);
-}
+#ifdef NDEBUG
+    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+        SYCL_CHECK(ggml_sycl_set_device(i));
+    }
 
-static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d);
-}
+    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+        SYCL_CHECK(ggml_sycl_set_device(i));
 
-static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
-}
+        for (int id_other = 0; id_other < ggml_sycl_info().device_count; ++id_other) {
+            if (i == id_other) {
+                continue;
+            }
+            if (i != main_device && id_other != main_device) {
+                continue;
+            }
 
-static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
-}
+            // int can_access_peer;
+            // SYCL_CHECK(syclDeviceCanAccessPeer(&can_access_peer, id, id_other));
+            // if (can_access_peer) {
+            //     if (enable_peer_access) {
+            //         SYCL_CHECK(syclDeviceEnablePeerAccess(id_other, 0));
+            //     } else {
+            //         SYCL_CHECK(syclDeviceDisablePeerAccess(id_other));
+            //     }
+            // }
+        }
+    }
+#endif // NDEBUG
 
-static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
+    peer_access_enabled = enable_peer_access;
 }
 
-static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    (void) src0;
-    (void) src1;
-    (void) dst;
-}
+static void ggml_sycl_op_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                 const ggml_tensor *src1, ggml_tensor *dst,
+                                 ggml_sycl_op_mul_mat_t op,
+                                 const bool convert_src1_to_q8_1) try {
 
-static size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
-    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
 
-    return nrows_split*ggml_row_size(tensor->type, tensor->ne[0]);
-}
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
+    const int64_t nrows1 = ggml_nrows(src1);
 
-void ggml_sycl_set_main_device(const int main_device) try {
-    if (dpct::get_current_device_id() == main_device) return;
-    check_allow_gpu_index(main_device);
-    dpct::select_device(main_device);
+    GGML_ASSERT(ne03 == ne13);
 
-    if (g_ggml_sycl_debug) {
-        dpct::device_info prop;
-        SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
-            prop, dpct::dev_mgr::instance().get_device(main_device))));
-        fprintf(stderr, "Using device %d (%s) as main device\n",
-                main_device, prop.get_name());
-    }
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
 
-bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) {
-    if (!g_sycl_loaded) return false;
+    const int nb2 = dst->nb[2];
+    const int nb3 = dst->nb[3];
 
-    ggml_sycl_func_t func;
+    GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+    GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
 
-    switch (tensor->op) {
-        case GGML_OP_CONV_TRANSPOSE_1D:
-            func = ggml_sycl_op_conv_transpose_1d;
-            break;
-        case GGML_OP_REPEAT:
-            func = ggml_sycl_repeat;
-            break;
-        case GGML_OP_GET_ROWS:
-            func = ggml_sycl_get_rows;
-            break;
-        case GGML_OP_DUP:
-            func = ggml_sycl_dup;
-            break;
-        case GGML_OP_ADD:
-            func = ggml_sycl_add;
-            break;
-        case GGML_OP_ACC:
-            func = ggml_sycl_acc;
-            break;
-        case GGML_OP_MUL:
-            func = ggml_sycl_mul;
-            break;
-        case GGML_OP_DIV:
-            func = ggml_sycl_div;
-            break;
-        case GGML_OP_UNARY:
-            switch (ggml_get_unary_op(tensor)) {
-                case GGML_UNARY_OP_GELU:
-                    func = ggml_sycl_gelu;
-                    break;
-                case GGML_UNARY_OP_SILU:
-                    func = ggml_sycl_silu;
-                    break;
-                case GGML_UNARY_OP_GELU_QUICK:
-                    func = ggml_sycl_gelu_quick;
-                    break;
-                case GGML_UNARY_OP_TANH:
-                    func = ggml_sycl_tanh;
-                    break;
-                case GGML_UNARY_OP_RELU:
-                    func = ggml_sycl_relu;
-                    break;
-                case GGML_UNARY_OP_HARDSIGMOID:
-                    func = ggml_sycl_hardsigmoid;
-                    break;
-                case GGML_UNARY_OP_HARDSWISH:
-                    func = ggml_sycl_hardswish;
-                    break;
-                default:
-                    return false;
-            }
-            break;
-        case GGML_OP_NORM:
-            func = ggml_sycl_norm;
-            break;
-        case GGML_OP_GROUP_NORM:
-            func = ggml_sycl_group_norm;
-            break;
-        case GGML_OP_CONCAT:
-            func = ggml_sycl_op_concat;
-            break;
-        case GGML_OP_UPSCALE:
-            func = ggml_sycl_upscale;
-            break;
-        case GGML_OP_PAD:
-            func = ggml_sycl_pad;
-            break;
-        case GGML_OP_LEAKY_RELU:
-            func = ggml_sycl_leaky_relu;
-            break;
-        case GGML_OP_RMS_NORM:
-            func = ggml_sycl_rms_norm;
-            break;
-        case GGML_OP_MUL_MAT:
-            if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
-                return false;
-            }
-            func = ggml_sycl_mul_mat;
-            break;
-        case GGML_OP_MUL_MAT_ID:
-            if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
-                return false;
-            }
-            func = ggml_sycl_mul_mat_id;
-            break;
-        case GGML_OP_SCALE:
-            func = ggml_sycl_scale;
-            break;
-        case GGML_OP_SQR:
-            func = ggml_sycl_sqr;
-            break;
-        case GGML_OP_CLAMP:
-            func = ggml_sycl_clamp;
-            break;
-        case GGML_OP_CPY:
-            func = ggml_sycl_cpy;
-            break;
-        case GGML_OP_CONT:
-            func = ggml_sycl_dup;
-            break;
-        case GGML_OP_NONE:
-        case GGML_OP_RESHAPE:
-        case GGML_OP_VIEW:
-        case GGML_OP_PERMUTE:
-        case GGML_OP_TRANSPOSE:
-            func = ggml_sycl_nop;
-            break;
-        case GGML_OP_DIAG_MASK_INF:
-            func = ggml_sycl_diag_mask_inf;
-            break;
-        case GGML_OP_SOFT_MAX:
-            func = ggml_sycl_soft_max;
-            break;
-        case GGML_OP_ROPE:
-            func = ggml_sycl_rope;
-            break;
-        case GGML_OP_IM2COL:
-            func = ggml_sycl_im2col;
-            break;
-        case GGML_OP_POOL_2D:
-            func = ggml_sycl_pool2d;
-            break;
-        case GGML_OP_SUM_ROWS:
-            func = ggml_sycl_sum_rows;
-            break;
-        case GGML_OP_ARGSORT:
-            func = ggml_sycl_argsort;
-            break;
-        case GGML_OP_TIMESTEP_EMBEDDING:
-            func = ggml_sycl_op_timestep_embedding;
-            break;
-        default:
-            return false;
-    }
+    GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
 
-    if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
-        ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
-    }
+    const int64_t i02_divisor = ne12 / ne02;
 
-    func(ctx, tensor->src[0], tensor->src[1], tensor);
-    return true;
-}
+    const size_t src0_ts = ggml_type_size(src0->type);
+    const size_t src0_bs = ggml_blck_size(src0->type);
+    const size_t q8_1_ts = sizeof(block_q8_1);
+    const size_t q8_1_bs = QK8_1;
 
-GGML_API void   ggml_sycl_get_gpu_list(int *id_list, int max_len) try {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_sycl_get_gpu_list\n");
-    for(int i=0;iextra;
+    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    ggml_tensor_extra_gpu *  dst_extra = (ggml_tensor_extra_gpu *)  dst->extra;
 
-    for (int i=0;i< ggml_sycl_info().device_count;i++){
-        if (i>=max_len) break;
-        id_list[i] = i;
-    }
-    return;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
+    const bool src0_is_contiguous = ggml_is_contiguous(src0);
+    const bool src1_is_contiguous = ggml_is_contiguous(src1);
 
-int ggml_sycl_get_device_count() try {
-    int device_count;
-    if (CHECK_TRY_ERROR(device_count =
-                             dpct::dev_mgr::instance().device_count()) != 0) {
-        return 0;
+    int64_t src1_padded_col_size = GGML_PAD(ne10, MATRIX_ROW_PADDING);
+
+    const bool split = src0->backend == GGML_BACKEND_TYPE_GPU_SPLIT;
+    GGML_ASSERT(!(split && ne02 > 1));
+    GGML_ASSERT(!(split && ne03 > 1));
+    GGML_ASSERT(!(split && ne02 < ne12));
+
+    std::array tensor_split;
+    if (split) {
+        // TODO: check that src0->buffer->buft is a split buffer type, replace GGML_BACKEND_TYPE_GPU_SPLIT check
+        // GGML_ASSERT(src0->buffer != nullptr && src0->buffer->buft == ...);
+        ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
+        tensor_split = buft_ctx->tensor_split;
     }
-    return device_count;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
 
-GGML_API void ggml_sycl_get_device_description(int device, char *description,
-                                      size_t description_size) try {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_sycl_get_device_description\n");
-    dpct::device_info prop;
-    SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
-        prop, dpct::dev_mgr::instance().get_device(device))));
-    snprintf(description, description_size, "%s", prop.get_name());
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
+    struct dev_data {
+        ggml_sycl_pool_alloc src0_dd_alloc;
+        ggml_sycl_pool_alloc src1_ddf_alloc;
+        ggml_sycl_pool_alloc src1_ddq_alloc;
+        ggml_sycl_pool_alloc dst_dd_alloc;
 
-void ggml_backend_sycl_get_device_memory(int device, size_t *free,
-                                                   size_t *total) try {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
-    ggml_sycl_set_device(device);
+        char *src0_dd = nullptr;
+        float *src1_ddf = nullptr; // float
+        char *src1_ddq = nullptr;  // q8_1
+        float *dst_dd = nullptr;
 
-    /*
-    DPCT1009:218: SYCL uses exceptions to report errors and does not use the
-    error codes. The original code was commented out and a warning string was
-    inserted. You need to rewrite this code.
-    */
-    /*
-    DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for
-    device information which may not be supported by all compilers or runtimes.
-    You may need to adjust the code.
-    */
-    SYCL_CHECK(CHECK_TRY_ERROR(
-        dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
+        int64_t row_low;
+        int64_t row_high;
+    };
 
-////////////////////////////////////////////////////////////////////////////////
+    dev_data dev[GGML_SYCL_MAX_DEVICES];
 
-// backend interface
+    int used_devices = 0;
+    queue_ptr main_stream = ctx.stream();
 
-#define UNUSED GGML_UNUSED
+    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+        // by default, use all rows
+        dev[i].row_low  = 0;
+        dev[i].row_high = ne01;
 
-// sycl buffer
+        // for multi GPU, get the row boundaries from tensor split
+        // and round to mul_mat_q tile sizes
+        if (split) {
+            const int64_t rounding = get_row_rounding(src0->type, tensor_split);
 
-struct ggml_backend_sycl_buffer_context {
-    int device;
-    void * dev_ptr = nullptr;
-    queue_ptr stream;
-    std::string name;
+            if (i != 0) {
+                dev[i].row_low  = ne01*tensor_split[i];
+                if (dev[i].row_low < ne01) {
+                    dev[i].row_low -= dev[i].row_low % rounding;
+                }
+            }
 
-     ggml_backend_sycl_buffer_context(int device, void * dev_ptr, queue_ptr stream) :
-        device(device), dev_ptr(dev_ptr), stream(stream) {
-            check_allow_gpu_index(device);
-            name = (GGML_SYCL_NAME + std::to_string(device));
+            if (i != ggml_sycl_info().device_count - 1) {
+                dev[i].row_high  = ne01*tensor_split[i + 1];
+                if (dev[i].row_high < ne01) {
+                    dev[i].row_high -= dev[i].row_high % rounding;
+                }
+            }
+        }
+    }
+
+    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+        if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
+            continue;
+        }
+
+        used_devices++;
+
+        const bool src1_on_device = i == ctx.device;
+        const bool  dst_on_device = i == ctx.device;
+
+        ggml_sycl_set_device(i);
+        queue_ptr stream = ctx.stream(i, 0);
+
+        if (src0_is_contiguous) {
+            dev[i].src0_dd = (char *) src0->data;
+        } else {
+            dev[i].src0_dd = dev[i].src0_dd_alloc.alloc(ctx.pool(i), ggml_nbytes(src0));
+        }
+
+        if (src1_on_device && src1_is_contiguous) {
+            dev[i].src1_ddf = (float *) src1->data;
+        } else {
+            dev[i].src1_ddf = dev[i].src1_ddf_alloc.alloc(ctx.pool(i), ggml_nelements(src1));
+        }
+
+        if (convert_src1_to_q8_1) {
+            dev[i].src1_ddq = dev[i].src1_ddq_alloc.alloc(ctx.pool(i), nrows1*src1_padded_col_size*q8_1_ts/q8_1_bs);
+
+            if (src1_on_device && src1_is_contiguous) {
+                quantize_row_q8_1_sycl(dev[i].src1_ddf, dev[i].src1_ddq, ne10, nrows1, src1_padded_col_size, stream);
+                /*
+                DPCT1010:90: SYCL uses exceptions to report errors and does not
+                use the error codes. The call was replaced with 0. You need to
+                rewrite this code.
+                */
+                SYCL_CHECK(0);
+            }
+        }
+
+        if (dst_on_device) {
+            dev[i].dst_dd = (float *) dst->data;
+        } else {
+            const size_t size_dst_ddf = split ? (dev[i].row_high - dev[i].row_low)*ne1 : ggml_nelements(dst);
+            dev[i].dst_dd = dev[i].dst_dd_alloc.alloc(ctx.pool(i), size_dst_ddf);
+        }
+    }
+
+    // if multiple devices are used they need to wait for the main device
+    // here an event is recorded that signals that the main device has finished calculating the input data
+    if (split && used_devices > 1) {
+        ggml_sycl_set_device(ctx.device);
+        /*
+        DPCT1024:91: The original code returned the error code that was further
+        consumed by the program logic. This original code was replaced with 0.
+        You may need to rewrite the program logic consuming the error code.
+        */
+        SYCL_CHECK(CHECK_TRY_ERROR(
+            *src0_extra->events[ctx.device][0] =
+                ctx.stream()->ext_oneapi_submit_barrier()));
+    }
+
+    const int64_t src1_col_stride = split && used_devices > 1 ? MUL_MAT_SRC1_COL_STRIDE : ne11;
+    for (int64_t src1_col_0 = 0; src1_col_0 < ne11; src1_col_0 += src1_col_stride) {
+        const int64_t is = split ? (src1_col_0/src1_col_stride) % GGML_SYCL_MAX_STREAMS : 0;
+        const int64_t src1_ncols = src1_col_0 + src1_col_stride > ne11 ? ne11 - src1_col_0 : src1_col_stride;
+
+        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+            if ((!split && i != ctx.device) || dev[i].row_low == dev[i].row_high) {
+                continue;
+            }
+
+            const bool src1_on_device = i == ctx.device;
+            const bool  dst_on_device = i == ctx.device;
+            const int64_t row_diff = dev[i].row_high - dev[i].row_low;
+
+            ggml_sycl_set_device(i);
+            queue_ptr stream = ctx.stream(i, is);
+
+            // wait for main GPU data if necessary
+            if (split && (i != ctx.device || is != 0)) {
+                /*
+                DPCT1009:163: SYCL uses exceptions to report errors and does not
+                use the error codes. The original code was commented out and a
+                warning string was inserted. You need to rewrite this code.
+                */
+                SYCL_CHECK(CHECK_TRY_ERROR(stream->ext_oneapi_submit_barrier(
+                    {*src0_extra->events[ctx.device][0]})));
+            }
+
+            for (int64_t i0 = 0; i0 < ne13*ne12; ++i0) {
+                const int64_t i03 = i0 / ne12;
+                const int64_t i02 = i0 % ne12;
+
+                const size_t src1_ddq_i_offset = (i0*ne11 + src1_col_0) * src1_padded_col_size*q8_1_ts/q8_1_bs;
+
+                // for split tensors the data begins at i0 == i0_offset_low
+                char  *  src0_dd_i =  dev[i].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
+                float * src1_ddf_i = dev[i].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
+                char  * src1_ddq_i = dev[i].src1_ddq +  src1_ddq_i_offset;
+                float *   dst_dd_i =   dev[i].dst_dd + (i0*ne1  + src1_col_0) * (dst_on_device ? ne0 : row_diff);
+
+                // the main device memory buffer can be on VRAM scratch, with space for all partial results
+                // in that case an offset on dst_ddf_i is needed
+                if (i == ctx.device) {
+                    dst_dd_i += dev[i].row_low; // offset is 0 if no tensor split
+                }
+
+                // copy src0, src1 to device if necessary
+                if (src1_is_contiguous) {
+                    if (i != ctx.device) {
+                        if (convert_src1_to_q8_1) {
+                            char * src1_ddq_i_source = dev[ctx.device].src1_ddq + src1_ddq_i_offset;
+                          SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
+                                src1_ddq_i, src1_ddq_i_source,
+                                src1_ncols * src1_padded_col_size * q8_1_ts /
+                                    q8_1_bs).wait()));
+                        } else {
+
+                            float * src1_ddf_i_source = (float *) src1_extra->data_device[ctx.device];
+                            src1_ddf_i_source += (i0*ne11 + src1_col_0) * ne10;
+
+                            SYCL_CHECK(CHECK_TRY_ERROR(dev2dev_memcpy(*stream, *main_stream,
+                                src1_ddf_i, src1_ddf_i_source,
+                                src1_ncols * ne10 * sizeof(float))));
+                        }
+                    }
+                } else if (src1_on_device && !src1_is_contiguous) {
+                    SYCL_CHECK(ggml_sycl_cpy_tensor_2d(
+                                   src1_ddf_i, src1, i03, i02, src1_col_0, src1_col_0+src1_ncols, stream));
+                } else {
+                    GGML_ABORT("fatal error");
+                }
+
+                if (convert_src1_to_q8_1 && !src1_is_contiguous) {
+                    quantize_row_q8_1_sycl(src1_ddf_i, src1_ddq_i, ne10, src1_ncols, src1_padded_col_size, stream);
+                    /*
+                    DPCT1010:92: SYCL uses exceptions to report errors and does
+                    not use the error codes. The call was replaced with 0. You
+                    need to rewrite this code.
+                    */
+                    SYCL_CHECK(0);
+                }
+
+                if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
+                    SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[i].row_low, dev[i].row_high, stream));
+                }
+                if (src1->type == GGML_TYPE_F16) {
+                    src1_padded_col_size = (i0 * ne11 + src1_col_0) * ne10;
+                }
+                // do the computation
+                SYCL_CHECK(CHECK_TRY_ERROR(op(ctx, src0, src1, dst, src0_dd_i, src1_ddf_i, src1_ddq_i, dst_dd_i,
+                    dev[i].row_low, dev[i].row_high, src1_ncols, src1_padded_col_size, stream)));
+                /*
+                DPCT1010:93: SYCL uses exceptions to report errors and does not
+                use the error codes. The call was replaced with 0. You need to
+                rewrite this code.
+                */
+                SYCL_CHECK(0);
+
+                // copy dst to host or other device if necessary
+                if (!dst_on_device) {
+                    void * dst_off_device = dst->data;
+                    if (split) {
+                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
+                        // dst is NOT transposed.
+                        // The outputs of matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
+                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
+                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+                        dhf_dst_i += src1_col_0*ne0 + dev[i].row_low;
+
+                        SYCL_CHECK(CHECK_TRY_ERROR(dpct::async_dpct_memcpy(
+                            dhf_dst_i, ne0 * sizeof(float), dst_dd_i,
+                            row_diff * sizeof(float), row_diff * sizeof(float),
+                            src1_ncols, dpct::device_to_device, *stream)));
+                    } else {
+                        float * dhf_dst_i = (float *) ((char *) dst_off_device + i02*nb2 + i03*nb3);
+                        GGML_ASSERT(dst->nb[1] == ne0*sizeof(float));
+                        dhf_dst_i += src1_col_0*ne0;
+                        SYCL_CHECK(CHECK_TRY_ERROR(
+                            stream->memcpy(dhf_dst_i, dst_dd_i,
+                                           src1_ncols * ne0 * sizeof(float)).wait()));
+                    }
+                }
+
+                // add event for the main device to wait on until other device is done
+                if (split && (i != ctx.device || is != 0)) {
+                    /*
+                    DPCT1024:94: The original code returned the error code that
+                    was further consumed by the program logic. This original
+                    code was replaced with 0. You may need to rewrite the
+                    program logic consuming the error code.
+                    */
+                    SYCL_CHECK(CHECK_TRY_ERROR(
+                        *src0_extra->events[i][is] =
+                            stream->ext_oneapi_submit_barrier()));
+                }
+            }
         }
+    }
 
+    // main device waits for all other devices to be finished
+    if (split && ggml_sycl_info().device_count > 1) {
+        int64_t is_max = (ne11 + MUL_MAT_SRC1_COL_STRIDE - 1) / MUL_MAT_SRC1_COL_STRIDE;
+        is_max = is_max <= GGML_SYCL_MAX_STREAMS ? is_max : GGML_SYCL_MAX_STREAMS;
 
-    ~ggml_backend_sycl_buffer_context() {
-        if (dev_ptr != nullptr) {
-            ggml_sycl_set_device(device);
-            SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(dev_ptr, *stream)));
+        ggml_sycl_set_device(ctx.device);
+        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
+            if (dev[i].row_low == dev[i].row_high) {
+                continue;
+            }
+            for (int64_t is = 0; is < is_max; ++is) {
+                SYCL_CHECK(CHECK_TRY_ERROR(
+                    ctx.stream()->ext_oneapi_submit_barrier(
+                        {*src0_extra->events[i][is]})));
+            }
         }
     }
-};
-
-static const char * ggml_backend_sycl_buffer_get_name(ggml_backend_buffer_t buffer) {
-    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
-    return ctx->name.c_str();
-}
-
-static bool ggml_backend_buffer_is_sycl(ggml_backend_buffer_t buffer) {
-    return buffer->iface.get_name == ggml_backend_sycl_buffer_get_name;
-}
-
-static void
-ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try {
-    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-    ggml_sycl_set_device(ctx->device);
-
-    delete ctx;
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -4157,139 +3913,152 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static void * ggml_backend_sycl_buffer_get_base(ggml_backend_buffer_t buffer) {
-    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
-    return ctx->dev_ptr;
+
+static void ggml_sycl_repeat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_repeat);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-static void
-ggml_backend_sycl_buffer_init_tensor(ggml_backend_buffer_t buffer,
-                                     ggml_tensor *tensor) try {
-    ggml_backend_sycl_buffer_context * ctx = (ggml_backend_sycl_buffer_context *)buffer->context;
+static void ggml_sycl_get_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_get_rows);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-    if (tensor->view_src != NULL && tensor->view_offs == 0) {
-        assert(tensor->view_src->buffer->buft == buffer->buft);
-        tensor->backend = tensor->view_src->backend;
-        tensor->extra = tensor->view_src->extra;
-        return;
-    }
+static void ggml_sycl_add(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_add);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
+static void ggml_sycl_acc(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_acc);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-    if (ggml_is_quantized(tensor->type)) {
-        // initialize padding to 0 to avoid possible NaN values
-        size_t original_size = ggml_nbytes(tensor);
-        size_t padded_size = ggml_backend_buft_get_alloc_size(buffer->buft, tensor);
+static void ggml_sycl_mul(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_mul);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-        if (padded_size > original_size && tensor->view_src == nullptr) {
-            SYCL_CHECK(CHECK_TRY_ERROR(ctx->stream->memset(
-                (char *)tensor->data + original_size, 0,
-                padded_size - original_size).wait()));
-        }
-    }
+static void ggml_sycl_div(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_div);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
+
+static void ggml_sycl_gelu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-static void ggml_backend_sycl_buffer_set_tensor(ggml_backend_buffer_t buffer,
-                                                ggml_tensor *tensor,
-                                                const void *data, size_t offset,
-                                                size_t size) try {
+static void ggml_sycl_silu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_silu);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
+static void ggml_sycl_gelu_quick(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_gelu_quick);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-    ggml_sycl_set_device(ctx->device);
-    auto stream = &(dpct::dev_mgr::instance().get_device(ctx->device).default_queue());
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(dpct::dev_mgr::instance().get_device(ctx->device).queues_wait_and_throw()));
-    char* host_buf = (char*)malloc(size);
-    memcpy(host_buf, data, size);
-    SYCL_CHECK(
-        CHECK_TRY_ERROR((*stream).memcpy((char *)tensor->data + offset, host_buf, size)
-                             .wait()));
-    free(host_buf);
+static void ggml_sycl_tanh(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_tanh);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
+
+static void ggml_sycl_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_relu);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-static void ggml_backend_sycl_buffer_get_tensor(ggml_backend_buffer_t buffer,
-                                                const ggml_tensor *tensor,
-                                                void *data, size_t offset,
-                                                size_t size) try {
+static void ggml_sycl_hardsigmoid(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardsigmoid);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-    ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
+static void ggml_sycl_hardswish(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_hardswish);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-    ggml_sycl_set_device(ctx->device);
-    auto stream = dpct::dev_mgr::instance().get_device(ctx->device).default_queue();
+static void ggml_sycl_leaky_relu(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_leaky_relu);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-    SYCL_CHECK(CHECK_TRY_ERROR(
-        stream.memcpy(data, (const char *)tensor->data + offset, size)
-            .wait()));
+static void ggml_sycl_sqr(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sqr);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
+
+static void ggml_sycl_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_norm);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
 }
 
-static bool
-ggml_backend_sycl_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
-                                    const ggml_tensor *src,
-                                    ggml_tensor *dst) try {
-    if (ggml_backend_buffer_is_sycl(src->buffer)) {
-        ggml_backend_sycl_buffer_context * src_ctx = (ggml_backend_sycl_buffer_context *)src->buffer->context;
-        ggml_backend_sycl_buffer_context * dst_ctx = (ggml_backend_sycl_buffer_context *)dst->buffer->context;
+static void ggml_sycl_group_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_group_norm);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-        ggml_sycl_set_device(src_ctx->device);
-        /*
-        DPCT1009:198: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            dpct::dev_mgr::instance().get_device(src_ctx->device).queues_wait_and_throw()));
-        ggml_sycl_set_device(dst_ctx->device);
-        /*
-        DPCT1009:199: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
-        /*
-        DPCT1009:200: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
+static void ggml_sycl_upscale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_upscale);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-        queue_ptr stream_dst = dst_ctx->stream;
-        queue_ptr stream_src = src_ctx->stream;
-        size_t size = ggml_nbytes(src);
+static void ggml_sycl_pad(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pad);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-        //todo. it's dirty solutino to walkaroud known issue:device2device cross GPUs.
-        dev2dev_memcpy(*stream_dst, *stream_src, dst->data, src->data, size);
 
-//todo, it's known issue:error in device2device cross GPUs. reused when the issue is fixed. DON"T remove
-#if 0
-        SYCL_CHECK(CHECK_TRY_ERROR((*stream).memcpy(
-            (char *)dst->data, (const char *)src->data, size).wait()));
+static void ggml_sycl_rms_norm(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_SYCL_DEBUG("call %s\n", __func__);
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rms_norm);
+    GGML_SYCL_DEBUG("call %s done\n", __func__);
+}
 
-        /*
-        DPCT1009:201: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            dpct::dev_mgr::instance().get_device(dst_ctx->device).queues_wait_and_throw()));
-#endif
-        return true;
-    }
-    return false;
+static void ggml_sycl_mul_mat_vec_p021(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                       const ggml_tensor *src1,
+                                       ggml_tensor *dst) try {
+    GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1));
+    GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+    GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation
+    GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+
+    const int64_t ne12 = src1->ne[2];
+
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+    queue_ptr main_stream = ctx.stream();
+
+    void  * src0_ddq = src0->data;
+    float * src1_ddf = (float *) src1->data;
+    float * dst_ddf  = (float *) dst->data;
+
+    ggml_mul_mat_p021_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream);
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -4297,19 +4066,36 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
+static void ggml_sycl_mul_mat_vec_nc(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                     const ggml_tensor *src1,
+                                     ggml_tensor *dst) try {
+    GGML_ASSERT(!ggml_is_transposed(src0));
+    GGML_ASSERT(!ggml_is_transposed(src1));
+    GGML_ASSERT(!ggml_is_permuted(src0));
+    GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
 
-static void ggml_backend_sycl_buffer_clear(ggml_backend_buffer_t buffer,
-                                           uint8_t value) try {
-     ggml_backend_sycl_buffer_context * ctx = ( ggml_backend_sycl_buffer_context *)buffer->context;
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
 
-    ggml_sycl_set_device(ctx->device);
-    queue_ptr stream = ctx->stream;
-    SYCL_CHECK(
-        CHECK_TRY_ERROR(dpct::get_current_device().queues_wait_and_throw()));
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2];
 
-    SYCL_CHECK(CHECK_TRY_ERROR((*stream)
-                                    .memset(ctx->dev_ptr, value, buffer->size)
-                                    .wait()));
+    const int64_t ne12 = src1->ne[2];
+
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+    queue_ptr main_stream = ctx.stream();
+
+    void  * src0_ddq = src0->data;
+    float * src1_ddf = (float *) src1->data;
+    float * dst_ddf  = (float *) dst->data;
+
+    const int64_t row_stride_x = nb01 / sizeof(sycl::half);
+    const int64_t channel_stride_x = nb02 / sizeof(sycl::half);
+
+    ggml_mul_mat_vec_nc_f16_f32_sycl(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -4317,50 +4103,141 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static struct ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = {
-    /* .get_name        = */ ggml_backend_sycl_buffer_get_name,
-    /* .free_buffer     = */ ggml_backend_sycl_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_sycl_buffer_get_base,
-    /* .init_tensor     = */ ggml_backend_sycl_buffer_init_tensor,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_sycl_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_sycl_buffer_get_tensor,
-    /* .cpy_tensor      = */ ggml_backend_sycl_buffer_cpy_tensor,
-    /* .clear           = */ ggml_backend_sycl_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-// sycl buffer type
-struct ggml_backend_sycl_buffer_type_context {
-    int device;
-    std::string name;
+static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
+                                   const sycl::half *src1_as_f16, char *dst,
+                                   const void **ptrs_src, void **ptrs_dst,
+                                   int64_t ne12, int64_t ne13, int64_t ne23,
+                                   size_t nb02, size_t nb03, size_t nb12,
+                                   size_t nb13, size_t nbd2, size_t nbd3,
+                                   int64_t r2, int64_t r3,
+                                   const sycl::nd_item<3> &item_ct1) {
+    int64_t i13 = item_ct1.get_group(2) * item_ct1.get_local_range(2) +
+                  item_ct1.get_local_id(2);
+    int64_t i12 = item_ct1.get_group(1) * item_ct1.get_local_range(1) +
+                  item_ct1.get_local_id(1);
 
-    // each buffer type has its own stream
-    queue_ptr stream = nullptr;
-};
+    if (i13 >= ne13 || i12 >= ne12) {
+        return;
+    }
 
-static const char * ggml_backend_sycl_buffer_type_name(ggml_backend_buffer_type_t buft) {
-    ggml_backend_sycl_buffer_type_context * ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
+    int64_t i03 = i13 / r3;
+    int64_t i02 = i12 / r2;
 
-    return ctx->name.c_str();
+    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
+    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
+    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
 }
-static ggml_backend_buffer_t
-ggml_backend_sycl_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft,
-                                           size_t size) try {
-    ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
-    ggml_sycl_set_device(buft_ctx->device);
-    const queue_ptr stream = buft_ctx->stream;
-    size = std::max(size, (size_t)1); // syclMalloc returns null for size 0
 
-    void * dev_ptr;
-    SYCL_CHECK(CHECK_TRY_ERROR(dev_ptr = (void *)sycl::malloc_device(
-                                    size, *stream)));
-    if (!dev_ptr) {
-        fprintf(stderr, "%s: can't malloc %lu Bytes memory on device", __func__, size);
-        return nullptr;
+static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
+                                             const ggml_tensor *src0,
+                                             const ggml_tensor *src1,
+                                             ggml_tensor *dst) try {
+    GGML_ASSERT(!ggml_is_transposed(src0));
+    GGML_ASSERT(!ggml_is_transposed(src1));
+    GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t ne_dst = ggml_nelements(dst);
+
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+    queue_ptr main_stream = ctx.stream();;
+
+    void * src0_ddq = src0->data;
+    sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
+    float * src1_ddf = (float *) src1->data;
+    float * dst_ddf = (float *) dst->data;
+
+    // convert src1 to fp16
+    ggml_sycl_pool_alloc src1_f16_alloc(ctx.pool());
+    if (src1->type != GGML_TYPE_F16) {
+        const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
+        const int64_t ne_src1 = ggml_nelements(src1);
+        src1_f16_alloc.alloc(ne_src1);
+        GGML_ASSERT(to_fp16_sycl != nullptr);
+        to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
+    }
+    sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
+                                                       : src1_f16_alloc.get();
+
+    char * dst_t;
+
+    dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
+    dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
+
+    // dst strides
+    size_t nbd2 = dst->nb[2];
+    size_t nbd3 = dst->nb[3];
+
+    const float alpha_f32 = 1.0f;
+    const float beta_f32 = 0.0f;
+
+    const void * alpha = &alpha_f32;
+    const void * beta  = &beta_f32;
+
+    dst_t = (char *) dst_ddf;
+
+    GGML_ASSERT(ne12 % ne02 == 0);
+    GGML_ASSERT(ne13 % ne03 == 0);
+
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
+    if (r2 == 1 && r3 == 1 && ggml_is_contiguous_2(src0) && ggml_is_contiguous_2(src1)) {
+        // there is no broadcast and src0, src1 are contiguous across dims 2, 3
+        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
+            *main_stream, oneapi::mkl::transpose::trans,
+            oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
+            (const char *)src0_as_f16, dpct::library_data_t::real_half,
+            nb01 / nb00, nb02 / nb00,
+            (const char *)src1_f16, dpct::library_data_t::real_half,
+            nb11 / nb10, nb12 / nb10, beta,
+            (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
+            ne12 * ne13, cu_compute_type)));
+    } else {
+        const int ne23 = ne12*ne13;
+
+        ggml_sycl_pool_alloc ptrs_src(ctx.pool(), 2*ne23);
+        ggml_sycl_pool_alloc<      void *> ptrs_dst(ctx.pool(), 1*ne23);
+
+        sycl::range<3> block_dims(1, ne12, ne13);
+        /*
+        DPCT1049:47: The work-group size passed to the SYCL kernel may exceed
+        the limit. To get the device limit, query
+        info::device::max_work_group_size. Adjust the work-group size if needed.
+        */
+        {
+            dpct::has_capability_or_fail(main_stream->get_device(),
+                                         {sycl::aspect::fp16});
+
+            main_stream->submit([&](sycl::handler &cgh) {
+                const void **ptrs_src_get = ptrs_src.get();
+                void **ptrs_dst_get = ptrs_dst.get();
+                size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
+                size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
+                cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
+                                 [=](sycl::nd_item<3> item_ct1) {
+                                     k_compute_batched_ptrs(
+                                         src0_as_f16, src1_f16,
+                                         dst_t, ptrs_src_get,
+                                         ptrs_dst_get, ne12, ne13, ne23,
+                                         nb02, nb03, nb12_scaled, nb13_scaled,
+                                         nbd2, nbd3, r2, r3, item_ct1);
+                                 });
+            });
+        }
+        SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
+            *main_stream, oneapi::mkl::transpose::trans,
+            oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
+            (const void **)(ptrs_src.get() + 0 * ne23),
+            dpct::library_data_t::real_half, nb01 / nb00,
+            (const void **)(ptrs_src.get() + 1 * ne23),
+            dpct::library_data_t::real_half, nb11 / nb10, beta,
+            (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
+            cu_compute_type)));
     }
-    ggml_backend_sycl_buffer_context * ctx = new  ggml_backend_sycl_buffer_context(buft_ctx->device, dev_ptr, buft_ctx->stream);
-    return ggml_backend_buffer_init(buft, ggml_backend_sycl_buffer_interface, ctx, size);
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -4368,303 +4245,320 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static size_t ggml_backend_sycl_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return 128;
-    UNUSED(buft);
+inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
+    // TODO: accuracy issues in MMQ
+    return false;
+}
+
+bool ggml_sycl_supports_dmmv(enum ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_F16:
+            return true;
+        default:
+            return false;
+    }
 }
 
-static size_t ggml_backend_sycl_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
-    return dpct::get_current_device().get_max_mem_alloc_size();
+static void ggml_sycl_mul_mat(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const bool split = ggml_backend_buffer_is_sycl_split(src0->buffer);
+    int64_t min_compute_capability = INT_MAX;
+
+    if (split) {
+        ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *) src0->buffer->buft->context;
+        auto & tensor_split = buft_ctx->tensor_split;
+        for (int id = 0; id < ggml_sycl_info().device_count; ++id) {
+            // skip devices that are not going to do any work:
+            if (tensor_split[id] >= (id + 1 < ggml_sycl_info().device_count ? tensor_split[id + 1] : 1.0f)) {
+                continue;
+            }
+
+            if (min_compute_capability > ggml_sycl_info().devices[id].cc) {
+                min_compute_capability = ggml_sycl_info().devices[id].cc;
+            }
+        }
+    } else {
+        min_compute_capability    = ggml_sycl_info().devices[ctx.device].cc;
+    }
+
+    // check data types and tensor shapes for custom matrix multiplication kernels:
+    bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
 
-    UNUSED(buft);
-}
+    bool use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
 
-static size_t ggml_backend_sycl_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
-    size_t size = ggml_nbytes(tensor);
-    int64_t ne0 = tensor->ne[0];
+    bool use_mul_mat_q =  ggml_sycl_supports_mmq(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
 
-    if (ggml_is_quantized(tensor->type)) {
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
-    }
+    // mmvq and mmq need the __dp4a instruction which is available for gen12+
+    // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
+    use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
+#ifdef SYCL_USE_XMX
+    use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
+#endif // SYCL_USE_XMX
 
-    return size;
+    // mmvq path is faster in the CUDA backend.
+    if (ctx.stream()->get_backend() == sycl::backend::ext_oneapi_cuda)
+        use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q;
 
-    UNUSED(buft);
+    if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+        // KQ single-batch
+        ggml_sycl_mul_mat_vec_p021(ctx, src0, src1, dst);
+    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+        // KQV single-batch
+        ggml_sycl_mul_mat_vec_nc(ctx, src0, src1, dst);
+    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+        // KQ + KQV multi-batch
+        ggml_sycl_mul_mat_batched_sycl(ctx, src0, src1, dst);
+    } else if (use_dequantize_mul_mat_vec) {
+        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
+    } else if (use_mul_mat_vec_q) {
+        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
+    } else if (use_mul_mat_q) {
+        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
+    } else {
+        ggml_sycl_op_mul_mat(ctx, src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
+    }
 }
 
-static ggml_backend_buffer_type_i ggml_backend_sycl_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_sycl_buffer_type_name,
-    /* .alloc_buffer     = */ ggml_backend_sycl_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_sycl_buffer_type_get_alignment,
-    /* .get_max_size     = */ ggml_backend_sycl_buffer_type_get_max_size,
-    /* .get_alloc_size   = */ ggml_backend_sycl_buffer_type_get_alloc_size,
-    /* .is_host          = */ nullptr,
+
+struct mmid_row_mapping {
+    int32_t i1;
+    int32_t i2;
 };
 
-ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int device) {
-    static std::mutex mutex;
-    std::lock_guard lock(mutex);
+__dpct_inline__ static void k_copy_src1_to_contiguous(
+    const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
+    int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
+    const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
+    int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
+    const sycl::nd_item<3> &item_ct1, int &src1_row) {
+    int32_t iid1 = item_ct1.get_group(2);
+    int32_t id = item_ct1.get_group(1);
 
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
+    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
 
-    if (device>=ggml_sycl_info().device_count or device<0) {
-        printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
-            device, ggml_sycl_info().device_count-1);
-        GGML_ASSERT(device(
+                cur_src1_row, 1);
+        row_mapping[src1_row] = {id, iid1};
     }
-    return &ggml_backend_sycl_buffer_types[device];
-}
+    /*
+    DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
+    sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
+    performance if there is no access to global memory.
+    */
+    item_ct1.barrier();
 
-ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(ggml_backend_sycl_context * ctx) {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_buffer_type\n");
+    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
+    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
 
-    int device = ctx->device;
-    if (device>=ggml_sycl_info().device_count or device<0) {
-        printf("ggml_backend_sycl_buffer_type error: device_index:%d is out of range [0, %d], miss to call ggml_backend_sycl_set_single_device()\n",
-            device, ggml_sycl_info().device_count-1);
-        GGML_ASSERT(device &item_ct1) {
+    int32_t i = item_ct1.get_group(2);
 
-    if (!ggml_backend_sycl_buffer_type_initialized) {
-        for (int i = 0; i < ggml_sycl_info().device_count; i++) {
-            ggml_backend_sycl_buffer_types[i] = {
-                /* .iface    = */ ggml_backend_sycl_buffer_type_interface,
-                /* .device   = */ nullptr,
-                /* .context  = */ new ggml_backend_sycl_buffer_type_context{i, GGML_SYCL_NAME + std::to_string(i), ctx->stream(i, 0)},
-            };
-        }
-        ggml_backend_sycl_buffer_type_initialized = true;
-    }
-    return &ggml_backend_sycl_buffer_types[device];
-}
+    const int32_t i1 = row_mapping[i].i1;
+    const int32_t i2 = row_mapping[i].i2;
 
-// sycl split buffer type
-static void get_row_split(int64_t * row_low, int64_t * row_high, const ggml_tensor * tensor, const std::array & tensor_split, int id) {
-    const int64_t nrows = ggml_nrows(tensor);
-    const int64_t rounding = get_row_rounding(tensor->type, tensor_split);
+    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
+    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
 
-    *row_low = id == 0 ? 0 : nrows*tensor_split[id];
-    *row_low -= *row_low % rounding;
-    if (id == ggml_sycl_info().device_count - 1) {
-        *row_high = nrows;
-    } else {
-        *row_high = nrows*tensor_split[id + 1];
-        *row_high -= *row_high % rounding;
+#pragma unroll
+    for (int j = item_ct1.get_local_id(2); j < ne0;
+         j += item_ct1.get_local_range(2)) {
+        dst_row_original[j] = dst_row_contiguous[j];
     }
 }
 
-struct ggml_backend_sycl_split_buffer_context {
-    ~ggml_backend_sycl_split_buffer_context() try {
-        for (ggml_tensor_extra_gpu * extra : tensor_extras) {
-            for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-                for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
-                    if (extra->events[i][is] != nullptr) {
-                        /*
-                        DPCT1009:206: SYCL uses exceptions to report errors and
-                        does not use the error codes. The original code was
-                        commented out and a warning string was inserted. You
-                        need to rewrite this code.
-                        */
-                        SYCL_CHECK(CHECK_TRY_ERROR(
-                            dpct::destroy_event(extra->events[i][is])));
-                    }
-                }
-                if (extra->data_device[i] != nullptr) {
-                    /*
-                    DPCT1009:207: SYCL uses exceptions to report errors and does
-                    not use the error codes. The original code was commented out
-                    and a warning string was inserted. You need to rewrite this
-                    code.
-                    */
-                    ggml_sycl_set_device(i);
-                    SYCL_CHECK(CHECK_TRY_ERROR(sycl::free(
-                        extra->data_device[i], *(streams[i]))));
-                }
-            }
-            delete extra;
-        }
-    }
-    catch (sycl::exception const &exc) {
-      std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-                << ", line:" << __LINE__ << std::endl;
-      std::exit(1);
-    }
+static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_tensor *src0,
+                                 const ggml_tensor *src1,
+                                 ggml_tensor *dst) try {
+    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
 
-    std::vector tensor_extras;
-    std::vector streams;
-};
+    const ggml_tensor *ids = dst->src[2];
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backend_buffer_t buffer) {
-    return GGML_SYCL_NAME "_Split";
+    const queue_ptr stream = ctx.stream();
 
-    UNUSED(buffer);
-}
+    const int64_t n_as = ne02;
+    const int64_t n_ids = ids->ne[0];
 
-static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
-   return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
-}
+    std::vector ids_host(ggml_nbytes(ids));
+    const char * ids_dev = (const char *) ids->data;
 
-static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
-    delete ctx;
-}
+    SYCL_CHECK(CHECK_TRY_ERROR(
+        stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
+    SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
 
-static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) {
-    // the pointers are stored in the tensor extras, this is just a dummy address and never dereferenced
-    return (void *)0x1000;
+    ggml_tensor src0_row = *src0;
+    ggml_tensor src1_row = *src1;
+    ggml_tensor dst_row = *dst;
 
-    UNUSED(buffer);
-}
+    char *src0_original = (char *)src0->data;
+    char *src1_original = (char *)src1->data;
+    char *dst_original = (char *)dst->data;
+
+    src0_row.ne[2] = 1;
+    src0_row.ne[3] = 1;
+    src0_row.nb[3] = nb02;
+
+    src1_row.ne[1] = 1;
+    src1_row.ne[2] = 1;
+    src1_row.ne[3] = 1;
+    src1_row.nb[2] = nb11;
+    src1_row.nb[3] = nb11;
+
+    dst_row.ne[1] = 1;
+    dst_row.ne[2] = 1;
+    dst_row.ne[3] = 1;
+    dst_row.nb[2] = nb1;
+    dst_row.nb[3] = nb1;
+    if (ne12 == 1) {
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+            for (int64_t id = 0; id < n_ids; id++) {
+                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+                GGML_ASSERT(i02 >= 0 && i02 < n_as);
+
+                const int64_t i11 = id % ne11;
+                const int64_t i12 = iid1;
+
+                const int64_t i1 = id;
+                const int64_t i2 = i12;
 
-static void
-ggml_backend_sycl_split_buffer_init_tensor(ggml_backend_buffer_t buffer,
-                                           ggml_tensor *tensor) try {
-    GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
+            src0_row.data = src0_original + i02*nb02;
+            src1_row.data = src1_original + + i11*nb11 + i12*nb12;
+            dst_row.data = dst_original + i1*nb1   + i2*nb2;
 
-    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
-    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
+            ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+            }
+        }
+    } else {
+        ggml_sycl_pool_alloc src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
+        ggml_sycl_pool_alloc  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
 
-    const int64_t ne0 = tensor->ne[0];
+        src1_row.data = src1_contiguous.get();
+        dst_row.data  =  dst_contiguous.get();
 
-    ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu{};
+        for (int64_t i02 = 0; i02 < n_as; i02++) {
+            int64_t num_src1_rows = 0;
+            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+                for (int64_t id = 0; id < n_ids; id++) {
+                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
 
-    ctx->tensor_extras.push_back(extra);
-        ctx->streams.push_back(&(dpct::get_current_device().default_queue()));
+                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
 
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
+                    if (row_id_i != i02) {
+                        continue;
+                    }
 
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
+                    num_src1_rows++;
+                }
+            }
 
-        size_t size = ggml_nbytes_split(tensor, nrows_split);
-        const size_t original_size = size;
+            if (num_src1_rows == 0) {
+                continue;
+            }
 
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
 
-        // FIXME: do not crash if cudaMalloc fails
-        // currently, init_tensor cannot fail, it needs to be fixed in ggml-backend first
-        ggml_sycl_set_device(i);
-        const queue_ptr stream = ctx->streams[i];
-        char * buf;
-        /*
-        DPCT1009:208: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        SYCL_CHECK(CHECK_TRY_ERROR(buf = (char *)sycl::malloc_device(
-                                        size, *stream)));
-        if (!buf) {
-            char err_buf[1024];
-            snprintf(err_buf, 1023, "%s: can't malloc %lu Bytes memory on device", __func__, size);
-            throw std::runtime_error(err_buf);
-        }
-        // set padding to 0 to avoid possible NaN values
-        if (size > original_size) {
-            /*
-            DPCT1009:209: SYCL uses exceptions to report errors and does not use
-            the error codes. The original code was commented out and a warning
-            string was inserted. You need to rewrite this code.
-            */
+            ggml_sycl_pool_alloc dev_cur_src1_row(ctx.pool(), 1);
+            ggml_sycl_pool_alloc dev_row_mapping(ctx.pool(), num_src1_rows);
             SYCL_CHECK(CHECK_TRY_ERROR(
-                (*stream)
-                    .memset(buf + original_size, 0, size - original_size)
-                    .wait()));
-        }
+                stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
 
-        extra->data_device[i] = buf;
+            {
+                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
+                sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
+                stream->submit([&](sycl::handler &cgh) {
+                    sycl::local_accessor src1_row_acc(cgh);
 
-        for (int64_t is = 0; is < GGML_SYCL_MAX_STREAMS; ++is) {
-            /*
-            DPCT1009:210: SYCL uses exceptions to report errors and does not use
-            the error codes. The original code was commented out and a warning
-            string was inserted. You need to rewrite this code.
-            */
-            SYCL_CHECK(
-                CHECK_TRY_ERROR(extra->events[i][is] = new sycl::event()));
-        }
-    }
-    tensor->backend = GGML_BACKEND_TYPE_GPU_SPLIT;
-    tensor->extra = extra;
-}
-catch (sycl::exception const &exc) {
-  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
-            << ", line:" << __LINE__ << std::endl;
-  std::exit(1);
-}
+                    char *__restrict src1_contiguous_get =
+                        src1_contiguous.get();
+                    int *__restrict dev_cur_src1_row_get =
+                        dev_cur_src1_row.get();
+                    mmid_row_mapping *__restrict dev_row_mapping_get =
+                        dev_row_mapping.get();
+                    size_t ids_nb_ct6 = ids->nb[1];
+                    size_t ids_nb_ct7 = ids->nb[0];
 
-static void
-ggml_backend_sycl_split_buffer_set_tensor(ggml_backend_buffer_t buffer,
-                                          ggml_tensor *tensor, const void *data,
-                                          size_t offset, size_t size) try {
-    // split tensors must always be set in their entirety at once
-    GGML_ASSERT(offset == 0);
-    GGML_ASSERT(size == ggml_nbytes(tensor));
+                    cgh.parallel_for(
+                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                        [=](sycl::nd_item<3> item_ct1) {
+                            k_copy_src1_to_contiguous(
+                                src1_original, src1_contiguous_get,
+                                dev_cur_src1_row_get,
+                                dev_row_mapping_get, ids_dev, i02,
+                                ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
+                                item_ct1, src1_row_acc);
+                        });
+                });
+            }
 
-    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
-    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
+            src0_row.data = src0_original + i02*nb02;
 
-    const int64_t ne0 = tensor->ne[0];
-    const size_t nb1 = tensor->nb[1];
-    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+            GGML_ASSERT(nb11 == sizeof(float)*ne10);
+            GGML_ASSERT(nb1 == sizeof(float)*ne0);
+            src1_row.ne[1] = num_src1_rows;
 
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
+            src1_row.nb[1] = nb11;
+            src1_row.nb[2] = num_src1_rows*nb11;
+            src1_row.nb[3] = num_src1_rows*nb11;
 
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
+            dst_row.ne[1] = num_src1_rows;
+            dst_row.nb[1] = nb1;
+            dst_row.nb[2] = num_src1_rows*nb1;
+            dst_row.nb[3] = num_src1_rows*nb1;
 
-        const size_t offset_split = row_low*nb1;
-        size_t size = ggml_nbytes_split(tensor, nrows_split);
-        const size_t original_size = size;
+            ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
 
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
+            {
+                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
+                sycl::range<3> grid_dims(1, 1, num_src1_rows);
+                stream->submit([&](sycl::handler &cgh) {
+                    const char *__restrict dst_contiguous_get =
+                        dst_contiguous.get();
+                    const mmid_row_mapping *__restrict dev_row_mapping_get =
+                        dev_row_mapping.get();
 
-        const char * buf_host = (const char *)data + offset_split;
-        /*
-        DPCT1009:211: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        ggml_sycl_set_device(i);
-        const queue_ptr stream = ctx->streams[i];
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            (*stream)
-                .memcpy(extra->data_device[i], buf_host, original_size)
-                .wait()));
+                    cgh.parallel_for(
+                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                        [=](sycl::nd_item<3> item_ct1) {
+                            k_copy_dst_from_contiguous(dst_original,
+                                                       dst_contiguous_get,
+                                                       dev_row_mapping_get,
+                                                       ne0, nb1, nb2, item_ct1);
+                        });
+                });
+            }
+        }
     }
 }
 catch (sycl::exception const &exc) {
@@ -4673,52 +4567,55 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static void
-ggml_backend_sycl_split_buffer_get_tensor(ggml_backend_buffer_t buffer,
-                                          const ggml_tensor *tensor, void *data,
-                                          size_t offset, size_t size) try {
-    // split tensors must always be set in their entirety at once
-    GGML_ASSERT(offset == 0);
-    GGML_ASSERT(size == ggml_nbytes(tensor));
+static void ggml_sycl_scale(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_scale);
+}
 
-    ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;
-    ggml_backend_sycl_split_buffer_type_context * buft_ctx = (ggml_backend_sycl_split_buffer_type_context *)buffer->buft->context;
+static void ggml_sycl_clamp(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_clamp);
+}
 
-    const int64_t ne0 = tensor->ne[0];
-    const size_t nb1 = tensor->nb[1];
-    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *)tensor->extra;
+static void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor *src0, const ggml_tensor *src1,
+                          ggml_tensor *dst) try {
+    const int64_t ne = ggml_nelements(src0);
+    GGML_ASSERT(ne == ggml_nelements(src1));
 
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, buft_ctx->tensor_split, i);
+    GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
+    GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
 
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
+    GGML_TENSOR_BINARY_OP_LOCALS01;
 
-        const size_t offset_split = row_low*nb1;
-        size_t size = ggml_nbytes_split(tensor, nrows_split);
-        const size_t original_size = size;
+    SYCL_CHECK(ggml_sycl_set_device(ctx.device));
+    queue_ptr main_stream = ctx.stream();
 
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
+    char * src0_ddc = (char *) src0->data;
+    char * src1_ddc = (char *) src1->data;
 
-        char * buf_host = (char *)data + offset_split;
-        /*
-        DPCT1009:212: SYCL uses exceptions to report errors and does not use the
-        error codes. The original code was commented out and a warning string
-        was inserted. You need to rewrite this code.
-        */
-        ggml_sycl_set_device(i);
-        const queue_ptr stream = ctx->streams[i];
-        SYCL_CHECK(CHECK_TRY_ERROR(
-            (*stream)
-                .memcpy(buf_host, extra->data_device[i], original_size)
-                .wait()));
+    if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_f32_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f32_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
+        ggml_cpy_f32_q8_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
+        ggml_cpy_f32_q4_0_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
+        ggml_cpy_f32_q4_1_sycl(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
+        ggml_cpy_f16_f32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
+        ggml_cpy_f16_f16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16) {
+        ggml_cpy_i16_i16_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) {
+        ggml_cpy_i32_i32_sycl (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
+    } else {
+        fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
+                ggml_type_name(src0->type), ggml_type_name(src1->type));
+        GGML_ABORT("fatal error");
     }
+
+    (void) dst;
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -4726,183 +4623,259 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
-static void ggml_backend_sycl_split_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
-    UNUSED(buffer);
-    UNUSED(value);
+static void ggml_sycl_dup(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    // TODO: why do we pass dst as src1 here?
+    ggml_sycl_cpy(ctx, src0, dst, nullptr);
+    (void) src1;
 }
 
-static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = {
-    /* .get_name        = */ ggml_backend_sycl_split_buffer_get_name,
-    /* .free_buffer     = */ ggml_backend_sycl_split_buffer_free_buffer,
-    /* .get_base        = */ ggml_backend_sycl_split_buffer_get_base,
-    /* .init_tensor     = */ ggml_backend_sycl_split_buffer_init_tensor,
-    /* .memset_tensor   = */ NULL,
-    /* .set_tensor      = */ ggml_backend_sycl_split_buffer_set_tensor,
-    /* .get_tensor      = */ ggml_backend_sycl_split_buffer_get_tensor,
-    /* .cpy_tensor      = */ NULL,
-    /* .clear           = */ ggml_backend_sycl_split_buffer_clear,
-    /* .reset           = */ NULL,
-};
-
-static const char * ggml_backend_sycl_split_buffer_type_name(ggml_backend_buffer_type_t buft) {
-    return GGML_SYCL_NAME "_Split";
-
-    UNUSED(buft);
+static void ggml_sycl_diag_mask_inf(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_diag_mask_inf);
 }
 
-static ggml_backend_buffer_t ggml_backend_sycl_split_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    // since we don't know the exact split after rounding, we cannot allocate the device buffers at this point
-    // instead, we allocate them for each tensor separately in init_tensor
-    // however, the size still represents the maximum cumulative size of all the device buffers after the tensors are allocated,
-    // as returned by get_alloc_size. this limit is enforced during tensor allocation by ggml-alloc, so it must be correct.
-    ggml_backend_sycl_split_buffer_context * ctx = new ggml_backend_sycl_split_buffer_context();
-
-    return ggml_backend_buffer_init(buft, ggml_backend_sycl_split_buffer_interface, ctx, size);
+static void ggml_sycl_soft_max(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_soft_max);
 }
 
-static size_t ggml_backend_sycl_split_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
-    return 128;
-    UNUSED(buft);
+static void ggml_sycl_rope(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_rope);
 }
 
-static size_t ggml_backend_sycl_split_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
-    ggml_backend_sycl_split_buffer_type_context * ctx = (ggml_backend_sycl_split_buffer_type_context *)buft->context;
+static void ggml_sycl_pool2d(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_pool2d);
+}
 
-    size_t total_size = 0;
+static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
+}
 
-    const int64_t ne0 = tensor->ne[0];
+static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
+}
 
-    for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-        int64_t row_low, row_high;
-        get_row_split(&row_low, &row_high, tensor, ctx->tensor_split, i);
+static void ggml_sycl_argsort(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    GGML_ASSERT(ggml_is_contiguous(src0));
+    ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_argsort);
+}
 
-        int64_t nrows_split = row_high - row_low;
-        if (nrows_split == 0) {
-            continue;
-        }
+static void ggml_sycl_nop(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    (void) src0;
+    (void) src1;
+    (void) dst;
+}
 
-        total_size += ggml_nbytes_split(tensor, nrows_split);
+void ggml_sycl_set_main_device(const int main_device) try {
+    if (dpct::get_current_device_id() == main_device) return;
+    check_allow_gpu_index(main_device);
+    dpct::select_device(main_device);
 
-        // pad last row to a multiple of 512 elements to avoid out-of-bounds memory accesses
-        if (ne0 % MATRIX_ROW_PADDING != 0) {
-            total_size += ggml_row_size(tensor->type, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING);
-        }
+    if (g_ggml_sycl_debug) {
+        dpct::device_info prop;
+        SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
+            prop, dpct::dev_mgr::instance().get_device(main_device))));
+        fprintf(stderr, "Using device %d (%s) as main device\n",
+                main_device, prop.get_name());
     }
-
-    return total_size;
 }
-
-static bool ggml_backend_sycl_split_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
-    return false;
-
-    UNUSED(buft);
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
-static ggml_backend_buffer_type_i ggml_backend_sycl_split_buffer_type_interface = {
-    /* .get_name         = */ ggml_backend_sycl_split_buffer_type_name,
-    /* .alloc_buffer     = */ ggml_backend_sycl_split_buffer_type_alloc_buffer,
-    /* .get_alignment    = */ ggml_backend_sycl_split_buffer_type_get_alignment,
-    /* .get_max_size     = */ NULL, // defaults to SIZE_MAX
-    /* .get_alloc_size   = */ ggml_backend_sycl_split_buffer_type_get_alloc_size,
-    /* .is_host          = */ ggml_backend_sycl_split_buffer_type_is_host,
-};
-
-ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split) {
-    static std::mutex mutex;
-    std::lock_guard lock(mutex);
-
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_split_buffer_type\n");
-    ggml_check_sycl();
-    // FIXME: this is not thread safe
-    static std::map, struct ggml_backend_buffer_type> buft_map;
+bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tensor * tensor) {
+    if (!g_sycl_loaded) return false;
 
-    std::array tensor_split_arr = {};
+    ggml_sycl_func_t func;
 
-    bool all_zero = tensor_split == nullptr || std::all_of(tensor_split, tensor_split + GGML_SYCL_MAX_DEVICES, [](float x) { return x == 0.0f; });
-    if (all_zero) {
-        tensor_split_arr = ggml_sycl_info().default_tensor_split;
-    } else {
-        float split_sum = 0.0f;
-        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-            tensor_split_arr[i] = split_sum;
-            split_sum += tensor_split[i];
-        }
-        for (int i = 0; i < ggml_sycl_info().device_count; ++i) {
-            tensor_split_arr[i] /= split_sum;
-        }
+    switch (tensor->op) {
+        case GGML_OP_CONV_TRANSPOSE_1D:
+            func = ggml_sycl_op_conv_transpose_1d;
+            break;
+        case GGML_OP_REPEAT:
+            func = ggml_sycl_repeat;
+            break;
+        case GGML_OP_GET_ROWS:
+            func = ggml_sycl_get_rows;
+            break;
+        case GGML_OP_DUP:
+            func = ggml_sycl_dup;
+            break;
+        case GGML_OP_ADD:
+            func = ggml_sycl_add;
+            break;
+        case GGML_OP_ACC:
+            func = ggml_sycl_acc;
+            break;
+        case GGML_OP_MUL:
+            func = ggml_sycl_mul;
+            break;
+        case GGML_OP_DIV:
+            func = ggml_sycl_div;
+            break;
+        case GGML_OP_UNARY:
+            switch (ggml_get_unary_op(tensor)) {
+                case GGML_UNARY_OP_GELU:
+                    func = ggml_sycl_gelu;
+                    break;
+                case GGML_UNARY_OP_SILU:
+                    func = ggml_sycl_silu;
+                    break;
+                case GGML_UNARY_OP_GELU_QUICK:
+                    func = ggml_sycl_gelu_quick;
+                    break;
+                case GGML_UNARY_OP_TANH:
+                    func = ggml_sycl_tanh;
+                    break;
+                case GGML_UNARY_OP_RELU:
+                    func = ggml_sycl_relu;
+                    break;
+                case GGML_UNARY_OP_HARDSIGMOID:
+                    func = ggml_sycl_hardsigmoid;
+                    break;
+                case GGML_UNARY_OP_HARDSWISH:
+                    func = ggml_sycl_hardswish;
+                    break;
+                default:
+                    return false;
+            }
+            break;
+        case GGML_OP_NORM:
+            func = ggml_sycl_norm;
+            break;
+        case GGML_OP_GROUP_NORM:
+            func = ggml_sycl_group_norm;
+            break;
+        case GGML_OP_CONCAT:
+            func = ggml_sycl_op_concat;
+            break;
+        case GGML_OP_UPSCALE:
+            func = ggml_sycl_upscale;
+            break;
+        case GGML_OP_PAD:
+            func = ggml_sycl_pad;
+            break;
+        case GGML_OP_LEAKY_RELU:
+            func = ggml_sycl_leaky_relu;
+            break;
+        case GGML_OP_RMS_NORM:
+            func = ggml_sycl_rms_norm;
+            break;
+        case GGML_OP_MUL_MAT:
+            if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
+                return false;
+            }
+            func = ggml_sycl_mul_mat;
+            break;
+        case GGML_OP_MUL_MAT_ID:
+            if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
+                return false;
+            }
+            func = ggml_sycl_mul_mat_id;
+            break;
+        case GGML_OP_SCALE:
+            func = ggml_sycl_scale;
+            break;
+        case GGML_OP_SQR:
+            func = ggml_sycl_sqr;
+            break;
+        case GGML_OP_CLAMP:
+            func = ggml_sycl_clamp;
+            break;
+        case GGML_OP_CPY:
+            func = ggml_sycl_cpy;
+            break;
+        case GGML_OP_CONT:
+            func = ggml_sycl_dup;
+            break;
+        case GGML_OP_NONE:
+        case GGML_OP_RESHAPE:
+        case GGML_OP_VIEW:
+        case GGML_OP_PERMUTE:
+        case GGML_OP_TRANSPOSE:
+            func = ggml_sycl_nop;
+            break;
+        case GGML_OP_DIAG_MASK_INF:
+            func = ggml_sycl_diag_mask_inf;
+            break;
+        case GGML_OP_SOFT_MAX:
+            func = ggml_sycl_soft_max;
+            break;
+        case GGML_OP_ROPE:
+            func = ggml_sycl_rope;
+            break;
+        case GGML_OP_IM2COL:
+            func = ggml_sycl_im2col;
+            break;
+        case GGML_OP_POOL_2D:
+            func = ggml_sycl_pool2d;
+            break;
+        case GGML_OP_SUM_ROWS:
+            func = ggml_sycl_sum_rows;
+            break;
+        case GGML_OP_ARGSORT:
+            func = ggml_sycl_argsort;
+            break;
+        case GGML_OP_TIMESTEP_EMBEDDING:
+            func = ggml_sycl_op_timestep_embedding;
+            break;
+        default:
+            return false;
     }
 
-    auto it = buft_map.find(tensor_split_arr);
-    if (it != buft_map.end()) {
-        return &it->second;
+    if (tensor->src[0] != nullptr && ggml_backend_buffer_is_sycl_split(tensor->src[0]->buffer)) {
+        ggml_sycl_set_peer_access(tensor->src[1]->ne[1], ctx.device);
     }
 
-    struct ggml_backend_buffer_type buft {
-        /* .iface   = */ ggml_backend_sycl_split_buffer_type_interface,
-        /* .device  = */ nullptr,
-        /* .context = */ new ggml_backend_sycl_split_buffer_type_context{tensor_split_arr},
-    };
-
-    auto result = buft_map.emplace(tensor_split_arr, buft);
-    return &result.first->second;
-}
-
-// host buffer type
-
-static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_type_t buft) {
-    return GGML_SYCL_NAME "_Host";
-
-    UNUSED(buft);
+    func(ctx, tensor->src[0], tensor->src[1], tensor);
+    return true;
 }
 
-static const char * ggml_backend_sycl_host_buffer_name(ggml_backend_buffer_t buffer) {
-    return GGML_SYCL_NAME "_Host";
-
-    UNUSED(buffer);
+GGML_API void ggml_backend_sycl_get_device_description(int device, char *description,
+                                      size_t description_size) try {
+    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_description\n");
+    dpct::device_info prop;
+    SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
+        prop, dpct::dev_mgr::instance().get_device(device))));
+    snprintf(description, description_size, "%s", prop.get_name());
 }
-
-static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) {
-    ggml_sycl_host_free(buffer->context);
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
-static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
-    void * ptr = ggml_sycl_host_malloc(size);
-
-    if (ptr == nullptr) {
-        // fallback to cpu buffer
-        return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size);
-    }
-
-    // FIXME: this is a hack to avoid having to implement a new buffer type
-    ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size);
-    buffer->buft = buft;
-    buffer->iface.get_name = ggml_backend_sycl_host_buffer_name;
-    buffer->iface.free_buffer = ggml_backend_sycl_host_buffer_free_buffer;
+void ggml_backend_sycl_get_device_memory(int device, size_t *free,
+                                                   size_t *total) try {
+    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_memory\n");
+    ggml_sycl_set_device(device);
 
-    return buffer;
+    /*
+    DPCT1009:218: SYCL uses exceptions to report errors and does not use the
+    error codes. The original code was commented out and a warning string was
+    inserted. You need to rewrite this code.
+    */
+    /*
+    DPCT1106:217: 'cudaMemGetInfo' was migrated with the Intel extensions for
+    device information which may not be supported by all compilers or runtimes.
+    You may need to adjust the code.
+    */
+    SYCL_CHECK(CHECK_TRY_ERROR(
+        dpct::dev_mgr::instance().get_device(device).get_memory_info(*free, *total)));
 }
-
-ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type() {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_host_buffer_type\n");
-    static struct ggml_backend_buffer_type ggml_backend_sycl_buffer_type_host = {
-        /* .iface    = */ {
-            /* .get_name         = */ ggml_backend_sycl_host_buffer_type_name,
-            /* .alloc_buffer     = */ ggml_backend_sycl_host_buffer_type_alloc_buffer,
-            /* .get_alignment    = */ ggml_backend_cpu_buffer_type()->iface.get_alignment,
-            /* .get_max_size     = */ NULL, // TODO: return device.maxBufferLength
-            /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
-            /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
-        },
-        /* .device   = */ nullptr,
-        /* .context  = */ nullptr,
-    };
-
-    return &ggml_backend_sycl_buffer_type_host;
+catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
 }
 
+////////////////////////////////////////////////////////////////////////////////
+
 // backend
 
-static const char * ggml_backend_sycl_name(ggml_backend_t backend) {
+static const char * ggml_backend_sycl_get_name(ggml_backend_t backend) {
 
     ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
 
@@ -4931,8 +4904,8 @@ static void ggml_backend_sycl_set_tensor_async(ggml_backend_t backend,
 
     GGML_ASSERT(buf->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device) && "unsupported buffer type");
     const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
-    SYCL_CHECK(CHECK_TRY_ERROR((stream)->memcpy(
-        (char *)tensor->data + offset, data, size).wait()));
+    SYCL_CHECK(CHECK_TRY_ERROR(
+        (stream)->memcpy((char *)tensor->data + offset, data, size)));
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -4987,7 +4960,7 @@ static void ggml_backend_sycl_synchronize(ggml_backend_t backend) try {
     const queue_ptr stream = sycl_ctx->stream(sycl_ctx->device, 0);
     SYCL_CHECK(CHECK_TRY_ERROR((stream)->wait()));
 
-    UNUSED(backend);
+    GGML_UNUSED(backend);
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -5023,7 +4996,151 @@ static ggml_status ggml_backend_sycl_graph_compute(ggml_backend_t backend, ggml_
     return GGML_STATUS_SUCCESS;
 }
 
-static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
+static void ggml_backend_sycl_event_record(ggml_backend_t backend, ggml_backend_event_t event)
+try
+{
+    ggml_backend_sycl_context *sycl_ctx =
+        (ggml_backend_sycl_context *)backend->context;
+    sycl::event *sycl_event = static_cast(event->context);
+
+    const queue_ptr &stream = sycl_ctx->stream(sycl_ctx->device, 0);
+    // Record the current state of the queue
+    SYCL_CHECK(CHECK_TRY_ERROR(*sycl_event = stream->ext_oneapi_submit_barrier()));
+}
+catch (sycl::exception const &exc)
+{
+    std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+              << ", line:" << __LINE__ << std::endl;
+    std::exit(1);
+}
+
+static void ggml_backend_sycl_event_wait(ggml_backend_t backend, ggml_backend_event_t event) try {
+    ggml_backend_sycl_context* sycl_ctx = static_cast(backend->context);
+    sycl::event* sycl_event = static_cast(event->context);
+
+    if (ggml_backend_is_sycl(backend)) {
+        SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
+    } else
+        GGML_ABORT("fatal error");
+} catch (sycl::exception const& exc) {
+    std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+              << ", line:" << __LINE__ << std::endl;
+    std::exit(1);
+}
+
+static ggml_backend_i ggml_backend_sycl_interface = {
+    /* .get_name                = */ ggml_backend_sycl_get_name,
+    /* .free                    = */ ggml_backend_sycl_free,
+    /* .get_default_buffer_type = */ ggml_backend_sycl_get_default_buffer_type,
+    /* .set_tensor_async        = */ ggml_backend_sycl_set_tensor_async,
+    /* .get_tensor_async        = */ ggml_backend_sycl_get_tensor_async,
+    /* .cpy_tensor_async        = */ NULL, // ggml_backend_sycl_cpy_tensor_async,
+                                           // // TODO: update for the new
+                                           // interface
+    /* .synchronize             = */ ggml_backend_sycl_synchronize,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_sycl_graph_compute,
+    /* .supports_op             = */ NULL, // moved to device
+    /* .supports_buft           = */ NULL, // moved to device
+    /* .offload_op              = */ NULL, // moved to device
+    /* .event_record            = */ ggml_backend_sycl_event_record,
+    /* .event_wait              = */ ggml_backend_sycl_event_wait,
+};
+
+static ggml_guid_t ggml_backend_sycl_guid() {
+    static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 };
+    return &guid;
+}
+
+bool ggml_backend_is_sycl(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid());
+}
+
+int ggml_backend_sycl_get_device_count() {
+    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
+    return ggml_sycl_info().device_count;
+}
+
+
+// backend device
+
+struct ggml_backend_sycl_device_context {
+    int device;
+    std::string name;
+    std::string description;
+};
+
+static const char * ggml_backend_sycl_device_get_name(ggml_backend_dev_t dev) {
+    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
+    return ctx->name.c_str();
+}
+
+static const char * ggml_backend_sycl_device_get_description(ggml_backend_dev_t dev) {
+    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
+    return ctx->description.c_str();
+}
+
+static void ggml_backend_sycl_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) {
+    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
+    ggml_sycl_set_device(ctx->device);
+    SYCL_CHECK(CHECK_TRY_ERROR(
+    dpct::dev_mgr::instance().get_device(ctx->device).get_memory_info(*free, *total)));
+}
+
+static enum ggml_backend_dev_type ggml_backend_sycl_device_get_type(ggml_backend_dev_t dev) {
+    GGML_UNUSED(dev);
+    return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
+}
+
+static void ggml_backend_sycl_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_sycl_device_get_name(dev);
+    props->description = ggml_backend_sycl_device_get_description(dev);
+    props->type        = ggml_backend_sycl_device_get_type(dev);
+    ggml_backend_sycl_device_get_memory(dev, &props->memory_free, &props->memory_total);
+
+    bool host_buffer = getenv("GGML_SYCL_NO_PINNED") == nullptr;
+#ifdef GGML_SYCL_NO_PEER_COPY
+    bool events = false;
+#else
+    bool events = true;
+#endif
+
+    props->caps = {
+        /* .async                 = */ true,
+        /* .host_buffer           = */ host_buffer,
+        /* .buffer_from_host_ptr  = */ false,
+        /* .events                = */ events,
+    };
+}
+
+static ggml_backend_t ggml_backend_sycl_device_init(ggml_backend_dev_t dev, const char * params) {
+    GGML_UNUSED(params);
+    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
+    return ggml_backend_sycl_init(ctx->device);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_buffer_type(ggml_backend_dev_t dev) {
+    ggml_backend_sycl_device_context * ctx = (ggml_backend_sycl_device_context *)dev->context;
+    return ggml_backend_sycl_buffer_type(ctx->device);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_sycl_device_get_host_buffer_type(ggml_backend_dev_t dev) {
+    GGML_UNUSED(dev);
+    return ggml_backend_sycl_host_buffer_type();
+}
+
+static ggml_backend_buffer_t ggml_backend_sycl_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
+    GGML_UNUSED(dev);
+    GGML_UNUSED(ptr);
+    GGML_UNUSED(size);
+    GGML_UNUSED(max_tensor_size);
+    return nullptr;
+}
+
+static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
     switch (op->op) {
         case GGML_OP_CONV_TRANSPOSE_1D:
             {
@@ -5167,47 +5284,173 @@ static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, const ggml_ten
             return false;
     }
 
-    UNUSED(backend);
+    GGML_UNUSED(dev);
 }
 
-static bool ggml_backend_sycl_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
+static bool ggml_backend_sycl_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
+    if (buft->iface.get_name != ggml_backend_sycl_buffer_type_get_name) {
+        return false;
+    }
+    ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
+    ggml_backend_sycl_device_context * sycl_ctx = (ggml_backend_sycl_device_context *)dev->context;
+    return buft_ctx->device == sycl_ctx->device;
+}
+
+static bool ggml_backend_sycl_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
     const int min_batch_size = 32;
     return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS && op->op != GGML_OP_MUL_MAT_ID;
-    GGML_UNUSED(backend);
+    GGML_UNUSED(dev);
 }
 
-static bool ggml_backend_sycl_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
-    if (buft->iface.get_name != ggml_backend_sycl_buffer_type_name) {
-        return false;
+static ggml_backend_event_t
+ggml_backend_sycl_device_event_new(ggml_backend_dev_t dev) {
+
+#ifdef GGML_SYCL_NO_PEER_COPY
+    return nullptr;
+#else
+  sycl::event *event_ptr = new sycl::event();
+
+  return new ggml_backend_event{
+      /* .device = */ dev,
+      /* .context = */ event_ptr,
+  };
+#endif
+}
+
+static void ggml_backend_sycl_device_event_free(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
+  GGML_UNUSED(dev);
+  if (event == nullptr) {
+    return;
+  }
+
+  if (event->context != nullptr) {
+    sycl::event *sycl_event = static_cast(event->context);
+    delete sycl_event;
+    event->context = nullptr;
+  }
+
+  delete event;
+} catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+
+static void ggml_backend_sycl_device_event_synchronize(ggml_backend_dev_t dev, ggml_backend_event_t event) try {
+  GGML_UNUSED(dev);
+
+  sycl::event *sycl_event = static_cast(event->context);
+  SYCL_CHECK(CHECK_TRY_ERROR(sycl_event->wait()));
+} catch (sycl::exception const &exc) {
+  std::cerr << exc.what() << "Exception caught at file:" << __FILE__
+            << ", line:" << __LINE__ << std::endl;
+  std::exit(1);
+}
+
+static const ggml_backend_device_i ggml_backend_sycl_device_interface = {
+    /* .get_name                = */ ggml_backend_sycl_device_get_name,
+    /* .get_description         = */ ggml_backend_sycl_device_get_description,
+    /* .get_memory              = */ ggml_backend_sycl_device_get_memory,
+    /* .get_type                = */ ggml_backend_sycl_device_get_type,
+    /* .get_props               = */ ggml_backend_sycl_device_get_props,
+    /* .init_backend            = */ ggml_backend_sycl_device_init,
+    /* .get_buffer_type         = */ ggml_backend_sycl_device_get_buffer_type,
+    /* .get_host_buffer_type    = */ ggml_backend_sycl_device_get_host_buffer_type,
+    /* .buffer_from_host_ptr    = */ ggml_backend_sycl_device_buffer_from_host_ptr,
+    /* .supports_op             = */ ggml_backend_sycl_device_supports_op,
+    /* .supports_buft           = */ ggml_backend_sycl_device_supports_buft,
+    /* .offload_op              = */ ggml_backend_sycl_device_offload_op,
+    /* .event_new               = */ ggml_backend_sycl_device_event_new,
+    /* .event_free              = */ ggml_backend_sycl_device_event_free,
+    /* .event_synchronize       = */ ggml_backend_sycl_device_event_synchronize,
+};
+
+// backend reg
+
+struct ggml_backend_sycl_reg_context {
+    std::vector devices;
+};
+
+static const char * ggml_backend_sycl_reg_get_name(ggml_backend_reg_t reg) {
+    GGML_UNUSED(reg);
+    return GGML_SYCL_NAME;
+}
+
+static size_t ggml_backend_sycl_reg_get_device_count(ggml_backend_reg_t reg) {
+    ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
+    return ctx->devices.size();
+}
+
+static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t reg, size_t index) {
+    ggml_backend_sycl_reg_context * ctx = (ggml_backend_sycl_reg_context *)reg->context;
+    GGML_ASSERT(index < ctx->devices.size());
+    return ctx->devices[index];
+}
+
+static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name)
+{
+    GGML_UNUSED(reg);
+    if (strcmp(name, "ggml_backend_split_buffer_type") == 0) {
+        return (void *)ggml_backend_sycl_split_buffer_type;
     }
-    ggml_backend_sycl_buffer_type_context * buft_ctx = (ggml_backend_sycl_buffer_type_context *)buft->context;
-    ggml_backend_sycl_context * sycl_ctx = (ggml_backend_sycl_context *)backend->context;
-    return buft_ctx->device == sycl_ctx->device;
+    // SYCL doesn't support registering host memory, left here for reference
+    // "ggml_backend_register_host_buffer"
+    // "ggml_backend_unregister_host_buffer"
+    return nullptr;
 }
 
-static ggml_backend_i ggml_backend_sycl_interface = {
-    /* .get_name                = */ ggml_backend_sycl_name,
-    /* .free                    = */ ggml_backend_sycl_free,
-    /* .get_default_buffer_type = */ ggml_backend_sycl_get_default_buffer_type,
-    /* .set_tensor_async        = */ ggml_backend_sycl_set_tensor_async,
-    /* .get_tensor_async        = */ ggml_backend_sycl_get_tensor_async,
-    /* .cpy_tensor_async        = */ NULL, //ggml_backend_sycl_cpy_tensor_async, // TODO: update for the new interface
-    /* .synchronize             = */ ggml_backend_sycl_synchronize,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_sycl_graph_compute,
-    /* .supports_op             = */ ggml_backend_sycl_supports_op,
-    /* .supports_buft           = */ ggml_backend_sycl_supports_buft,
-    /* .offload_op              = */ ggml_backend_sycl_offload_op,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
+static const ggml_backend_reg_i ggml_backend_sycl_reg_interface = {
+    /* .get_name          = */ ggml_backend_sycl_reg_get_name,
+    /* .get_device_count  = */ ggml_backend_sycl_reg_get_device_count,
+    /* .get_device_get    = */ ggml_backend_sycl_reg_get_device,
+    /* .get_proc_address  = */ ggml_backend_sycl_reg_get_proc_address,
 };
 
-static ggml_guid_t ggml_backend_sycl_guid() {
-    static ggml_guid guid = { 0x58, 0x05, 0x13, 0x8f, 0xcd, 0x3a, 0x61, 0x9d, 0xe7, 0xcd, 0x98, 0xa9, 0x03, 0xfd, 0x7c, 0x53 };
-    return &guid;
+
+// backend registry
+
+ggml_backend_reg_t ggml_backend_sycl_reg() {
+    static ggml_backend_reg reg;
+    static bool initialized = false;
+
+    {
+        static std::mutex mutex;
+        std::lock_guard lock(mutex);
+        if (!initialized) {
+            ggml_backend_sycl_reg_context * ctx = new ggml_backend_sycl_reg_context;
+
+            for (int i = 0; i < ggml_sycl_info().device_count; i++) {
+                ggml_backend_sycl_device_context * dev_ctx = new ggml_backend_sycl_device_context;
+                dev_ctx->device = i;
+                dev_ctx->name = GGML_SYCL_NAME + std::to_string(i);
+
+                ggml_sycl_set_device(i);
+
+                dpct::device_info prop;
+                SYCL_CHECK(CHECK_TRY_ERROR(dpct::get_device_info(
+                    prop, dpct::dev_mgr::instance().get_device(i))));
+
+                dev_ctx->description = prop.get_name();
+
+                ggml_backend_dev_t dev = new ggml_backend_device {
+                    /* .interface = */ ggml_backend_sycl_device_interface,
+                    /* .reg       = */ ®,
+                    /* .context   = */ dev_ctx
+                };
+                ctx->devices.push_back(dev);
+            }
+
+            reg = ggml_backend_reg {
+                /* .interface = */ ggml_backend_sycl_reg_interface,
+                /* .context   = */ ctx
+            };
+        }
+
+        initialized = true;
+    }
+
+    return ®
 }
 
 ggml_backend_t ggml_backend_sycl_init(int device) {
@@ -5225,18 +5468,10 @@ ggml_backend_t ggml_backend_sycl_init(int device) {
     ggml_backend_t sycl_backend = new ggml_backend {
         /* .guid      = */ ggml_backend_sycl_guid(),
         /* .interface = */ ggml_backend_sycl_interface,
-        /* .device    = */ nullptr,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_sycl_reg(), device),
         /* .context   = */ ctx
     };
 
     return sycl_backend;
 }
 
-bool ggml_backend_is_sycl(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_sycl_guid());
-}
-
-int ggml_backend_sycl_get_device_count() {
-    GGML_SYCL_DEBUG("[SYCL] call ggml_backend_sycl_get_device_count\n");
-    return ggml_sycl_info().device_count;
-}
diff --git a/ggml/src/ggml-sycl/mmvq.cpp b/ggml/src/ggml-sycl/mmvq.cpp
index 1b96925e14e..7b10cf68814 100644
--- a/ggml/src/ggml-sycl/mmvq.cpp
+++ b/ggml/src/ggml-sycl/mmvq.cpp
@@ -1,6 +1,6 @@
 #include "mmvq.hpp"
 #include "vecdotq.hpp"
-
+#include 
 
 template 
 static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows,
@@ -13,7 +13,8 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+    const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
+    assert(blocks_per_warp>0);
 
 // partial sum for each thread
     float tmp = 0.0f;
@@ -37,7 +38,7 @@ static void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict_
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -61,7 +62,8 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
+    const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
+    assert(blocks_per_warp>0);
 
 // partial sum for each thread
     float tmp = 0.0f;
@@ -85,7 +87,7 @@ static void mul_mat_vec_q_iq2_xxs_q8_1(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -109,8 +111,8 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
+    const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
+    assert(blocks_per_warp>0);
 // partial sum for each thread
     float tmp = 0.0f;
 
@@ -133,7 +135,7 @@ static void mul_mat_vec_q_iq2_xs_q8_1(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -157,8 +159,8 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
+    const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
+    assert(blocks_per_warp>0);
 // partial sum for each thread
     float tmp = 0.0f;
 
@@ -181,7 +183,7 @@ static void mul_mat_vec_q_iq2_s_q8_1(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -205,8 +207,8 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
+    const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
+    assert(blocks_per_warp>0);
 // partial sum for each thread
     float tmp = 0.0f;
 
@@ -229,7 +231,7 @@ static void mul_mat_vec_q_iq3_xxs_q8_1(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -253,8 +255,8 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
+    const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
+    assert(blocks_per_warp>0);
 // partial sum for each thread
     float tmp = 0.0f;
 
@@ -277,7 +279,7 @@ static void mul_mat_vec_q_iq3_s_q8_1(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -301,8 +303,8 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
+    const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
+    assert(blocks_per_warp>0);
 // partial sum for each thread
     float tmp = 0.0f;
 
@@ -325,7 +327,7 @@ static void mul_mat_vec_q_iq1_s_q8_1(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -349,8 +351,8 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
+    const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
+    assert(blocks_per_warp>0);
 // partial sum for each thread
     float tmp = 0.0f;
 
@@ -373,7 +375,7 @@ static void mul_mat_vec_q_iq1_m_q8_1(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -397,8 +399,8 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
+    const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
+    assert(blocks_per_warp>0);
 // partial sum for each thread
     float tmp = 0.0f;
 
@@ -421,7 +423,7 @@ static void mul_mat_vec_q_iq4_nl_q8_1(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -446,8 +448,8 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
     }
 
     const int blocks_per_row = ncols / qk;
-    const int blocks_per_warp = vdr * WARP_SIZE / qi;
-
+    const int blocks_per_warp = vdr * QK_WARP_SIZE / qi;
+    assert(blocks_per_warp>0);
 // partial sum for each thread
     float tmp = 0.0f;
 
@@ -470,7 +472,7 @@ static void mul_mat_vec_q_iq4_xs_q8_1(const void *__restrict__ vx,
 
     // sum up partial sums and write back result
 #pragma unroll
-    for (int mask = WARP_SIZE / 2; mask > 0; mask >>= 1) {
+    for (int mask = QK_WARP_SIZE / 2; mask > 0; mask >>= 1) {
         tmp +=
             dpct::permute_sub_group_by_xor(item_ct1.get_sub_group(), tmp, mask);
     }
@@ -487,7 +489,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK4_0 == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -495,7 +497,7 @@ static void mul_mat_vec_q4_0_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q(
                             vx, vy, dst, ncols, nrows, item_ct1);
@@ -511,7 +513,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK4_1 == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -519,7 +521,7 @@ static void mul_mat_vec_q4_1_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q(
                             vx, vy, dst, ncols, nrows, item_ct1);
@@ -535,7 +537,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK5_0 == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -543,7 +545,7 @@ static void mul_mat_vec_q5_0_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q(
                             vx, vy, dst, ncols, nrows, item_ct1);
@@ -559,7 +561,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK5_1 == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -567,7 +569,7 @@ static void mul_mat_vec_q5_1_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q(
                             vx, vy, dst, ncols, nrows, item_ct1);
@@ -583,7 +585,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK8_0 == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -591,7 +593,7 @@ static void mul_mat_vec_q8_0_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q(
                             vx, vy, dst, ncols, nrows, item_ct1);
@@ -607,7 +609,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -615,7 +617,7 @@ static void mul_mat_vec_q2_K_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q(
                             vx, vy, dst, ncols, nrows, item_ct1);
@@ -631,7 +633,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -639,7 +641,7 @@ static void mul_mat_vec_q3_K_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q(
                             vx, vy, dst, ncols, nrows, item_ct1);
@@ -655,7 +657,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -663,7 +665,7 @@ static void mul_mat_vec_q4_K_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q(
                             vx, vy, dst, ncols, nrows, item_ct1);
@@ -679,7 +681,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -687,7 +689,7 @@ static void mul_mat_vec_q5_K_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q(
                             vx, vy, dst, ncols, nrows, item_ct1);
@@ -703,7 +705,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -711,7 +713,7 @@ static void mul_mat_vec_q6_K_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q(
                             vx, vy, dst, ncols, nrows, item_ct1);
@@ -728,13 +730,13 @@ static void mul_mat_vec_iq2_xxs_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
         stream->submit([&](sycl::handler &cgh) {
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q_iq2_xxs_q8_1(
                             vx, vy, dst, ncols, nrows, item_ct1);
                     });
@@ -749,7 +751,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -759,7 +761,7 @@ static void mul_mat_vec_iq2_xs_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q_iq2_xs_q8_1(
                             vx, vy, dst, ncols, nrows, item_ct1);
                     });
@@ -774,7 +776,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -784,7 +786,7 @@ static void mul_mat_vec_iq2_s_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q_iq2_s_q8_1(
                             vx, vy, dst, ncols, nrows, item_ct1);
                     });
@@ -799,7 +801,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -809,7 +811,7 @@ static void mul_mat_vec_iq3_xxs_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q_iq3_xxs_q8_1(
                             vx, vy, dst, ncols, nrows, item_ct1);
                     });
@@ -824,7 +826,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -833,7 +835,7 @@ static void mul_mat_vec_iq3_s_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q_iq3_s_q8_1(
                             vx, vy, dst, ncols, nrows, item_ct1);
                     });
@@ -848,7 +850,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
@@ -858,7 +860,7 @@ static void mul_mat_vec_iq1_s_q8_1_sycl(const void *vx, const void *vy,
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q_iq1_s_q8_1(
                             vx, vy, dst, ncols, nrows, item_ct1);
                     });
@@ -873,13 +875,13 @@ static void mul_mat_vec_iq1_m_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
         stream->submit([&](sycl::handler &cgh) {
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q_iq1_m_q8_1(
                             vx, vy, dst, ncols, nrows, item_ct1);
                     });
@@ -894,14 +896,14 @@ static void mul_mat_vec_iq4_nl_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK4_NL == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q_iq4_nl_q8_1(
                             vx, vy, dst, ncols, nrows, item_ct1);
                     });
@@ -916,14 +918,14 @@ static void mul_mat_vec_iq4_xs_q8_1_sycl(const void *vx, const void *vy,
     GGML_ASSERT(ncols % QK_K == 0);
     const int block_num_y = (nrows + GGML_SYCL_MMV_Y - 1) / GGML_SYCL_MMV_Y;
     const sycl::range<3> block_nums(1, 1, block_num_y);
-    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, WARP_SIZE);
+    const sycl::range<3> block_dims(1, GGML_SYCL_MMV_Y, QK_WARP_SIZE);
     {
 
         stream->submit([&](sycl::handler &cgh) {
             cgh.parallel_for(
                 sycl::nd_range<3>(block_nums * block_dims, block_dims),
                 [=](sycl::nd_item<3> item_ct1)
-                    [[intel::reqd_sub_group_size(WARP_SIZE)]] {
+                    [[intel::reqd_sub_group_size(QK_WARP_SIZE)]] {
                         mul_mat_vec_q_iq4_xs_q8_1(
                             vx, vy, dst, ncols, nrows, item_ct1);
                     });
diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp
index 30bd376da61..e749bbe7047 100644
--- a/ggml/src/ggml-vulkan.cpp
+++ b/ggml/src/ggml-vulkan.cpp
@@ -1941,7 +1941,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         if (device->fp16) {
             device_extensions.push_back("VK_KHR_shader_float16_int8");
         }
-        device->name = device->properties.deviceName.data();
+        device->name = GGML_VK_NAME + std::to_string(idx);
 
         device_create_info = {
             vk::DeviceCreateFlags(),
@@ -1968,7 +1968,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
         device->buffer_type = {
             /* .iface    = */ ggml_backend_vk_buffer_type_interface,
-            /* .device   = */ nullptr,
+            /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx),
             /* .context  = */ new ggml_backend_vk_buffer_type_context{ device->name, device },
         };
 
@@ -5287,9 +5287,9 @@ static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, gg
         return;
     }
 
-    ggml_type_traits_t tt = ggml_internal_get_type_traits(quant);
+    const auto * tt = ggml_get_type_traits(quant);
 
-    ggml_to_float_t dequant_fn = tt.to_float;
+    ggml_to_float_t dequant_fn = tt->to_float;
 
     dequant_fn(from, to, ne);
 }
@@ -6378,7 +6378,7 @@ ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() {
             /* .get_alloc_size   = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size,
             /* .is_host          = */ ggml_backend_cpu_buffer_type()->iface.is_host,
         },
-        /* .device   = */ nullptr,
+        /* .device   = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0),
         /* .context  = */ nullptr,
     };
 
@@ -6581,9 +6581,135 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
     UNUSED(backend);
 }
 
-static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
-    // ggml_backend_vk_context * ctx = (ggml_backend_vk_context *) backend->context;
+// TODO: enable async and synchronize
+static ggml_backend_i ggml_backend_vk_interface = {
+    /* .get_name                = */ ggml_backend_vk_name,
+    /* .free                    = */ ggml_backend_vk_free,
+    /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
+    /* .set_tensor_async        = */ NULL,  // ggml_backend_vk_set_tensor_async,
+    /* .get_tensor_async        = */ NULL,  // ggml_backend_vk_get_tensor_async,
+    /* .cpy_tensor_async        = */ NULL,  // ggml_backend_vk_cpy_tensor_async,
+    /* .synchronize             = */ NULL,  // ggml_backend_vk_synchronize,
+    /* .graph_plan_create       = */ NULL,
+    /* .graph_plan_free         = */ NULL,
+    /* .graph_plan_update       = */ NULL,
+    /* .graph_plan_compute      = */ NULL,
+    /* .graph_compute           = */ ggml_backend_vk_graph_compute,
+    /* .supports_op             = */ NULL,
+    /* .supports_buft           = */ NULL,
+    /* .offload_op              = */ NULL,
+    /* .event_record            = */ NULL,
+    /* .event_wait              = */ NULL,
+};
+
+static ggml_guid_t ggml_backend_vk_guid() {
+    static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
+    return &guid;
+}
+
+ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
+    VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
+
+    ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
+    ggml_vk_init(ctx, dev_num);
+
+    ggml_backend_t vk_backend = new ggml_backend {
+        /* .guid      = */ ggml_backend_vk_guid(),
+        /* .interface = */ ggml_backend_vk_interface,
+        /* .device    = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num),
+        /* .context   = */ ctx,
+    };
+
+    return vk_backend;
+}
+
+bool ggml_backend_is_vk(ggml_backend_t backend) {
+    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
+}
+
+int ggml_backend_vk_get_device_count() {
+    return ggml_vk_get_device_count();
+}
+
+void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
+    GGML_ASSERT(device < (int) vk_instance.device_indices.size());
+    int dev_idx = vk_instance.device_indices[device];
+    ggml_vk_get_device_description(dev_idx, description, description_size);
+}
+
+void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
+    GGML_ASSERT(device < (int) vk_instance.device_indices.size());
+
+    vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
+
+    vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
+
+    for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
+        if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
+            *total = heap.size;
+            *free = heap.size;
+            break;
+        }
+    }
+}
+
+//////////////////////////
+
+struct ggml_backend_vk_device_context {
+    int device;
+    std::string name;
+    std::string description;
+};
+
+static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) {
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+    return ctx->name.c_str();
+}
+
+static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) {
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+    return ctx->description.c_str();
+}
+
+static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context;
+    ggml_backend_vk_get_device_memory(ctx->device, free, total);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) {
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+    return ggml_backend_vk_buffer_type(ctx->device);
+}
+
+static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) {
+    UNUSED(dev);
+    return ggml_backend_vk_host_buffer_type();
+}
 
+static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
+    UNUSED(dev);
+    return GGML_BACKEND_DEVICE_TYPE_GPU_FULL;
+}
+
+static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) {
+    props->name        = ggml_backend_vk_device_get_name(dev);
+    props->description = ggml_backend_vk_device_get_description(dev);
+    props->type        = ggml_backend_vk_device_get_type(dev);
+    ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
+    props->caps = {
+        /* async       */ false,
+        /* host_buffer */ true,
+        /* events      */ false,
+    };
+}
+
+static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) {
+    UNUSED(params);
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
+    return ggml_backend_vk_init(ctx->device);
+}
+
+static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
     switch (op->op) {
         case GGML_OP_UNARY:
             switch (ggml_get_unary_op(op)) {
@@ -6701,97 +6827,101 @@ static bool ggml_backend_vk_supports_op(ggml_backend_t backend, const ggml_tenso
             return false;
     }
 
-    UNUSED(backend);
-}
-
-static bool ggml_backend_vk_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
-    const int min_batch_size = 32;
-
-    return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
-           (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
-
-    UNUSED(backend);
+    UNUSED(dev);
 }
 
-static bool ggml_backend_vk_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
+static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) {
     if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) {
         return false;
     }
 
+    ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
     ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context;
-    ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
-
-    return buft_ctx->device == ctx->device;
-}
-
-// TODO: enable async and synchronize
-static ggml_backend_i ggml_backend_vk_interface = {
-    /* .get_name                = */ ggml_backend_vk_name,
-    /* .free                    = */ ggml_backend_vk_free,
-    /* .get_default_buffer_type = */ ggml_backend_vk_get_default_buffer_type,
-    /* .set_tensor_async        = */ NULL,  // ggml_backend_vk_set_tensor_async,
-    /* .get_tensor_async        = */ NULL,  // ggml_backend_vk_get_tensor_async,
-    /* .cpy_tensor_async        = */ NULL,  // ggml_backend_vk_cpy_tensor_async,
-    /* .synchronize             = */ NULL,  // ggml_backend_vk_synchronize,
-    /* .graph_plan_create       = */ NULL,
-    /* .graph_plan_free         = */ NULL,
-    /* .graph_plan_update       = */ NULL,
-    /* .graph_plan_compute      = */ NULL,
-    /* .graph_compute           = */ ggml_backend_vk_graph_compute,
-    /* .supports_op             = */ ggml_backend_vk_supports_op,
-    /* .supports_buft           = */ ggml_backend_vk_supports_buft,
-    /* .offload_op              = */ ggml_backend_vk_offload_op,
-    /* .event_record            = */ NULL,
-    /* .event_wait              = */ NULL,
-};
 
-static ggml_guid_t ggml_backend_vk_guid() {
-    static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b };
-    return &guid;
+    return buft_ctx->device->idx == ctx->device;
 }
 
-ggml_backend_t ggml_backend_vk_init(size_t dev_num) {
-    VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")");
+static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
+    const int min_batch_size = 32;
 
-    ggml_backend_vk_context * ctx = new ggml_backend_vk_context;
-    ggml_vk_init(ctx, dev_num);
+    return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
+           (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
 
-    ggml_backend_t vk_backend = new ggml_backend {
-        /* .guid      = */ ggml_backend_vk_guid(),
-        /* .interface = */ ggml_backend_vk_interface,
-        /* .device    = */ nullptr,
-        /* .context   = */ ctx,
-    };
+    UNUSED(dev);
+}
+
+static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
+    /* .get_name             = */ ggml_backend_vk_device_get_name,
+    /* .get_description      = */ ggml_backend_vk_device_get_description,
+    /* .get_memory           = */ ggml_backend_vk_device_get_memory,
+    /* .get_type             = */ ggml_backend_vk_device_get_type,
+    /* .get_props            = */ ggml_backend_vk_device_get_props,
+    /* .init_backend         = */ ggml_backend_vk_device_init,
+    /* .get_buffer_type      = */ ggml_backend_vk_device_get_buffer_type,
+    /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
+    /* .buffer_from_host_ptr = */ NULL,
+    /* .supports_op          = */ ggml_backend_vk_device_supports_op,
+    /* .supports_buft        = */ ggml_backend_vk_device_supports_buft,
+    /* .offload_op           = */ ggml_backend_vk_device_offload_op,
+    /* .event_new            = */ NULL,
+    /* .event_free           = */ NULL,
+    /* .event_synchronize    = */ NULL,
+};
 
-    return vk_backend;
+static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) {
+    UNUSED(reg);
+    return GGML_VK_NAME;
 }
 
-bool ggml_backend_is_vk(ggml_backend_t backend) {
-    return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid());
+static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) {
+    UNUSED(reg);
+    return ggml_backend_vk_get_device_count();
 }
 
-int ggml_backend_vk_get_device_count() {
-    return ggml_vk_get_device_count();
-}
+static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) {
+    static std::vector devices;
 
-void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) {
-    ggml_vk_get_device_description(device, description, description_size);
-}
+    static bool initialized = false;
 
-void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) {
-    GGML_ASSERT(device < (int) vk_instance.device_indices.size());
+    {
+        static std::mutex mutex;
+        std::lock_guard lock(mutex);
+        if (!initialized) {
+            for (size_t i = 0; i < ggml_backend_vk_get_device_count(); i++) {
+                ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context;
+                char desc[256];
+                ggml_backend_vk_get_device_description(i, desc, sizeof(desc));
+                ctx->device = i;
+                ctx->name = GGML_VK_NAME + std::to_string(i);
+                ctx->description = desc;
+                devices.push_back(new ggml_backend_device {
+                    /* .iface   = */ ggml_backend_vk_device_i,
+                    /* .reg     = */ reg,
+                    /* .context = */ ctx,
+                });
+            }
+            initialized = true;
+        }
+    }
 
-    vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]];
+    GGML_ASSERT(device < devices.size());
+    return devices[device];
+}
 
-    vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties();
+static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = {
+    /* .get_name         = */ ggml_backend_vk_reg_get_name,
+    /* .get_device_count = */ ggml_backend_vk_reg_get_device_count,
+    /* .get_device       = */ ggml_backend_vk_reg_get_device,
+    /* .get_proc_address = */ NULL,
+};
 
-    for (const vk::MemoryHeap& heap : memprops.memoryHeaps) {
-        if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) {
-            *total = heap.size;
-            *free = heap.size;
-            break;
-        }
-    }
+ggml_backend_reg_t ggml_backend_vk_reg() {
+    static ggml_backend_reg reg = {
+        /* .iface   = */ ggml_backend_vk_reg_i,
+        /* .context = */ nullptr,
+    };
+
+    return ®
 }
 
 // Extension availability
diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c
index 264ffb5195e..a4359e7dd05 100644
--- a/ggml/src/ggml.c
+++ b/ggml/src/ggml.c
@@ -35,10 +35,6 @@
 #include 
 #endif
 
-#ifdef GGML_USE_METAL
-#include 
-#endif
-
 #if defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
 #undef GGML_USE_LLAMAFILE
 #endif
@@ -189,6 +185,8 @@ typedef pthread_t ggml_thread_t;
 #endif
 
 #if defined(__APPLE__)
+#include 
+#include 
 #include 
 #endif
 
@@ -327,8 +325,9 @@ struct ggml_logger_state {
 static struct ggml_logger_state g_logger_state = {ggml_log_callback_default, NULL};
 
 static void ggml_log_internal_v(enum ggml_log_level level, const char * format, va_list args) {
-    if (format == NULL)
+    if (format == NULL) {
         return;
+    }
     va_list args_copy;
     va_copy(args_copy, args);
     char buffer[128];
@@ -387,22 +386,40 @@ void ggml_log_callback_default(enum ggml_log_level level, const char * text, voi
 //#define GGML_SOFT_MAX_ACCELERATE
 #endif
 
+
+void * ggml_aligned_malloc(size_t size) {
 #if defined(_MSC_VER) || defined(__MINGW32__)
-#define GGML_ALIGNED_MALLOC(size) _aligned_malloc(size, GGML_MEM_ALIGN)
-#define GGML_ALIGNED_FREE(ptr)    _aligned_free(ptr)
+    return _aligned_malloc(size, TENSOR_ALIGNMENT);
 #else
-inline static void * ggml_aligned_malloc(size_t size) {
     if (size == 0) {
         GGML_LOG_WARN("Behavior may be unexpected when allocating 0 bytes for ggml_aligned_malloc!\n");
         return NULL;
     }
     void * aligned_memory = NULL;
 #ifdef GGML_USE_CPU_HBM
-    int result = hbw_posix_memalign(&aligned_memory, 16, size);
+    int result = hbw_posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size);
+#elif TARGET_OS_OSX
+    kern_return_t alloc_status = vm_allocate((vm_map_t) mach_task_self(), (vm_address_t *) &aligned_memory, size, VM_FLAGS_ANYWHERE);
+    int result = EFAULT;
+    switch (alloc_status) {
+        case KERN_SUCCESS:
+            result = 0;
+            break;
+        case KERN_INVALID_ADDRESS:
+            result = EINVAL;
+            break;
+        case KERN_NO_SPACE:
+            result = ENOMEM;
+            break;
+        default:
+            result = EFAULT;
+            break;
+    }
 #elif GGML_USE_METAL
-    int result = posix_memalign(&aligned_memory, sysconf(_SC_PAGESIZE), size);
+    const long page_size = sysconf(_SC_PAGESIZE);
+    int result = posix_memalign(&aligned_memory, MAX(TENSOR_ALIGNMENT, page_size), size);
 #else
-    int result = posix_memalign(&aligned_memory, GGML_MEM_ALIGN, size);
+    int result = posix_memalign(&aligned_memory, TENSOR_ALIGNMENT, size);
 #endif
     if (result != 0) {
         // Handle allocation failure
@@ -420,14 +437,26 @@ inline static void * ggml_aligned_malloc(size_t size) {
         return NULL;
     }
     return aligned_memory;
+#endif
 }
-#define GGML_ALIGNED_MALLOC(size) ggml_aligned_malloc(size)
-#ifdef GGML_USE_CPU_HBM
-#define GGML_ALIGNED_FREE(ptr)    if(NULL != ptr) hbw_free(ptr)
+
+void ggml_aligned_free(void * ptr, size_t size) {
+    GGML_UNUSED(size);
+#if defined(_MSC_VER) || defined(__MINGW32__)
+    _aligned_free(ptr);
+#elif GGML_USE_CPU_HBM
+    if (ptr != NULL) {
+        hbw_free(ptr);
+    }
+#elif TARGET_OS_OSX
+    if (ptr != NULL) {
+        vm_deallocate((vm_map_t)mach_task_self(), (vm_address_t)ptr, size);
+    }
 #else
-#define GGML_ALIGNED_FREE(ptr)    free(ptr)
-#endif
+    free(ptr);
 #endif
+}
+
 
 inline static void * ggml_malloc(size_t size) {
     if (size == 0) {
@@ -730,7 +759,7 @@ static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float *
 static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
 static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t * restrict x, size_t bx, ggml_bf16_t * restrict y, size_t by, int nrc);
 
-static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
+static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = {
     [GGML_TYPE_I8] = {
         .type_name                = "i8",
         .blck_size                = 1,
@@ -1152,9 +1181,9 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
 };
 
 // For internal test use
-ggml_type_traits_t ggml_internal_get_type_traits(enum ggml_type type) {
+const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) {
     GGML_ASSERT(type < GGML_TYPE_COUNT);
-    return type_traits[type];
+    return &type_traits[type];
 }
 
 //
@@ -3435,7 +3464,7 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
 
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
     size_t nbytes;
-    size_t blck_size = ggml_blck_size(tensor->type);
+    const size_t blck_size = ggml_blck_size(tensor->type);
     if (blck_size == 1) {
         nbytes = ggml_type_size(tensor->type);
         for (int i = 0; i < GGML_MAX_DIMS; ++i) {
@@ -3847,7 +3876,7 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
 
     *ctx = (struct ggml_context) {
         /*.mem_size           =*/ mem_size,
-        /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : GGML_ALIGNED_MALLOC(mem_size),
+        /*.mem_buffer         =*/ params.mem_buffer ? params.mem_buffer : ggml_aligned_malloc(mem_size),
         /*.mem_buffer_owned   =*/ params.mem_buffer ? false : true,
         /*.no_alloc           =*/ params.no_alloc,
         /*.no_alloc_save      =*/ params.no_alloc,
@@ -3885,7 +3914,7 @@ void ggml_free(struct ggml_context * ctx) {
     }
 
     if (ctx->mem_buffer_owned) {
-        GGML_ALIGNED_FREE(ctx->mem_buffer);
+        ggml_aligned_free(ctx->mem_buffer, ctx->mem_size);
     }
 
     GGML_FREE(ctx);
@@ -15662,6 +15691,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     ggml_vec_dot_t    const kq_vec_dot     = type_traits[k->type].vec_dot;
     ggml_to_float_t   const v_to_float     = type_traits[v->type].to_float;
 
+    GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type");
+    GGML_ASSERT(v_to_float   && "fattn: unsupported V-type");
+
     // loop over n_batch and n_head
     for (int ir = ir0; ir < ir1; ++ir) {
         // q indices
@@ -19575,9 +19607,10 @@ static void ggml_thread_cpumask_next(const bool * global_mask, bool * local_mask
 void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
     if (!threadpool) return;
 
+    const int n_threads = threadpool->n_threads_max;
+
 #ifndef GGML_USE_OPENMP
     struct ggml_compute_state* workers = threadpool->workers;
-    const int n_threads = threadpool->n_threads_max;
 
     ggml_mutex_lock(&threadpool->mutex);
 
@@ -19597,8 +19630,9 @@ void ggml_threadpool_free(struct ggml_threadpool* threadpool) {
     ggml_cond_destroy(&threadpool->cond);
 #endif // GGML_USE_OPENMP
 
-    GGML_ALIGNED_FREE(threadpool->workers);
-    GGML_ALIGNED_FREE(threadpool);
+    const size_t workers_size = sizeof(struct ggml_compute_state) * n_threads;
+    ggml_aligned_free(threadpool->workers, workers_size);
+    ggml_aligned_free(threadpool, sizeof(struct ggml_threadpool));
 }
 
 #ifndef GGML_USE_OPENMP
@@ -20030,7 +20064,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
                 struct ggml_cplan * cplan) {
 
     struct ggml_threadpool * threadpool =
-        GGML_ALIGNED_MALLOC(sizeof(struct ggml_threadpool));
+        ggml_aligned_malloc(sizeof(struct ggml_threadpool));
     {
         threadpool->cgraph           = cgraph;
         threadpool->cplan            = cplan;
@@ -20051,7 +20085,7 @@ static struct ggml_threadpool * ggml_threadpool_new_impl(
 
     // Allocate and init workers state
     const size_t workers_size = sizeof(struct ggml_compute_state) * tpp->n_threads;
-    struct ggml_compute_state * workers = GGML_ALIGNED_MALLOC(workers_size);
+    struct ggml_compute_state * workers = ggml_aligned_malloc(workers_size);
 
     memset(workers, 0, workers_size);
     for (int j = 0; j < tpp->n_threads; j++) {
@@ -23189,6 +23223,14 @@ int ggml_cpu_has_avx512_bf16(void) {
 #endif
 }
 
+int ggml_cpu_has_amx_int8(void) {
+#if defined(__AMX_INT8__)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
 int ggml_cpu_has_fma(void) {
 #if defined(__FMA__)
     return 1;
diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last
index b4fdf95db0f..da40927e196 100644
--- a/scripts/sync-ggml.last
+++ b/scripts/sync-ggml.last
@@ -1 +1 @@
-e7fd7deec20ef1ced3eebe38802f3c2126fddfa4
+162e232411ee98ceb0cccfa84886118d917d2123
diff --git a/src/whisper.cpp b/src/whisper.cpp
index 6e62d103b17..08f1ef01752 100644
--- a/src/whisper.cpp
+++ b/src/whisper.cpp
@@ -3667,6 +3667,9 @@ struct whisper_context * whisper_init_with_params_no_state(struct whisper_model_
     WHISPER_LOG_INFO("%s: gpu_device = %d\n", __func__, params.gpu_device);
     WHISPER_LOG_INFO("%s: dtw        = %d\n", __func__, params.dtw_token_timestamps);
 
+    // TODO: temporary call to force backend registry initialization
+    WHISPER_LOG_INFO("%s: backends   = %zu\n", __func__, ggml_backend_reg_count());
+
     whisper_context * ctx = new whisper_context;
     ctx->params = params;