@@ -707,9 +707,21 @@ static void ggml_vk_queue_cleanup(ggml_backend_vk_context * ctx, vk_queue& q) {
707707 q.cmd_buffer_idx = 0 ;
708708}
709709
710- static vk_buffer ggml_vk_create_buffer (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
710+ static uint32_t find_properties (const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
711+ for (uint32_t i = 0 ; i < mem_props->memoryTypeCount ; ++i) {
712+ vk::MemoryType memory_type = mem_props->memoryTypes [i];
713+ if ((mem_req->memoryTypeBits & ((uint64_t )1 << i)) &&
714+ (flags & memory_type.propertyFlags ) == flags &&
715+ mem_props->memoryHeaps [memory_type.heapIndex ].size >= mem_req->size ) {
716+ return static_cast <int32_t >(i);
717+ }
718+ }
719+ return UINT32_MAX;
720+ }
721+
722+ static vk_buffer ggml_vk_create_buffer (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0 )) {
711723#ifdef GGML_VULKAN_DEBUG
712- std::cerr << " ggml_vk_create_buffer(" << size << " , " << to_string (req_flags) << " )" << std::endl;
724+ std::cerr << " ggml_vk_create_buffer(" << size << " , " << to_string (req_flags) << " , " << to_string (fallback_flags) << " )" << std::endl;
713725#endif
714726 vk_buffer buf = std::make_shared<vk_buffer_struct>();
715727
@@ -736,15 +748,15 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
736748
737749 uint32_t memory_type_index = UINT32_MAX;
738750
739- for ( uint32_t i = 0 ; i < mem_props. memoryTypeCount ; ++i) {
740- vk::MemoryType memory_type = mem_props. memoryTypes [i] ;
741- if ((mem_req. memoryTypeBits & (( uint64_t ) 1 << i)) && (req_flags & memory_type. propertyFlags ) == req_flags && mem_props. memoryHeaps [memory_type. heapIndex ]. size >= mem_req. size ) {
742- memory_type_index = i;
743- break ;
744- }
751+ memory_type_index = find_properties (& mem_props, &mem_req, req_flags);
752+ buf-> memory_property_flags = req_flags ;
753+
754+ if (memory_type_index == UINT32_MAX && fallback_flags) {
755+ memory_type_index = find_properties (&mem_props, &mem_req, fallback_flags) ;
756+ buf-> memory_property_flags = fallback_flags;
745757 }
746758
747- if (memory_type_index >= mem_props. memoryTypeCount ) {
759+ if (memory_type_index == UINT32_MAX ) {
748760 ctx->device .lock ()->device .destroyBuffer (buf->buffer );
749761 buf->size = 0 ;
750762 throw vk::OutOfDeviceMemoryError (" No suitable memory type found" );
@@ -758,10 +770,9 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
758770 buf->size = 0 ;
759771 throw e;
760772 }
761- buf->memory_property_flags = req_flags;
762773 buf->ptr = nullptr ;
763774
764- if (req_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
775+ if (buf-> memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
765776 buf->ptr = ctx->device .lock ()->device .mapMemory (buf->device_memory , 0 , VK_WHOLE_SIZE);
766777 }
767778
@@ -778,9 +789,9 @@ static vk_buffer ggml_vk_create_buffer(ggml_backend_vk_context * ctx, size_t siz
778789 return buf;
779790}
780791
781- static vk_buffer ggml_vk_create_buffer_check (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags) {
792+ static vk_buffer ggml_vk_create_buffer_check (ggml_backend_vk_context * ctx, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags( 0 ) ) {
782793 try {
783- return ggml_vk_create_buffer (ctx, size, req_flags);
794+ return ggml_vk_create_buffer (ctx, size, req_flags, fallback_flags );
784795 } catch (const vk::SystemError& e) {
785796 std::cerr << " ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl;
786797 std::cerr << " ggml_vulkan: " << e.what () << std::endl;
@@ -791,16 +802,16 @@ static vk_buffer ggml_vk_create_buffer_check(ggml_backend_vk_context * ctx, size
791802static vk_buffer ggml_vk_create_buffer_device (ggml_backend_vk_context * ctx, size_t size) {
792803 vk_buffer buf;
793804 try {
794- buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
795- } catch (const vk::SystemError& e) {
796805 if (ctx->device .lock ()->uma ) {
797806 // Fall back to host memory type
798- buf = ggml_vk_create_buffer_check (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
807+ buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
799808 } else {
800- std::cerr << " ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
801- std::cerr << " ggml_vulkan: " << e.what () << std::endl;
802- throw e;
809+ buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eDeviceLocal);
803810 }
811+ } catch (const vk::SystemError& e) {
812+ std::cerr << " ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl;
813+ std::cerr << " ggml_vulkan: " << e.what () << std::endl;
814+ throw e;
804815 }
805816
806817 return buf;
@@ -1422,7 +1433,9 @@ static void * ggml_vk_host_malloc(ggml_backend_vk_context * ctx, size_t size) {
14221433#ifdef GGML_VULKAN_DEBUG
14231434 std::cerr << " ggml_vk_host_malloc(" << size << " )" << std::endl;
14241435#endif
1425- vk_buffer buf = ggml_vk_create_buffer (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1436+ vk_buffer buf = ggml_vk_create_buffer (ctx, size,
1437+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
1438+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
14261439
14271440 if (!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) {
14281441 fprintf (stderr, " WARNING: failed to allocate %.2f MB of pinned memory\n " ,
@@ -1568,7 +1581,9 @@ static void deferred_memcpy(void * dst, const void * src, size_t size, std::vect
15681581static void ggml_vk_ensure_sync_staging_buffer (ggml_backend_vk_context * ctx, size_t size) {
15691582 if (ctx->sync_staging == nullptr || ctx->sync_staging ->size < size) {
15701583 ggml_vk_destroy_buffer (ctx->sync_staging );
1571- ctx->sync_staging = ggml_vk_create_buffer_check (ctx, size, vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
1584+ ctx->sync_staging = ggml_vk_create_buffer_check (ctx, size,
1585+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
1586+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
15721587 }
15731588}
15741589
@@ -4082,7 +4097,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
40824097 std::cerr << " ggml_vk_preallocate_buffers(qx_size: " << ctx->prealloc_size_qx << " qy_size: " << ctx->prealloc_size_qy << " x_size: " << ctx->prealloc_size_x << " y_size: " << ctx->prealloc_size_y << " split_k_size: " << ctx->prealloc_size_split_k << " )" << std::endl;
40834098#endif
40844099#if defined(GGML_VULKAN_RUN_TESTS)
4085- ctx->staging = ggml_vk_create_buffer_check (ctx, 100ul * 1024ul * 1024ul , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4100+ ctx->staging = ggml_vk_create_buffer_check (ctx, 100ul * 1024ul * 1024ul ,
4101+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached
4102+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
40864103 ggml_vk_test_transfer (ctx, 8192 * 1000 , false );
40874104 ggml_vk_test_transfer (ctx, 8192 * 1000 , true );
40884105
@@ -4174,7 +4191,9 @@ static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) {
41744191 if (ctx->staging != nullptr ) {
41754192 ggml_vk_destroy_buffer (ctx->staging );
41764193 }
4177- ctx->staging = ggml_vk_create_buffer_check (ctx, ctx->staging_size , vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached);
4194+ ctx->staging = ggml_vk_create_buffer_check (ctx, ctx->staging_size ,
4195+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached,
4196+ vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent);
41784197 }
41794198}
41804199
0 commit comments