@@ -1808,8 +1808,8 @@ static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_pr
18081808 return UINT32_MAX;
18091809}
18101810
1811- static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0) ) {
1812- VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags) << ", " << to_string(fallback_flags ) << ")");
1811+ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list< vk::MemoryPropertyFlags> & req_flags_list ) {
1812+ VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1] ) << ")");
18131813 if (size > device->max_memory_allocation_size) {
18141814 throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device memory allocation limit");
18151815 }
@@ -1836,42 +1836,27 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
18361836
18371837 vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
18381838
1839- uint32_t memory_type_index = UINT32_MAX;
1839+ for (auto &req_flags : req_flags_list) {
1840+ uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
18401841
1841- memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
1842- buf->memory_property_flags = req_flags;
1842+ if (memory_type_index == UINT32_MAX) {
1843+ continue;
1844+ }
1845+ buf->memory_property_flags = req_flags;
18431846
1844- if (memory_type_index == UINT32_MAX && fallback_flags) {
1845- memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
1846- buf->memory_property_flags = fallback_flags;
1847+ try {
1848+ buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
1849+ break;
1850+ } catch (const vk::SystemError& e) {
1851+ // loop and retry
1852+ }
18471853 }
18481854
1849- if (memory_type_index == UINT32_MAX ) {
1855+ if (buf->device_memory == VK_NULL_HANDLE ) {
18501856 device->device.destroyBuffer(buf->buffer);
18511857 throw vk::OutOfDeviceMemoryError("No suitable memory type found");
18521858 }
18531859
1854- try {
1855- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
1856- } catch (const vk::SystemError& e) {
1857- if (buf->memory_property_flags != fallback_flags) {
1858- // Try again with fallback flags
1859- memory_type_index = find_properties(&mem_props, &mem_req, fallback_flags);
1860- buf->memory_property_flags = fallback_flags;
1861-
1862- try {
1863- buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index });
1864- }
1865- catch (const vk::SystemError& e) {
1866- device->device.destroyBuffer(buf->buffer);
1867- throw e;
1868- }
1869- } else {
1870- // Out of Host/Device memory, clean up buffer
1871- device->device.destroyBuffer(buf->buffer);
1872- throw e;
1873- }
1874- }
18751860 buf->ptr = nullptr;
18761861
18771862 if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
@@ -1892,7 +1877,7 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
18921877
18931878static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) {
18941879 try {
1895- return ggml_vk_create_buffer(device, size, req_flags, fallback_flags);
1880+ return ggml_vk_create_buffer(device, size, { req_flags, fallback_flags} );
18961881 } catch (const vk::SystemError& e) {
18971882 std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
18981883 std::cerr << "ggml_vulkan: " << e.what() << std::endl;
@@ -1904,15 +1889,20 @@ static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) {
19041889 vk_buffer buf;
19051890 try {
19061891 if (device->prefer_host_memory) {
1907- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
1892+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
1893+ vk::MemoryPropertyFlagBits::eDeviceLocal});
19081894 } else if (device->uma) {
19091895 // Fall back to host memory type
1910- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
1896+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
1897+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
19111898 } else if (device->disable_host_visible_vidmem) {
1912- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal, vk::MemoryPropertyFlagBits::eDeviceLocal);
1899+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal,
1900+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
19131901 } else {
19141902 // use rebar if available, otherwise fallback to device only visible memory
1915- buf = ggml_vk_create_buffer(device, size, vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, vk::MemoryPropertyFlagBits::eDeviceLocal);
1903+ buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent,
1904+ vk::MemoryPropertyFlagBits::eDeviceLocal,
1905+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent});
19161906 }
19171907 } catch (const vk::SystemError& e) {
19181908 std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
@@ -4774,8 +4764,8 @@ static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_
47744764static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
47754765 VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")");
47764766 vk_buffer buf = ggml_vk_create_buffer(device, size,
4777- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
4778- vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
4767+ { vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
4768+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent} );
47794769
47804770 if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
47814771 fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n",
@@ -9187,7 +9177,7 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
91879177 if (ctx->prealloc_split_k != nullptr) {
91889178 ggml_vk_destroy_buffer(ctx->prealloc_split_k);
91899179 }
9190- ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
9180+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
91919181 }
91929182 }
91939183
@@ -9197,9 +9187,9 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t
91979187
91989188 ggml_pipeline_allocate_descriptor_sets(ctx);
91999189
9200- vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
9201- vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
9202- vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, vk::MemoryPropertyFlagBits::eDeviceLocal);
9190+ vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
9191+ vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
9192+ vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
92039193
92049194 X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne);
92059195 Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne);
@@ -9425,8 +9415,8 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
94259415 const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant);
94269416 float * x = (float *) malloc(x_sz);
94279417 void * qx = malloc(qx_sz);
9428- vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
9429- vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, vk::MemoryPropertyFlagBits::eDeviceLocal);
9418+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
9419+ vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
94309420 float * x_ref = (float *) malloc(x_sz);
94319421 ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16);
94329422
@@ -9531,8 +9521,8 @@ static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_
95319521// float * x = (float *) malloc(x_sz);
95329522// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz);
95339523// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz);
9534- // vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
9535- // vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
9524+ // vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
9525+ // vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
95369526//
95379527// for (size_t i = 0; i < ne; i++) {
95389528// x[i] = rand() / (float)RAND_MAX;
@@ -9679,10 +9669,10 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
96799669 float * x = (float *) malloc(x_sz);
96809670 float * y = (float *) malloc(y_sz);
96819671 void * qx = malloc(qx_sz);
9682- vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
9683- vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
9684- vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
9685- vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, vk::MemoryPropertyFlagBits::eDeviceLocal);
9672+ vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
9673+ vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
9674+ vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
9675+ vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
96869676 float * d = (float *) malloc(d_sz);
96879677 float * d_chk = (float *) malloc(d_sz);
96889678
@@ -9709,7 +9699,7 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m,
97099699 if (ctx->prealloc_split_k != nullptr) {
97109700 ggml_vk_destroy_buffer(ctx->prealloc_split_k);
97119701 }
9712- ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, vk::MemoryPropertyFlagBits::eDeviceLocal);
9702+ ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, { vk::MemoryPropertyFlagBits::eDeviceLocal} );
97139703 }
97149704 }
97159705 if (mmq) {
0 commit comments