diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 3e3db999c92ed..859c1443f5f84 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -8943,6 +8943,13 @@ def set_vocab(self): class GptOssModel(TextModel): model_arch = gguf.MODEL_ARCH.GPT_OSS + # TODO: remove once MXFP4 is supported more generally + def dequant_model(self): + quant_config = self.hparams.get("quantization_config") + if quant_config is not None and quant_config.get("quant_method") == "mxfp4": + return + return super().dequant_model() + def transform_nibble_layout(self, tensor): assert tensor.dtype == torch.uint8 assert tensor.shape[-1] == 16 diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 21bd052255564..94d76c7ea8891 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -96,8 +96,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define GGML_VK_MAX_NODES 8192 -#define MAX_VK_BUFFERS 256 - #define VK_CHECK(err, msg) \ do { \ vk::Result err_ = (err); \ @@ -1311,7 +1309,6 @@ struct ggml_vk_garbage_collector { std::vector tl_semaphores; std::vector semaphores; std::vector events; - std::vector temp_buffers; std::vector contexts; }; @@ -1482,8 +1479,6 @@ struct ggml_backend_vk_context { // and set to true after the buffer contents are consumed. bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; - vk_buffer buffer_pool[MAX_VK_BUFFERS]; - vk_context_ref compute_ctx; vk_context_ref transfer_ctx; @@ -3623,8 +3618,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1); - ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1); + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true); + } ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1); @@ -5144,71 +5144,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; } -static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { - VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")"); - VK_LOG_MEMORY("ggml_vk_pool_malloc"); - - int best_i = -1; - size_t best_size = std::numeric_limits::max(); //smallest unused buffer that fits our needs - int worst_i = -1; - size_t worst_size = 0; //largest unused buffer seen so far - for (int i = 0; i < MAX_VK_BUFFERS; ++i) { - vk_buffer &b = ctx->buffer_pool[i]; - if (b != nullptr && b->size >= size && b->size < best_size) { - best_i = i; - best_size = b->size; - } - if (b != nullptr && b->size > worst_size) { - worst_i = i; - worst_size = b->size; - } - } - if(best_i != -1) { - //found the smallest buffer that fits our needs - vk_buffer b = ctx->buffer_pool[best_i]; - ctx->buffer_pool[best_i].reset(); - return b; - } - if(worst_i != -1) { - //no buffer that fits our needs, resize largest one to save memory - vk_buffer& b = ctx->buffer_pool[worst_i]; - ggml_vk_destroy_buffer(b); - } - - return ggml_vk_create_buffer_device(ctx->device, size); -} - -static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) { - VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")"); - for (int i = 0; i < MAX_VK_BUFFERS; ++i) { - vk_buffer& b = ctx->buffer_pool[i]; - if (b == nullptr) { - b = buffer; - return; - } - } - std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl; - ggml_vk_destroy_buffer(buffer); -} - -// Returns an available temporary buffer that may only be used temporarily, it will be reused -static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) { - // Try to find existing temp buffer with enough capacity - for (auto& buffer : ctx->gc.temp_buffers) { - if (buffer->size >= size) { - return buffer; - } - } - - VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")"); - - // Otherwise create new buffer - vk_buffer buf = ggml_vk_pool_malloc(ctx, size); - ctx->gc.temp_buffers.push_back(buf); - - return buf; -} - static void * ggml_vk_host_malloc(vk_device& device, size_t size) { VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); vk_buffer buf = ggml_vk_create_buffer(device, size, @@ -11789,10 +11724,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * // Clean up after graph processing is done static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); - for (auto& buffer : ctx->gc.temp_buffers) { - ggml_vk_pool_free(ctx, buffer); - } - ctx->gc.temp_buffers.clear(); ctx->prealloc_y_last_pipeline_used = {}; ctx->unsynced_nodes_written.clear(); @@ -11835,10 +11766,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { ggml_vk_destroy_buffer(ctx->prealloc_split_k); ctx->prealloc_y_last_pipeline_used = nullptr; - for (auto& buffer : ctx->buffer_pool) { - ggml_vk_destroy_buffer(buffer); - } - ctx->prealloc_size_x = 0; ctx->prealloc_size_y = 0; ctx->prealloc_size_split_k = 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp index 12bd174579052..8f67be9799518 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/ssm_scan.comp @@ -1,6 +1,9 @@ #version 450 #extension GL_EXT_control_flow_attributes : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif #include "types.glsl" @@ -84,35 +87,47 @@ void main() { } barrier(); - for (uint w = D_STATE; w > SUBGROUP_SIZE; w >>= 1) { - [[unroll]] for (uint j = 0; j < ((w >> 1) * SPLIT_H + D_STATE - 1) / D_STATE; j++) { - const uint k = (tid % (w >> 1)) + - (D_STATE * (tid / (w >> 1))) + - j * D_STATE * (D_STATE / (w >> 1)); - if (k < SPLIT_H * D_STATE && (k + (w >> 1)) < SPLIT_H * D_STATE) { - stateC[k] += stateC[k + (w >> 1)]; + [[unroll]] + for (uint w = D_STATE / 2; w >= SUBGROUP_SIZE; w >>= 1) { + [[unroll]] for (uint j = 0; j < (w * SPLIT_H + D_STATE - 1) / D_STATE; j++) { + const uint k = (tid % w) + (D_STATE * (tid / w)) + j * D_STATE * (D_STATE / w); + if (k < SPLIT_H * D_STATE && (k + w) < SPLIT_H * D_STATE) { + stateC[k] += stateC[k + w]; } } barrier(); } - [[unroll]] for (uint j = 0; j <= SPLIT_H / (D_STATE / SUBGROUP_SIZE); j++) { + [[unroll]] for (uint j = 0; j < max(1, SPLIT_H / (D_STATE / SUBGROUP_SIZE)); j++) { const uint idx = (tid % SUBGROUP_SIZE) + D_STATE * (tid / SUBGROUP_SIZE) + j * D_STATE * (D_STATE / SUBGROUP_SIZE); + const uint max_idx = SUBGROUP_SIZE - 1 + + D_STATE * ((D_STATE - 1) / SUBGROUP_SIZE) + + j * D_STATE * (D_STATE / SUBGROUP_SIZE); - uint lane = tid % SUBGROUP_SIZE; - - [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { - if (idx + offset < SPLIT_H * D_STATE) { - stateC[idx] += stateC[idx + offset]; + if (idx < SPLIT_H * D_STATE || + max_idx < SPLIT_H * D_STATE) { + float sc; +#if USE_SUBGROUP_ADD + sc = stateC[idx]; + sc = subgroupAdd(sc); +#else + [[unroll]] for (uint offset = SUBGROUP_SIZE / 2; offset > 0; offset >>= 1) { + if (idx + offset < SPLIT_H * D_STATE) { + stateC[idx] += stateC[idx + offset]; + } + barrier(); } - barrier(); - } + if (tid % SUBGROUP_SIZE == 0) { + sc = stateC[idx]; + } +#endif - if (idx < SPLIT_H * D_STATE && tid % SUBGROUP_SIZE == 0) { - const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); - d[y_base_idx + i * stride_y + k] = stateC[idx]; + if (tid % SUBGROUP_SIZE == 0) { + const uint k = tid / SUBGROUP_SIZE + j * (D_STATE / SUBGROUP_SIZE); + d[y_base_idx + i * stride_y + k] = sc; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 49bf6c764f726..0f25ba3453093 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -916,7 +916,8 @@ void process_shaders() { string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); - string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}}); + string_to_spv("ssm_scan_f32", "ssm_scan.comp", {{"A_TYPE", "float"}}); + string_to_spv("ssm_scan_subgroup_f32", "ssm_scan.comp", {{"A_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); string_to_spv("ssm_conv_f32", "ssm_conv.comp", {{"A_TYPE", "float"}}); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 7420a3176d930..2a83d66279b79 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -17965,6 +17965,8 @@ struct llm_build_plamo2 : public llm_graph_context_mamba { cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1); cb(cur, "result_norm", -1); + res->t_embd = cur; + // lm_head cur = build_lora_mm(model.output, cur); cb(cur, "result_output", -1);