diff --git a/backends/vulkan/runtime/api/Context.cpp b/backends/vulkan/runtime/api/Context.cpp index df3d43ee901..6de700efce7 100644 --- a/backends/vulkan/runtime/api/Context.cpp +++ b/backends/vulkan/runtime/api/Context.cpp @@ -40,6 +40,10 @@ Context::Context(size_t adapter_i, const ContextConfig& config) cmd_mutex_{}, cmd_(VK_NULL_HANDLE, 0u), submit_count_{0u}, + // Custom memory pools + custom_vma_pool_( + adapter_p_->vma().handle(), + adapter_p_->num_memory_types()), // Memory Management buffer_clearlist_mutex_{}, buffers_to_clear_{}, @@ -60,6 +64,13 @@ Context::~Context() { } } +vkapi::MemoryPoolManager* Context::get_custom_memory_pool_ptr() { + if (config_.use_custom_vma_pools) { + return &custom_vma_pool_; + } + return nullptr; +} + void Context::initialize_querypool() { querypool_.initialize(adapter_p_); } diff --git a/backends/vulkan/runtime/api/Context.h b/backends/vulkan/runtime/api/Context.h index 6cfbc64f141..0a684aa4c3a 100644 --- a/backends/vulkan/runtime/api/Context.h +++ b/backends/vulkan/runtime/api/Context.h @@ -13,6 +13,8 @@ #include #include +#include + #include #include #include @@ -29,6 +31,7 @@ struct ContextConfig final { vkapi::CommandPoolConfig cmd_pool_config; vkapi::DescriptorPoolConfig descriptor_pool_config; vkapi::QueryPoolConfig query_pool_config; + bool use_custom_vma_pools; }; // @@ -69,6 +72,9 @@ class Context final { std::mutex cmd_mutex_; vkapi::CommandBuffer cmd_; uint32_t submit_count_; + // Custom memory pool that can be used to allocate resources that are used in + // this command stream + vkapi::MemoryPoolManager custom_vma_pool_; // Memory Management std::mutex buffer_clearlist_mutex_; std::vector buffers_to_clear_; @@ -78,6 +84,10 @@ class Context final { VkImageTiling preferred_image_tiling_; public: + inline const ContextConfig& config() const { + return config_; + } + // Adapter access inline vkapi::Adapter* adapter_ptr() { @@ -130,6 +140,8 @@ class Context final { return preferred_image_tiling_; } + vkapi::MemoryPoolManager* get_custom_memory_pool_ptr(); + /* * By default, the querypool attached to a Context instance is uninitialized. * This function triggers the querypool to be created via vkCreateQueryPool. diff --git a/backends/vulkan/runtime/api/api.h b/backends/vulkan/runtime/api/api.h index b5d46b8bba4..5edf12b15e6 100644 --- a/backends/vulkan/runtime/api/api.h +++ b/backends/vulkan/runtime/api/api.h @@ -30,3 +30,4 @@ #include #include #include +#include diff --git a/backends/vulkan/runtime/api/containers/ParamsBuffer.cpp b/backends/vulkan/runtime/api/containers/ParamsBuffer.cpp index 482a5c50be6..be9bfa36d07 100644 --- a/backends/vulkan/runtime/api/containers/ParamsBuffer.cpp +++ b/backends/vulkan/runtime/api/containers/ParamsBuffer.cpp @@ -36,7 +36,8 @@ ParamsBuffer::ParamsBuffer(const ParamsBuffer& other) : context_p_(other.context_p_), vulkan_buffer_{} { if (other.vulkan_buffer_) { vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer( - other.vulkan_buffer_.mem_size()); + other.vulkan_buffer_.mem_size(), + context_p_->get_custom_memory_pool_ptr()); memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_); } @@ -55,7 +56,8 @@ ParamsBuffer& ParamsBuffer::operator=(const ParamsBuffer& other) { if (other.vulkan_buffer_) { vulkan_buffer_ = context_p_->adapter_ptr()->vma().create_uniform_buffer( - other.vulkan_buffer_.mem_size()); + other.vulkan_buffer_.mem_size(), + context_p_->get_custom_memory_pool_ptr()); memcpy_to_buffer(other.vulkan_buffer_, vulkan_buffer_); } diff --git a/backends/vulkan/runtime/api/containers/ParamsBuffer.h b/backends/vulkan/runtime/api/containers/ParamsBuffer.h index ecc07892cf7..a5919732e74 100644 --- a/backends/vulkan/runtime/api/containers/ParamsBuffer.h +++ b/backends/vulkan/runtime/api/containers/ParamsBuffer.h @@ -35,8 +35,9 @@ class ParamsBuffer final { // constructor from the one above. ParamsBuffer(Context* context_p, const VkDeviceSize nbytes, const bool unused) : context_p_(context_p), - vulkan_buffer_( - context_p_->adapter_ptr()->vma().create_uniform_buffer(nbytes)) {} + vulkan_buffer_(context_p_->adapter_ptr()->vma().create_uniform_buffer( + nbytes, + context_p->get_custom_memory_pool_ptr())) {} ParamsBuffer(const ParamsBuffer&); ParamsBuffer& operator=(const ParamsBuffer&); diff --git a/backends/vulkan/runtime/api/containers/StagingBuffer.h b/backends/vulkan/runtime/api/containers/StagingBuffer.h index 1e9f569fc4a..a9fe28ca575 100644 --- a/backends/vulkan/runtime/api/containers/StagingBuffer.h +++ b/backends/vulkan/runtime/api/containers/StagingBuffer.h @@ -35,7 +35,8 @@ class StagingBuffer final { : context_p_(context_p), dtype_(dtype), vulkan_buffer_(context_p_->adapter_ptr()->vma().create_staging_buffer( - element_size(dtype_) * numel)), + element_size(dtype_) * numel, + context_p_->get_custom_memory_pool_ptr())), mapped_data_(nullptr) {} StagingBuffer(const StagingBuffer&) = delete; diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 62b53f9a76c..61c4ecfe68d 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -292,7 +292,8 @@ vkapi::VulkanImage allocate_image( sampler_props, sampler, /*allow_transfer = */ true, - /*allocate_memory = */ allocate_memory); + /*allocate_memory = */ allocate_memory, + /*pool_manager = */ context_ptr->get_custom_memory_pool_ptr()); } vkapi::VulkanBuffer allocate_buffer( @@ -301,8 +302,6 @@ vkapi::VulkanBuffer allocate_buffer( const utils::StorageType storage_type, const vkapi::ScalarType dtype, const bool allocate_memory) { - vkapi::Adapter* adapter_ptr = context_ptr->adapter_ptr(); - switch (storage_type) { case utils::kBuffer: break; @@ -313,8 +312,10 @@ vkapi::VulkanBuffer allocate_buffer( VK_CHECK_COND(numel <= context_ptr->adapter_ptr()->max_buffer_numel()); - return adapter_ptr->vma().create_storage_buffer( - element_size(dtype) * numel, allocate_memory); + return context_ptr->adapter_ptr()->vma().create_storage_buffer( + element_size(dtype) * numel, + allocate_memory, + context_ptr->get_custom_memory_pool_ptr()); } vTensorStorage::vTensorStorage( diff --git a/backends/vulkan/runtime/graph/GraphConfig.cpp b/backends/vulkan/runtime/graph/GraphConfig.cpp index 887b46c002a..7bc87899782 100644 --- a/backends/vulkan/runtime/graph/GraphConfig.cpp +++ b/backends/vulkan/runtime/graph/GraphConfig.cpp @@ -40,6 +40,7 @@ GraphConfig::GraphConfig() { cmd_config, descriptor_pool_config, query_pool_config, + false, }; // Empirically selected safety factor. If descriptor pools start running out diff --git a/backends/vulkan/runtime/graph/containers/SharedObject.cpp b/backends/vulkan/runtime/graph/containers/SharedObject.cpp index 10ddd6f2ca3..65fd7e64d68 100644 --- a/backends/vulkan/runtime/graph/containers/SharedObject.cpp +++ b/backends/vulkan/runtime/graph/containers/SharedObject.cpp @@ -39,7 +39,9 @@ void SharedObject::allocate(ComputeGraph* const graph) { graph->context()->adapter_ptr()->vma().gpuonly_resource_create_info(); allocation = graph->context()->adapter_ptr()->vma().create_allocation( - aggregate_memory_requirements, alloc_create_info); + aggregate_memory_requirements, + alloc_create_info, + graph->context()->get_custom_memory_pool_ptr()); } void SharedObject::bind_users(ComputeGraph* const graph) { diff --git a/backends/vulkan/runtime/vk_api/Adapter.h b/backends/vulkan/runtime/vk_api/Adapter.h index 8ae61095be8..e9b2c92134e 100644 --- a/backends/vulkan/runtime/vk_api/Adapter.h +++ b/backends/vulkan/runtime/vk_api/Adapter.h @@ -110,6 +110,10 @@ class Adapter final { return physical_device_.has_unified_memory; } + inline uint32_t num_memory_types() const { + return physical_device_.memory_properties.memoryTypeCount; + } + inline uint32_t num_compute_queues() const { return physical_device_.num_compute_queues; } diff --git a/backends/vulkan/runtime/vk_api/memory/Allocator.cpp b/backends/vulkan/runtime/vk_api/memory/Allocator.cpp index 7976d0ddee5..b6c1c1fe21c 100644 --- a/backends/vulkan/runtime/vk_api/memory/Allocator.cpp +++ b/backends/vulkan/runtime/vk_api/memory/Allocator.cpp @@ -67,7 +67,8 @@ VmaAllocationCreateInfo Allocator::gpuonly_resource_create_info() { Allocation Allocator::create_allocation( const VkMemoryRequirements& memory_requirements, - const VmaAllocationCreateInfo& create_info) { + const VmaAllocationCreateInfo& create_info, + MemoryPoolManager* pool_manager) { VmaAllocationCreateInfo alloc_create_info = create_info; // Protect against using VMA_MEMORY_USAGE_AUTO_* flags when allocating memory // directly, since those usage flags require that VkBufferCreateInfo and/or @@ -90,6 +91,11 @@ Allocation Allocator::create_allocation( default: break; } + if (pool_manager) { + uint32_t memory_type_idx = + pool_manager->get_memory_type_idx(alloc_create_info); + alloc_create_info.pool = pool_manager->get_memory_pool(memory_type_idx); + } return Allocation(allocator_, memory_requirements, alloc_create_info); } @@ -104,7 +110,8 @@ VulkanImage Allocator::create_image( const VulkanImage::SamplerProperties& sampler_props, VkSampler sampler, const bool allow_transfer, - const bool allocate_memory) { + const bool allocate_memory, + MemoryPoolManager* pool_manager) { VkImageUsageFlags usage = VK_IMAGE_USAGE_SAMPLED_BIT | VK_IMAGE_USAGE_STORAGE_BIT; if (allow_transfer) { @@ -129,6 +136,15 @@ VulkanImage Allocator::create_image( const VkImageLayout initial_layout = VK_IMAGE_LAYOUT_UNDEFINED; + if (pool_manager) { + VkImageCreateInfo image_create_info = vkapi::generate_image_create_info( + image_type, image_format, extents, image_tiling, usage, initial_layout); + + uint32_t memory_type_idx = + pool_manager->get_memory_type_idx(alloc_create_info, image_create_info); + alloc_create_info.pool = pool_manager->get_memory_pool(memory_type_idx); + } + return VulkanImage( device, allocator_, @@ -141,7 +157,26 @@ VulkanImage Allocator::create_image( allocate_memory); } -VulkanBuffer Allocator::create_staging_buffer(const VkDeviceSize size) { +VmaPool get_memory_pool_for_buffer( + MemoryPoolManager* pool_manager, + const VkDeviceSize size, + const VkBufferUsageFlags buffer_usage, + const VmaAllocationCreateInfo& alloc_create_info) { + if (pool_manager) { + VkBufferCreateInfo buffer_create_info = + vkapi::generate_buffer_create_info(size, buffer_usage); + + uint32_t memory_type_idx = pool_manager->get_memory_type_idx( + alloc_create_info, buffer_create_info); + + return pool_manager->get_memory_pool(memory_type_idx); + } + return VK_NULL_HANDLE; // Return a default value if pool_manager is null +} + +VulkanBuffer Allocator::create_staging_buffer( + const VkDeviceSize size, + MemoryPoolManager* pool_manager) { const VkBufferUsageFlags buffer_usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; VmaAllocationCreateInfo alloc_create_info = {}; @@ -159,26 +194,37 @@ VulkanBuffer Allocator::create_staging_buffer(const VkDeviceSize size) { alloc_create_info.preferredFlags = VK_MEMORY_PROPERTY_HOST_COHERENT_BIT | VK_MEMORY_PROPERTY_HOST_CACHED_BIT; + alloc_create_info.pool = get_memory_pool_for_buffer( + pool_manager, size, buffer_usage, alloc_create_info); + return VulkanBuffer(allocator_, size, alloc_create_info, buffer_usage); } VulkanBuffer Allocator::create_storage_buffer( const VkDeviceSize size, - const bool allocate_memory) { + const bool allocate_memory, + MemoryPoolManager* pool_manager) { const VkBufferUsageFlags buffer_usage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; VmaAllocationCreateInfo alloc_create_info = gpuonly_resource_create_info(); + alloc_create_info.pool = get_memory_pool_for_buffer( + pool_manager, size, buffer_usage, alloc_create_info); + return VulkanBuffer( allocator_, size, alloc_create_info, buffer_usage, allocate_memory); } -VulkanBuffer Allocator::create_uniform_buffer(const VkDeviceSize size) { +VulkanBuffer Allocator::create_uniform_buffer( + const VkDeviceSize size, + MemoryPoolManager* pool_manager) { VmaAllocationCreateInfo alloc_create_info = {}; alloc_create_info.flags = DEFAULT_ALLOCATION_STRATEGY | VMA_ALLOCATION_CREATE_HOST_ACCESS_SEQUENTIAL_WRITE_BIT; alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO; VkBufferUsageFlags buffer_usage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT; + alloc_create_info.pool = get_memory_pool_for_buffer( + pool_manager, size, buffer_usage, alloc_create_info); return VulkanBuffer(allocator_, size, alloc_create_info, buffer_usage); } diff --git a/backends/vulkan/runtime/vk_api/memory/Allocator.h b/backends/vulkan/runtime/vk_api/memory/Allocator.h index 8f76ca932b7..56f6ae90cbf 100644 --- a/backends/vulkan/runtime/vk_api/memory/Allocator.h +++ b/backends/vulkan/runtime/vk_api/memory/Allocator.h @@ -19,6 +19,7 @@ #include #include #include +#include namespace vkcompute { namespace vkapi { @@ -48,11 +49,16 @@ class Allocator final { VmaAllocator allocator_; public: + inline VmaAllocator handle() { + return allocator_; + } + VmaAllocationCreateInfo gpuonly_resource_create_info(); Allocation create_allocation( const VkMemoryRequirements& memory_requirements, - const VmaAllocationCreateInfo& create_info); + const VmaAllocationCreateInfo& create_info, + MemoryPoolManager* pool_manager = nullptr); VulkanImage create_image( const VkDevice, @@ -64,18 +70,24 @@ class Allocator final { const VulkanImage::SamplerProperties&, VkSampler, const bool allow_transfer = false, - const bool allocate_memory = true); + const bool allocate_memory = true, + MemoryPoolManager* pool_manager = nullptr); - VulkanBuffer create_staging_buffer(const VkDeviceSize); + VulkanBuffer create_staging_buffer( + const VkDeviceSize, + MemoryPoolManager* pool_manager = nullptr); VulkanBuffer create_storage_buffer( const VkDeviceSize, - const bool allocate_memory = true); + const bool allocate_memory = true, + MemoryPoolManager* pool_manager = nullptr); /* * Create a uniform buffer with a specified size */ - VulkanBuffer create_uniform_buffer(const VkDeviceSize); + VulkanBuffer create_uniform_buffer( + const VkDeviceSize, + MemoryPoolManager* pool_manager = nullptr); /* * Create a uniform buffer containing the data in an arbitrary struct diff --git a/backends/vulkan/runtime/vk_api/memory/Buffer.cpp b/backends/vulkan/runtime/vk_api/memory/Buffer.cpp index 4f58e07b146..522ba573aaf 100644 --- a/backends/vulkan/runtime/vk_api/memory/Buffer.cpp +++ b/backends/vulkan/runtime/vk_api/memory/Buffer.cpp @@ -11,6 +11,21 @@ namespace vkcompute { namespace vkapi { +VkBufferCreateInfo generate_buffer_create_info( + VkDeviceSize size, + VkBufferUsageFlags usage) { + return VkBufferCreateInfo{ + VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + size, // size + usage, // usage + VK_SHARING_MODE_EXCLUSIVE, // sharingMode + 0u, // queueFamilyIndexCount + nullptr, // pQueueFamilyIndices + }; +} + // // VulkanBuffer // @@ -42,16 +57,8 @@ VulkanBuffer::VulkanBuffer( buffer_properties_.mem_range = 1u; } - const VkBufferCreateInfo buffer_create_info{ - VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO, // sType - nullptr, // pNext - 0u, // flags - buffer_properties_.size, // size - usage, // usage - VK_SHARING_MODE_EXCLUSIVE, // sharingMode - 0u, // queueFamilyIndexCount - nullptr, // pQueueFamilyIndices - }; + const VkBufferCreateInfo buffer_create_info = + generate_buffer_create_info(buffer_properties_.size, usage); if (allocate_memory) { VK_CHECK(vmaCreateBuffer( diff --git a/backends/vulkan/runtime/vk_api/memory/Buffer.h b/backends/vulkan/runtime/vk_api/memory/Buffer.h index 0ef9f7e95e4..a4411e16166 100644 --- a/backends/vulkan/runtime/vk_api/memory/Buffer.h +++ b/backends/vulkan/runtime/vk_api/memory/Buffer.h @@ -27,6 +27,10 @@ class vTensorStorage; namespace vkapi { +VkBufferCreateInfo generate_buffer_create_info( + VkDeviceSize size, + VkBufferUsageFlags usage); + using MemoryAccessFlags = uint8_t; enum MemoryAccessType : MemoryAccessFlags { diff --git a/backends/vulkan/runtime/vk_api/memory/Image.cpp b/backends/vulkan/runtime/vk_api/memory/Image.cpp index da6ff76bccd..70a300b6638 100644 --- a/backends/vulkan/runtime/vk_api/memory/Image.cpp +++ b/backends/vulkan/runtime/vk_api/memory/Image.cpp @@ -11,6 +11,32 @@ namespace vkcompute { namespace vkapi { +VkImageCreateInfo generate_image_create_info( + VkImageType image_type, + VkFormat image_format, + VkExtent3D image_extents, + VkImageTiling image_tiling, + VkImageUsageFlags image_usage, + VkImageLayout initial_layout) { + return VkImageCreateInfo{ + VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType + nullptr, // pNext + 0u, // flags + image_type, // imageType + image_format, // format + image_extents, // extents + 1u, // mipLevels + 1u, // arrayLayers + VK_SAMPLE_COUNT_1_BIT, // samples + image_tiling, // tiling + image_usage, // usage + VK_SHARING_MODE_EXCLUSIVE, // sharingMode + 0u, // queueFamilyIndexCount + nullptr, // pQueueFamilyIndices + initial_layout, // initialLayout + }; +} + // // ImageSampler // @@ -146,23 +172,13 @@ VulkanImage::VulkanImage( image_properties_.image_extents.depth = 1u; } - const VkImageCreateInfo image_create_info{ - VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO, // sType - nullptr, // pNext - 0u, // flags - image_properties_.image_type, // imageType - image_properties_.image_format, // format - image_properties_.image_extents, // extents - 1u, // mipLevels - 1u, // arrayLayers - VK_SAMPLE_COUNT_1_BIT, // samples - image_properties_.image_tiling, // tiling - image_properties_.image_usage, // usage - VK_SHARING_MODE_EXCLUSIVE, // sharingMode - 0u, // queueFamilyIndexCount - nullptr, // pQueueFamilyIndices - layout_, // initialLayout - }; + const VkImageCreateInfo image_create_info = generate_image_create_info( + image_properties_.image_type, + image_properties_.image_format, + image_properties_.image_extents, + image_properties_.image_tiling, + image_properties_.image_usage, + layout_); if (allocate_memory) { VK_CHECK(vmaCreateImage( diff --git a/backends/vulkan/runtime/vk_api/memory/Image.h b/backends/vulkan/runtime/vk_api/memory/Image.h index 5bbdaf06b47..7760c021ffe 100644 --- a/backends/vulkan/runtime/vk_api/memory/Image.h +++ b/backends/vulkan/runtime/vk_api/memory/Image.h @@ -30,6 +30,14 @@ class vTensorStorage; namespace vkapi { +VkImageCreateInfo generate_image_create_info( + VkImageType image_type, + VkFormat image_format, + VkExtent3D image_extents, + VkImageTiling image_tiling, + VkImageUsageFlags image_usage, + VkImageLayout initial_layout = VK_IMAGE_LAYOUT_UNDEFINED); + class ImageSampler final { public: struct Properties final { diff --git a/backends/vulkan/runtime/vk_api/memory/Pool.cpp b/backends/vulkan/runtime/vk_api/memory/Pool.cpp new file mode 100644 index 00000000000..ebd9ab2db8c --- /dev/null +++ b/backends/vulkan/runtime/vk_api/memory/Pool.cpp @@ -0,0 +1,129 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace vkcompute { +namespace vkapi { + +VmaPool create_memory_pool( + const VmaAllocator allocator, + const uint32_t mem_type_idx, + const size_t block_size = 0, + const size_t max_block_count = 0) { + VmaPoolCreateInfo create_info = { + mem_type_idx, // memoryTypeIndex + 0u, // flags + block_size, // blockSize + 0u, // minBlockCount + max_block_count, // maxBlockCount + 0.0, // priority + 0u, // minAllocationAlignment + nullptr}; + + VmaPool pool = VK_NULL_HANDLE; + VK_CHECK(vmaCreatePool(allocator, &create_info, &pool)); + return pool; +} + +MemoryPool::MemoryPool() + : allocator(VK_NULL_HANDLE), + memory_type_idx(0u), + block_size(0u), + max_block_count(0u), + handle(VK_NULL_HANDLE) {} + +MemoryPool::MemoryPool(MemoryPool&& other) noexcept + : allocator(other.allocator), + memory_type_idx(other.memory_type_idx), + block_size(other.block_size), + max_block_count(other.max_block_count), + handle(other.handle) { + other.handle = VK_NULL_HANDLE; +} + +MemoryPool& MemoryPool::operator=(MemoryPool&& other) noexcept { + VmaAllocator tmp_allocator = allocator; + VmaPool tmp_handle = handle; + + allocator = other.allocator; + memory_type_idx = other.memory_type_idx; + block_size = other.block_size; + max_block_count = other.max_block_count; + handle = other.handle; + + other.allocator = tmp_allocator; + other.handle = tmp_handle; + + return *this; +} + +void MemoryPool::initialize() { + VK_CHECK_COND(handle == VK_NULL_HANDLE); + handle = create_memory_pool( + allocator, memory_type_idx, block_size, max_block_count); +} + +MemoryPool::~MemoryPool() { + if (handle != VK_NULL_HANDLE) { + vmaDestroyPool(allocator, handle); + } +} + +MemoryPoolManager::MemoryPoolManager( + VmaAllocator vma_allocator, + const uint32_t num_memory_types) + : allocator{vma_allocator}, memory_pools(num_memory_types) { + for (int i = 0; i < num_memory_types; ++i) { + memory_pools.at(i).allocator = allocator; + memory_pools.at(i).memory_type_idx = i; + } +} + +VmaPool MemoryPoolManager::get_memory_pool(const uint32_t memory_type_idx) { + VK_CHECK_COND(memory_type_idx < memory_pools.size()); + MemoryPool& pool = memory_pools.at(memory_type_idx); + if (pool.handle == VK_NULL_HANDLE) { + pool.initialize(); + } + return pool.handle; +} + +uint32_t MemoryPoolManager::get_memory_type_idx( + const VmaAllocationCreateInfo alloc_create_info) const { + uint32_t memory_type_idx = 0u; + VK_CHECK(vmaFindMemoryTypeIndex( + allocator, + UINT32_MAX, // memoryTypeBits - using all available memory types + &alloc_create_info, + &memory_type_idx)); + return memory_type_idx; +} + +uint32_t MemoryPoolManager::get_memory_type_idx( + const VmaAllocationCreateInfo alloc_create_info, + const VkImageCreateInfo image_create_info) const { + uint32_t memory_type_idx = 0u; + VK_CHECK(vmaFindMemoryTypeIndexForImageInfo( + allocator, &image_create_info, &alloc_create_info, &memory_type_idx)); + return memory_type_idx; +} + +uint32_t MemoryPoolManager::get_memory_type_idx( + const VmaAllocationCreateInfo alloc_create_info, + const VkBufferCreateInfo buffer_create_info) const { + uint32_t memory_type_idx = 0u; + VK_CHECK(vmaFindMemoryTypeIndexForBufferInfo( + allocator, &buffer_create_info, &alloc_create_info, &memory_type_idx)); + return memory_type_idx; +} + +} // namespace vkapi +} // namespace vkcompute diff --git a/backends/vulkan/runtime/vk_api/memory/Pool.h b/backends/vulkan/runtime/vk_api/memory/Pool.h new file mode 100644 index 00000000000..2ac62d6d8dc --- /dev/null +++ b/backends/vulkan/runtime/vk_api/memory/Pool.h @@ -0,0 +1,66 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +// @lint-ignore-every CLANGTIDY facebook-hte-BadMemberName + +#include +#include + +#include + +namespace vkcompute { +namespace vkapi { + +struct MemoryPool final { + MemoryPool(); + + MemoryPool(const MemoryPool& other) = delete; + MemoryPool& operator=(const MemoryPool& other) = delete; + + MemoryPool(MemoryPool&& other) noexcept; + MemoryPool& operator=(MemoryPool&& other) noexcept; + + ~MemoryPool(); + + void initialize(); + + VmaAllocator allocator; + uint32_t memory_type_idx; + size_t block_size; + size_t max_block_count; + VmaPool handle; +}; + +class MemoryPoolManager final { + public: + explicit MemoryPoolManager( + VmaAllocator vma_allocator, + const uint32_t num_memory_types); + + VmaPool get_memory_pool(const uint32_t memory_type_idx); + + uint32_t get_memory_type_idx( + const VmaAllocationCreateInfo alloc_create_info) const; + + uint32_t get_memory_type_idx( + const VmaAllocationCreateInfo alloc_create_info, + const VkImageCreateInfo image_create_info) const; + + uint32_t get_memory_type_idx( + const VmaAllocationCreateInfo alloc_create_info, + const VkBufferCreateInfo buffer_create_info) const; + + private: + VmaAllocator allocator; + std::vector memory_pools; +}; + +} // namespace vkapi +} // namespace vkcompute