Skip to content

Commit 01305c9

Browse files
authored
feat: reduce gpu sync size. (#246)
1 parent e7c86f3 commit 01305c9

File tree

2 files changed

+55
-23
lines changed

2 files changed

+55
-23
lines changed

src/nn/nn-vulkan.cpp

Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
#include "nn-vulkan.hpp"
22

3+
#define DEBUG_VULKAN_BUFFERS false
4+
#define DEBUG_VULKAN_TRACE false
5+
36
#if DEBUG_VULKAN_TRACE
47
#define VULKAN_TRACE(...) printf("VULKAN_TRACE: "); printf(__VA_ARGS__); printf("\n");
58
#else
69
#define VULKAN_TRACE(...)
710
#endif
811

9-
#define DEBUG_VULKAN_BUFFERS false
10-
1112
static bool hasPortabilityExtension() {
1213
#ifdef __APPLE__
1314
const std::vector<vk::ExtensionProperties> extensionProperties = vk::enumerateInstanceExtensionProperties();
@@ -122,7 +123,7 @@ void NnVulkanStagingCopy::executeCopyCommand() {
122123
context->device.freeCommandBuffers(context->commandPool, 1, &commandBuffer);
123124
}
124125

125-
NnVulkanBuffer::NnVulkanBuffer(NnVulkanContext *context, const char *name, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess) {
126+
NnVulkanBuffer::NnVulkanBuffer(NnVulkanContext *context, const char *name, const NnSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess) {
126127
this->context = context;
127128
this->name = name;
128129
this->bufferSize = bufferSize;
@@ -131,7 +132,7 @@ NnVulkanBuffer::NnVulkanBuffer(NnVulkanContext *context, const char *name, const
131132

132133
isHostVisible = false;
133134

134-
VULKAN_TRACE("Creating buffer of size %zu (fastAccess=%d)", (NnSize)bufferSize, fastAccess);
135+
VULKAN_TRACE("Creating buffer of size %zu (fastAccess=%d)", bufferSize, fastAccess);
135136

136137
uint32_t memoryTypeIndex = MEMORY_TYPE_INDEX_NOT_FOUND;
137138
if (fastAccess) {
@@ -164,43 +165,67 @@ NnVulkanBuffer::NnVulkanBuffer(NnVulkanContext *context, const char *name, const
164165
deviceMemory = b.second;
165166
if (isHostVisible)
166167
hostPointer = context->device.mapMemory(deviceMemory, 0, bufferSize);
167-
VULKAN_TRACE("Created buffer of size %zu (fastAccess=%d)", (NnSize)bufferSize, fastAccess);
168+
VULKAN_TRACE("Created buffer of size %zu (fastAccess=%d)", bufferSize, fastAccess);
168169
}
169170

170171
NnVulkanBuffer::~NnVulkanBuffer() {
171172
if (hostPointer != nullptr)
172173
context->device.unmapMemory(deviceMemory);
173174
context->device.freeMemory(deviceMemory);
174175
context->device.destroyBuffer(deviceBuffer);
175-
VULKAN_TRACE("Destroyed buffer of size %zu", (NnSize)bufferSize);
176+
VULKAN_TRACE("Destroyed buffer of size %zu", bufferSize);
176177
}
177178

178179
void NnVulkanBuffer::write(const NnByte *data) {
180+
write(data, bufferSize);
181+
}
182+
183+
void NnVulkanBuffer::write(const NnByte *data, const NnSize size) {
184+
assert(size <= bufferSize);
185+
179186
if (isHostVisible && hostPointer != nullptr) {
180-
std::memcpy(hostPointer, data, bufferSize);
181-
context->device.flushMappedMemoryRanges({ { deviceMemory, 0, bufferSize } });
182-
VULKAN_TRACE("Wrote %zu bytes to host visible buffer", (NnSize)bufferSize);
187+
std::memcpy(hostPointer, data, size);
188+
context->device.flushMappedMemoryRanges({ { deviceMemory, 0, (vk::DeviceSize)size } });
189+
VULKAN_TRACE("Wrote %zu bytes to host visible buffer", size);
183190
} else {
184-
NnVulkanStagingCopy copy(context, deviceBuffer, bufferSize, COPY_TO_DEVICE);
191+
NnVulkanStagingCopy copy(context, deviceBuffer, size, COPY_TO_DEVICE);
185192
copy.copy((NnByte *)data);
186193
copy.executeCopyCommand();
187-
VULKAN_TRACE("Wrote %zu bytes to buffer", (NnSize)bufferSize);
194+
VULKAN_TRACE("Wrote %zu bytes to buffer", size);
188195
}
189196
}
190197

191198
void NnVulkanBuffer::read(NnByte *data) {
199+
read(data, bufferSize);
200+
}
201+
202+
void NnVulkanBuffer::read(NnByte *data, const NnSize size) {
203+
assert(size <= bufferSize);
204+
192205
if (isHostVisible && hostPointer != nullptr) {
193-
context->device.invalidateMappedMemoryRanges({ {deviceMemory, 0, bufferSize} });
194-
std::memcpy(data, hostPointer, bufferSize);
206+
context->device.invalidateMappedMemoryRanges({ {deviceMemory, 0, (vk::DeviceSize)size} });
207+
std::memcpy(data, hostPointer, size);
195208

196-
VULKAN_TRACE("Read %zu bytes from host visible buffer", (NnSize)bufferSize);
209+
VULKAN_TRACE("Read %zu bytes from host visible buffer", size);
197210
} else {
198211
NnVulkanStagingCopy copy(context, deviceBuffer, bufferSize, COPY_FROM_DEVICE);
199212
copy.executeCopyCommand();
200213
copy.copy(data);
201214

202-
VULKAN_TRACE("Read %zu bytes from buffer", (NnSize)bufferSize);
215+
VULKAN_TRACE("Read %zu bytes from buffer", size);
216+
}
217+
}
218+
219+
NnSize NnVulkanBuffer::calcSliceSize(const NnSize nominator, const NnSize denominator) {
220+
assert(bufferSize % denominator == 0);
221+
222+
NnSize size = (bufferSize / denominator) * nominator;
223+
if (context->nonCoherentAtomSize != 0) {
224+
// TODO: this alignment is not needed for coherent memory
225+
size += context->nonCoherentAtomSize - (size % context->nonCoherentAtomSize);
226+
size = std::min(size, bufferSize);
203227
}
228+
return size;
204229
}
205230

206231
static NnByte *findFirstOpConfig(NnNodeConfig *nodeConfig, NnOpCode opCode) {
@@ -332,11 +357,12 @@ NnVulkanDevice::NnVulkanDevice(NnUint gpuIndex, NnNetConfig *netConfig, NnNodeCo
332357
printf("🌋 Device: %s\n", (char*)deviceProps.deviceName);
333358
printf("🌋 DeviceApiVersion: %d.%d.%d\n", VK_VERSION_MAJOR(deviceProps.apiVersion), VK_VERSION_MINOR(deviceProps.apiVersion), VK_VERSION_PATCH(deviceProps.apiVersion));
334359
printf("🌋 MaxComputeSharedMemory: %d kB\n", deviceProps.limits.maxComputeSharedMemorySize / 1024);
360+
printf("🌋 NonCoherentAtomSize: %lu bytes\n", (NnSize)deviceProps.limits.nonCoherentAtomSize);
335361

336362
vk::PhysicalDeviceMemoryProperties memoryProperties = context.physicalDevice.getMemoryProperties();
337363
for (unsigned int h = 0; h < memoryProperties.memoryHeapCount; h++) {
338364
if (memoryProperties.memoryHeaps[h].flags & vk::MemoryHeapFlagBits::eDeviceLocal)
339-
printf("🌋 Heap[%u]: %lu MB\n", h, ((unsigned long)memoryProperties.memoryHeaps[h].size) / (1024 * 1024));
365+
printf("🌋 Heap[%u]: %lu MB\n", h, ((NnSize)memoryProperties.memoryHeaps[h].size) / (1024 * 1024));
340366
}
341367

342368
vk::PhysicalDeviceFeatures deviceFeatures = context.physicalDevice.getFeatures();
@@ -375,6 +401,7 @@ NnVulkanDevice::NnVulkanDevice(NnUint gpuIndex, NnNetConfig *netConfig, NnNodeCo
375401
vk::CommandPoolCreateInfo commandPoolCreateInfo(vk::CommandPoolCreateFlags(vk::CommandPoolCreateFlagBits::eTransient | vk::CommandPoolCreateFlagBits::eResetCommandBuffer), context.queueFamilyIndex);
376402
context.commandPool = context.device.createCommandPool(commandPoolCreateInfo);
377403
context.queue = context.device.getQueue(context.queueFamilyIndex, 0);
404+
context.nonCoherentAtomSize = deviceProps.limits.nonCoherentAtomSize;
378405

379406
VULKAN_TRACE("Context created");
380407
data = new NnVulkanDeviceData(&context, netConfig, nodeConfig);
@@ -790,15 +817,17 @@ void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint thre
790817
for (NnUint i = 0; i < netConfig->nPreSyncs; i++) {
791818
NnPreSyncConfig *preSyncConfig = &netConfig->preSyncs[i];
792819
NnByte *pipeData = netExecution->pipes[preSyncConfig->pipeIndex];
793-
data->pipes[preSyncConfig->pipeIndex]->write(pipeData);
820+
NnVulkanBuffer *buffer = data->pipes[preSyncConfig->pipeIndex].get();
821+
buffer->write(pipeData, buffer->calcSliceSize(batchSize, netConfig->nBatches));
794822
}
795823
}
796824

797825
for (NnUint opIndex = 0; opIndex < segmentConfig->nOps; opIndex++) {
798826
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
799827
if (opConfig->input.source == SRC_PIPE) {
800828
NnByte *pipeData = netExecution->pipes[opConfig->input.pointerIndex];
801-
data->pipes[opConfig->input.pointerIndex]->write(pipeData);
829+
NnVulkanBuffer *buffer = data->pipes[opConfig->input.pointerIndex].get();
830+
buffer->write(pipeData, buffer->calcSliceSize(batchSize, netConfig->nBatches));
802831
}
803832
}
804833
}
@@ -856,7 +885,8 @@ void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint thre
856885
NnOpConfig *opConfig = &segmentConfig->ops[opIndex];
857886
if (opConfig->output.source == SRC_PIPE) {
858887
NnByte *pipeData = netExecution->pipes[opConfig->output.pointerIndex];
859-
data->pipes[opConfig->output.pointerIndex]->read(pipeData);
888+
NnVulkanBuffer *buffer = data->pipes[opConfig->output.pointerIndex].get();
889+
buffer->read(pipeData, buffer->calcSliceSize(batchSize, netConfig->nBatches));
860890
}
861891
}
862892
}

src/nn/nn-vulkan.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,14 @@
66
#include "nn-executor.hpp"
77
#include "nn-cpu-ops.hpp"
88

9-
#define DEBUG_VULKAN_TRACE false
10-
119
typedef struct {
1210
vk::Instance instance;
1311
vk::PhysicalDevice physicalDevice;
1412
vk::Device device;
1513
uint32_t queueFamilyIndex;
1614
vk::CommandPool commandPool;
1715
vk::Queue queue;
16+
NnSize nonCoherentAtomSize;
1817
} NnVulkanContext;
1918

2019
enum NnStagingVulkanCopyDirection {
@@ -47,13 +46,16 @@ class NnVulkanBuffer {
4746
void *hostPointer;
4847
public:
4948
const char *name;
50-
vk::DeviceSize bufferSize;
49+
NnSize bufferSize;
5150
vk::Buffer deviceBuffer;
5251
vk::BufferUsageFlags usageFlags;
53-
NnVulkanBuffer(NnVulkanContext *context, const char *name, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess);
52+
NnVulkanBuffer(NnVulkanContext *context, const char *name, const NnSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess);
5453
~NnVulkanBuffer();
5554
void write(const NnByte *data);
55+
void write(const NnByte *data, const NnSize size);
5656
void read(NnByte *data);
57+
void read(NnByte *data, const NnSize size);
58+
NnSize calcSliceSize(const NnSize nominator, const NnSize denominator);
5759
};
5860

5961
typedef struct {

0 commit comments

Comments
 (0)