@@ -96,8 +96,6 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; }
9696
9797#define GGML_VK_MAX_NODES 8192
9898
99- #define MAX_VK_BUFFERS 256
100-
10199#define VK_CHECK(err, msg) \
102100 do { \
103101 vk::Result err_ = (err); \
@@ -1311,7 +1309,6 @@ struct ggml_vk_garbage_collector {
13111309 std::vector<vk_semaphore> tl_semaphores;
13121310 std::vector<vk_semaphore> semaphores;
13131311 std::vector<vk::Event> events;
1314- std::vector<vk_buffer> temp_buffers;
13151312 std::vector<vk_context> contexts;
13161313};
13171314
@@ -1482,8 +1479,6 @@ struct ggml_backend_vk_context {
14821479 // and set to true after the buffer contents are consumed.
14831480 bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync;
14841481
1485- vk_buffer buffer_pool[MAX_VK_BUFFERS];
1486-
14871482 vk_context_ref compute_ctx;
14881483 vk_context_ref transfer_ctx;
14891484
@@ -3623,8 +3618,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
36233618
36243619 ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
36253620
3626- ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1);
3627- ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1);
3621+ if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
3622+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
3623+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
3624+ } else {
3625+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size, 16}, 1, true, true);
3626+ ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
3627+ }
36283628
36293629 ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
36303630
@@ -5144,71 +5144,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
51445144 return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
51455145}
51465146
5147- static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) {
5148- VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")");
5149- VK_LOG_MEMORY("ggml_vk_pool_malloc");
5150-
5151- int best_i = -1;
5152- size_t best_size = std::numeric_limits<size_t>::max(); //smallest unused buffer that fits our needs
5153- int worst_i = -1;
5154- size_t worst_size = 0; //largest unused buffer seen so far
5155- for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
5156- vk_buffer &b = ctx->buffer_pool[i];
5157- if (b != nullptr && b->size >= size && b->size < best_size) {
5158- best_i = i;
5159- best_size = b->size;
5160- }
5161- if (b != nullptr && b->size > worst_size) {
5162- worst_i = i;
5163- worst_size = b->size;
5164- }
5165- }
5166- if(best_i != -1) {
5167- //found the smallest buffer that fits our needs
5168- vk_buffer b = ctx->buffer_pool[best_i];
5169- ctx->buffer_pool[best_i].reset();
5170- return b;
5171- }
5172- if(worst_i != -1) {
5173- //no buffer that fits our needs, resize largest one to save memory
5174- vk_buffer& b = ctx->buffer_pool[worst_i];
5175- ggml_vk_destroy_buffer(b);
5176- }
5177-
5178- return ggml_vk_create_buffer_device(ctx->device, size);
5179- }
5180-
5181- static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) {
5182- VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")");
5183- for (int i = 0; i < MAX_VK_BUFFERS; ++i) {
5184- vk_buffer& b = ctx->buffer_pool[i];
5185- if (b == nullptr) {
5186- b = buffer;
5187- return;
5188- }
5189- }
5190- std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl;
5191- ggml_vk_destroy_buffer(buffer);
5192- }
5193-
5194- // Returns an available temporary buffer that may only be used temporarily, it will be reused
5195- static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) {
5196- // Try to find existing temp buffer with enough capacity
5197- for (auto& buffer : ctx->gc.temp_buffers) {
5198- if (buffer->size >= size) {
5199- return buffer;
5200- }
5201- }
5202-
5203- VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")");
5204-
5205- // Otherwise create new buffer
5206- vk_buffer buf = ggml_vk_pool_malloc(ctx, size);
5207- ctx->gc.temp_buffers.push_back(buf);
5208-
5209- return buf;
5210- }
5211-
52125147static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
52135148 VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
52145149 vk_buffer buf = ggml_vk_create_buffer(device, size,
@@ -11789,10 +11724,6 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
1178911724// Clean up after graph processing is done
1179011725static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) {
1179111726 VK_LOG_DEBUG("ggml_vk_graph_cleanup()");
11792- for (auto& buffer : ctx->gc.temp_buffers) {
11793- ggml_vk_pool_free(ctx, buffer);
11794- }
11795- ctx->gc.temp_buffers.clear();
1179611727 ctx->prealloc_y_last_pipeline_used = {};
1179711728
1179811729 ctx->unsynced_nodes_written.clear();
@@ -11835,10 +11766,6 @@ static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) {
1183511766 ggml_vk_destroy_buffer(ctx->prealloc_split_k);
1183611767 ctx->prealloc_y_last_pipeline_used = nullptr;
1183711768
11838- for (auto& buffer : ctx->buffer_pool) {
11839- ggml_vk_destroy_buffer(buffer);
11840- }
11841-
1184211769 ctx->prealloc_size_x = 0;
1184311770 ctx->prealloc_size_y = 0;
1184411771 ctx->prealloc_size_split_k = 0;
0 commit comments