Skip to content

Commit 3716680

Browse files
authored
[ET-VK] Persistently map staging buffers
Differential Revision: D59706627 Pull Request resolved: #5021
1 parent f326ee1 commit 3716680

File tree

13 files changed

+76
-136
lines changed

13 files changed

+76
-136
lines changed

backends/vulkan/runtime/api/containers/StagingBuffer.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#include <executorch/backends/vulkan/runtime/vk_api/memory/Buffer.h>
1616

17+
#include <cstring>
18+
1719
namespace vkcompute {
1820
namespace api {
1921

@@ -55,13 +57,41 @@ class StagingBuffer final {
5557
return vulkan_buffer_;
5658
}
5759

60+
inline void* data() {
61+
return vulkan_buffer_.allocation_info().pMappedData;
62+
}
63+
5864
inline size_t numel() {
5965
return numel_;
6066
}
6167

6268
inline size_t nbytes() {
6369
return nbytes_;
6470
}
71+
72+
inline void copy_from(const void* src, const size_t nbytes) {
73+
VK_CHECK_COND(nbytes <= nbytes_);
74+
memcpy(data(), src, nbytes);
75+
vmaFlushAllocation(
76+
vulkan_buffer_.vma_allocator(),
77+
vulkan_buffer_.allocation(),
78+
0u,
79+
VK_WHOLE_SIZE);
80+
}
81+
82+
inline void copy_to(void* dst, const size_t nbytes) {
83+
VK_CHECK_COND(nbytes <= nbytes_);
84+
vmaInvalidateAllocation(
85+
vulkan_buffer_.vma_allocator(),
86+
vulkan_buffer_.allocation(),
87+
0u,
88+
VK_WHOLE_SIZE);
89+
memcpy(dst, data(), nbytes);
90+
}
91+
92+
inline void set_staging_zeros() {
93+
memset(data(), 0, nbytes_);
94+
}
6595
};
6696

6797
} // namespace api

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ void ComputeGraph::copy_into_staging(
401401
const size_t numel) {
402402
StagingPtr staging = get_staging(idx);
403403
size_t nbytes = numel * vkapi::element_size(staging->dtype());
404-
copy_ptr_to_staging(data, *staging, nbytes);
404+
staging->copy_from(data, nbytes);
405405
}
406406

407407
void ComputeGraph::copy_from_staging(
@@ -410,7 +410,7 @@ void ComputeGraph::copy_from_staging(
410410
const size_t numel) {
411411
StagingPtr staging = get_staging(idx);
412412
size_t nbytes = numel * vkapi::element_size(staging->dtype());
413-
copy_staging_to_ptr(*staging, data, nbytes);
413+
staging->copy_to(data, nbytes);
414414
}
415415

416416
void ComputeGraph::prepare() {

backends/vulkan/runtime/graph/ops/PrepackNode.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,16 +53,15 @@ api::StagingBuffer PrepackNode::create_staging_buffer(ComputeGraph* graph) {
5353
if (graph->val_is_none(tref_)) {
5454
size_t numel = utils::multiply_integers(packed->sizes());
5555
api::StagingBuffer staging(graph->context(), packed->dtype(), numel);
56-
size_t nbytes = numel * vkapi::element_size(packed->dtype());
57-
set_staging_zeros(staging, nbytes);
56+
staging.set_staging_zeros();
5857
return staging;
5958
}
6059

6160
TensorRefPtr tref = graph->get_tref(tref_);
6261
size_t numel = utils::multiply_integers(tref->sizes);
6362
api::StagingBuffer staging(graph->context(), tref->dtype, numel);
6463
size_t nbytes = numel * vkapi::element_size(tref->dtype);
65-
copy_ptr_to_staging(tref->data, staging, nbytes);
64+
staging.copy_from(tref->data, nbytes);
6665
return staging;
6766
}
6867

backends/vulkan/runtime/graph/ops/utils/StagingUtils.cpp

Lines changed: 0 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -13,88 +13,8 @@
1313

1414
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/DimUtils.h>
1515

16-
#include <cstring>
17-
1816
namespace vkcompute {
1917

20-
template <typename T>
21-
void memcpy_to_mapping_impl(
22-
const void* src,
23-
vkapi::MemoryMap& dst_mapping,
24-
const size_t nbytes) {
25-
T* data_ptr = dst_mapping.template data<T>();
26-
memcpy(data_ptr, reinterpret_cast<const T*>(src), nbytes);
27-
}
28-
29-
template <typename T>
30-
void memcpy_from_mapping_impl(
31-
vkapi::MemoryMap& src_mapping,
32-
void* dst,
33-
const size_t nbytes) {
34-
T* data_ptr = src_mapping.template data<T>();
35-
memcpy(reinterpret_cast<T*>(dst), data_ptr, nbytes);
36-
}
37-
38-
void memcpy_to_mapping(
39-
const void* src,
40-
vkapi::MemoryMap& dst_mapping,
41-
const size_t nbytes,
42-
const vkapi::ScalarType dtype) {
43-
#define DTYPE_CASE(ctype, vkformat, name) \
44-
case vkapi::ScalarType::name: \
45-
memcpy_to_mapping_impl<ctype>(src, dst_mapping, nbytes); \
46-
break;
47-
48-
switch (dtype) {
49-
VK_FORALL_SCALAR_TYPES(DTYPE_CASE)
50-
default:
51-
VK_THROW("Unrecognized dtype!");
52-
}
53-
#undef DTYPE_CASE
54-
}
55-
56-
void memcpy_from_mapping(
57-
vkapi::MemoryMap& src_mapping,
58-
void* dst,
59-
const size_t nbytes,
60-
const vkapi::ScalarType dtype) {
61-
#define DTYPE_CASE(ctype, vkformat, name) \
62-
case vkapi::ScalarType::name: \
63-
memcpy_from_mapping_impl<ctype>(src_mapping, dst, nbytes); \
64-
break;
65-
66-
switch (dtype) {
67-
VK_FORALL_SCALAR_TYPES(DTYPE_CASE)
68-
default:
69-
VK_THROW("Unrecognized dtype!");
70-
}
71-
#undef DTYPE_CASE
72-
}
73-
74-
void copy_ptr_to_staging(
75-
const void* src,
76-
api::StagingBuffer& staging,
77-
const size_t nbytes) {
78-
vkapi::MemoryMap mapping(staging.buffer(), vkapi::MemoryAccessType::WRITE);
79-
mapping.invalidate();
80-
memcpy_to_mapping(src, mapping, nbytes, staging.dtype());
81-
}
82-
83-
void copy_staging_to_ptr(
84-
api::StagingBuffer& staging,
85-
void* dst,
86-
const size_t nbytes) {
87-
vkapi::MemoryMap mapping(staging.buffer(), vkapi::MemoryAccessType::READ);
88-
mapping.invalidate();
89-
memcpy_from_mapping(mapping, dst, nbytes, staging.dtype());
90-
}
91-
92-
void set_staging_zeros(api::StagingBuffer& staging, const size_t nbytes) {
93-
vkapi::MemoryMap mapping(staging.buffer(), vkapi::MemoryAccessType::WRITE);
94-
uint8_t* data_ptr = mapping.template data<uint8_t>();
95-
memset(data_ptr, 0, staging.nbytes());
96-
}
97-
9818
vkapi::ShaderInfo get_nchw_to_tensor_shader(
9919
const api::vTensor& v_dst,
10020
const bool int8_buffer_enabled) {

backends/vulkan/runtime/graph/ops/utils/StagingUtils.h

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,6 @@
1212

1313
namespace vkcompute {
1414

15-
//
16-
// Functions to copy data into and out of a staging buffer
17-
//
18-
19-
void copy_ptr_to_staging(
20-
const void* src,
21-
api::StagingBuffer& staging,
22-
const size_t nbytes);
23-
void copy_staging_to_ptr(
24-
api::StagingBuffer& staging,
25-
void* dst,
26-
const size_t nbytes);
27-
28-
void set_staging_zeros(api::StagingBuffer& staging, const size_t nbytes);
29-
30-
//
31-
// Functions to get shaders
32-
//
33-
3415
vkapi::ShaderInfo get_nchw_to_tensor_shader(
3516
const api::vTensor& v_dst,
3617
bool int8_buffer_enabled = true);

backends/vulkan/runtime/vk_api/memory/Allocation.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ Allocation::Allocation()
3030
create_info{},
3131
allocator(VK_NULL_HANDLE),
3232
allocation(VK_NULL_HANDLE),
33+
allocation_info({}),
3334
is_copy_(false) {}
3435

3536
Allocation::Allocation(
@@ -40,6 +41,7 @@ Allocation::Allocation(
4041
create_info(create_info),
4142
allocator(vma_allocator),
4243
allocation(VK_NULL_HANDLE),
44+
allocation_info({}),
4345
is_copy_(false) {
4446
VK_CHECK(vmaAllocateMemory(
4547
allocator, &memory_requirements, &create_info, &allocation, nullptr));
@@ -50,15 +52,18 @@ Allocation::Allocation(const Allocation& other) noexcept
5052
create_info(other.create_info),
5153
allocator(other.allocator),
5254
allocation(other.allocation),
55+
allocation_info(other.allocation_info),
5356
is_copy_(true) {}
5457

5558
Allocation::Allocation(Allocation&& other) noexcept
5659
: memory_requirements(other.memory_requirements),
5760
create_info(other.create_info),
5861
allocator(other.allocator),
5962
allocation(other.allocation),
63+
allocation_info(other.allocation_info),
6064
is_copy_(other.is_copy_) {
6165
other.allocation = VK_NULL_HANDLE;
66+
other.allocation_info = {};
6267
}
6368

6469
Allocation& Allocation::operator=(Allocation&& other) noexcept {
@@ -68,9 +73,11 @@ Allocation& Allocation::operator=(Allocation&& other) noexcept {
6873
create_info = other.create_info;
6974
allocator = other.allocator;
7075
allocation = other.allocation;
76+
allocation_info = other.allocation_info;
7177
is_copy_ = other.is_copy_;
7278

7379
other.allocation = tmp_allocation;
80+
other.allocation_info = {};
7481

7582
return *this;
7683
}

backends/vulkan/runtime/vk_api/memory/Allocation.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ struct Allocation final {
6262
VmaAllocator allocator;
6363
// Handles to the allocated memory
6464
VmaAllocation allocation;
65+
// Information about the allocated memory
66+
VmaAllocationInfo allocation_info;
6567

6668
private:
6769
// Indicates whether this class instance is a copy of another class instance,

backends/vulkan/runtime/vk_api/memory/Allocator.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,8 @@ VulkanBuffer Allocator::create_staging_buffer(const VkDeviceSize size) {
142142
// Staging buffers are accessed by both the CPU and GPU, so set the
143143
// appropriate flags to indicate that the host device will be accessing
144144
// the data from this buffer.
145-
alloc_create_info.flags |= VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT;
145+
alloc_create_info.flags |= VMA_ALLOCATION_CREATE_HOST_ACCESS_RANDOM_BIT |
146+
VMA_ALLOCATION_CREATE_MAPPED_BIT;
146147
alloc_create_info.usage = VMA_MEMORY_USAGE_AUTO_PREFER_HOST;
147148
alloc_create_info.requiredFlags = VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT;
148149
alloc_create_info.preferredFlags =

backends/vulkan/runtime/vk_api/memory/Buffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ VulkanBuffer::VulkanBuffer(
6767
&allocation_create_info,
6868
&handle_,
6969
&(memory_.allocation),
70-
nullptr));
70+
&(memory_.allocation_info)));
7171
} else {
7272
VmaAllocatorInfo allocator_info{};
7373
vmaGetAllocatorInfo(allocator_, &allocator_info);

backends/vulkan/runtime/vk_api/memory/Buffer.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,10 @@ class VulkanBuffer final {
114114
return memory_.allocation;
115115
}
116116

117+
inline VmaAllocationInfo allocation_info() const {
118+
return memory_.allocation_info;
119+
}
120+
117121
inline VmaAllocationCreateInfo allocation_create_info() const {
118122
return VmaAllocationCreateInfo(memory_.create_info);
119123
}

0 commit comments

Comments
 (0)