@@ -148,6 +148,7 @@ struct vk_device_struct {
148148 vk::PhysicalDeviceProperties properties;
149149 std::string name;
150150 uint64_t max_memory_allocation_size;
151+ uint32_t force_heap_index;
151152 bool fp16;
152153 vk::Device device;
153154 uint32_t vendor_id;
@@ -1008,9 +1009,12 @@ static void ggml_vk_queue_cleanup(vk_device& device, vk_queue& q) {
10081009 q.cmd_buffer_idx = 0 ;
10091010}
10101011
1011- static uint32_t find_properties (const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) {
1012+ static uint32_t find_properties (const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags, uint32_t force_heap_index = UINT32_MAX ) {
10121013 for (uint32_t i = 0 ; i < mem_props->memoryTypeCount ; ++i) {
10131014 vk::MemoryType memory_type = mem_props->memoryTypes [i];
1015+ if (force_heap_index != UINT32_MAX && memory_type.heapIndex != force_heap_index) {
1016+ continue ;
1017+ }
10141018 if ((mem_req->memoryTypeBits & ((uint64_t )1 << i)) &&
10151019 (flags & memory_type.propertyFlags ) == flags &&
10161020 mem_props->memoryHeaps [memory_type.heapIndex ].size >= mem_req->size ) {
@@ -1053,11 +1057,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, vk::Memor
10531057
10541058 uint32_t memory_type_index = UINT32_MAX;
10551059
1056- memory_type_index = find_properties (&mem_props, &mem_req, req_flags);
1060+ memory_type_index = find_properties (&mem_props, &mem_req, req_flags, device-> force_heap_index );
10571061 buf->memory_property_flags = req_flags;
10581062
10591063 if (memory_type_index == UINT32_MAX && fallback_flags) {
1060- memory_type_index = find_properties (&mem_props, &mem_req, fallback_flags);
1064+ memory_type_index = find_properties (&mem_props, &mem_req, fallback_flags, device-> force_heap_index );
10611065 buf->memory_property_flags = fallback_flags;
10621066 }
10631067
@@ -1851,6 +1855,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
18511855 device->max_memory_allocation_size = props3.maxMemoryAllocationSize ;
18521856 }
18531857
1858+ const char * GGML_VK_FORCE_HEAP_INDEX = getenv (" GGML_VK_FORCE_HEAP_INDEX" );
1859+
1860+ if (GGML_VK_FORCE_HEAP_INDEX != nullptr ) {
1861+ device->force_heap_index = std::stoi (GGML_VK_FORCE_HEAP_INDEX);
1862+ } else {
1863+ device->force_heap_index = UINT32_MAX;
1864+ }
1865+
18541866 device->vendor_id = device->properties .vendorID ;
18551867 device->subgroup_size = subgroup_props.subgroupSize ;
18561868 device->uma = device->properties .deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
0 commit comments