11#include " nn-vulkan.hpp"
22
3+ #define DEBUG_VULKAN_BUFFERS false
4+ #define DEBUG_VULKAN_TRACE false
5+
36#if DEBUG_VULKAN_TRACE
47 #define VULKAN_TRACE (...) printf(" VULKAN_TRACE: " ); printf(__VA_ARGS__); printf(" \n " );
58#else
69 #define VULKAN_TRACE (...)
710#endif
811
9- #define DEBUG_VULKAN_BUFFERS false
10-
1112static bool hasPortabilityExtension () {
1213#ifdef __APPLE__
1314 const std::vector<vk::ExtensionProperties> extensionProperties = vk::enumerateInstanceExtensionProperties ();
@@ -122,7 +123,7 @@ void NnVulkanStagingCopy::executeCopyCommand() {
122123 context->device .freeCommandBuffers (context->commandPool , 1 , &commandBuffer);
123124}
124125
125- NnVulkanBuffer::NnVulkanBuffer (NnVulkanContext *context, const char *name, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess) {
126+ NnVulkanBuffer::NnVulkanBuffer (NnVulkanContext *context, const char *name, const NnSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess) {
126127 this ->context = context;
127128 this ->name = name;
128129 this ->bufferSize = bufferSize;
@@ -131,7 +132,7 @@ NnVulkanBuffer::NnVulkanBuffer(NnVulkanContext *context, const char *name, const
131132
132133 isHostVisible = false ;
133134
134- VULKAN_TRACE (" Creating buffer of size %zu (fastAccess=%d)" , (NnSize) bufferSize, fastAccess);
135+ VULKAN_TRACE (" Creating buffer of size %zu (fastAccess=%d)" , bufferSize, fastAccess);
135136
136137 uint32_t memoryTypeIndex = MEMORY_TYPE_INDEX_NOT_FOUND;
137138 if (fastAccess) {
@@ -164,43 +165,67 @@ NnVulkanBuffer::NnVulkanBuffer(NnVulkanContext *context, const char *name, const
164165 deviceMemory = b.second ;
165166 if (isHostVisible)
166167 hostPointer = context->device .mapMemory (deviceMemory, 0 , bufferSize);
167- VULKAN_TRACE (" Created buffer of size %zu (fastAccess=%d)" , (NnSize) bufferSize, fastAccess);
168+ VULKAN_TRACE (" Created buffer of size %zu (fastAccess=%d)" , bufferSize, fastAccess);
168169}
169170
170171NnVulkanBuffer::~NnVulkanBuffer () {
171172 if (hostPointer != nullptr )
172173 context->device .unmapMemory (deviceMemory);
173174 context->device .freeMemory (deviceMemory);
174175 context->device .destroyBuffer (deviceBuffer);
175- VULKAN_TRACE (" Destroyed buffer of size %zu" , (NnSize) bufferSize);
176+ VULKAN_TRACE (" Destroyed buffer of size %zu" , bufferSize);
176177}
177178
178179void NnVulkanBuffer::write (const NnByte *data) {
180+ write (data, bufferSize);
181+ }
182+
183+ void NnVulkanBuffer::write (const NnByte *data, const NnSize size) {
184+ assert (size <= bufferSize);
185+
179186 if (isHostVisible && hostPointer != nullptr ) {
180- std::memcpy (hostPointer, data, bufferSize );
181- context->device .flushMappedMemoryRanges ({ { deviceMemory, 0 , bufferSize } });
182- VULKAN_TRACE (" Wrote %zu bytes to host visible buffer" , (NnSize)bufferSize );
187+ std::memcpy (hostPointer, data, size );
188+ context->device .flushMappedMemoryRanges ({ { deviceMemory, 0 , (vk::DeviceSize)size } });
189+ VULKAN_TRACE (" Wrote %zu bytes to host visible buffer" , size );
183190 } else {
184- NnVulkanStagingCopy copy (context, deviceBuffer, bufferSize , COPY_TO_DEVICE);
191+ NnVulkanStagingCopy copy (context, deviceBuffer, size , COPY_TO_DEVICE);
185192 copy.copy ((NnByte *)data);
186193 copy.executeCopyCommand ();
187- VULKAN_TRACE (" Wrote %zu bytes to buffer" , (NnSize)bufferSize );
194+ VULKAN_TRACE (" Wrote %zu bytes to buffer" , size );
188195 }
189196}
190197
191198void NnVulkanBuffer::read (NnByte *data) {
199+ read (data, bufferSize);
200+ }
201+
202+ void NnVulkanBuffer::read (NnByte *data, const NnSize size) {
203+ assert (size <= bufferSize);
204+
192205 if (isHostVisible && hostPointer != nullptr ) {
193- context->device .invalidateMappedMemoryRanges ({ {deviceMemory, 0 , bufferSize } });
194- std::memcpy (data, hostPointer, bufferSize );
206+ context->device .invalidateMappedMemoryRanges ({ {deviceMemory, 0 , (vk::DeviceSize)size } });
207+ std::memcpy (data, hostPointer, size );
195208
196- VULKAN_TRACE (" Read %zu bytes from host visible buffer" , (NnSize)bufferSize );
209+ VULKAN_TRACE (" Read %zu bytes from host visible buffer" , size );
197210 } else {
198211 NnVulkanStagingCopy copy (context, deviceBuffer, bufferSize, COPY_FROM_DEVICE);
199212 copy.executeCopyCommand ();
200213 copy.copy (data);
201214
202- VULKAN_TRACE (" Read %zu bytes from buffer" , (NnSize)bufferSize);
215+ VULKAN_TRACE (" Read %zu bytes from buffer" , size);
216+ }
217+ }
218+
219+ NnSize NnVulkanBuffer::calcSliceSize (const NnSize nominator, const NnSize denominator) {
220+ assert (bufferSize % denominator == 0 );
221+
222+ NnSize size = (bufferSize / denominator) * nominator;
223+ if (context->nonCoherentAtomSize != 0 ) {
224+ // TODO: this alignment is not needed for coherent memory
225+ size += context->nonCoherentAtomSize - (size % context->nonCoherentAtomSize );
226+ size = std::min (size, bufferSize);
203227 }
228+ return size;
204229}
205230
206231static NnByte *findFirstOpConfig (NnNodeConfig *nodeConfig, NnOpCode opCode) {
@@ -332,11 +357,12 @@ NnVulkanDevice::NnVulkanDevice(NnUint gpuIndex, NnNetConfig *netConfig, NnNodeCo
332357 printf (" 🌋 Device: %s\n " , (char *)deviceProps.deviceName );
333358 printf (" 🌋 DeviceApiVersion: %d.%d.%d\n " , VK_VERSION_MAJOR (deviceProps.apiVersion ), VK_VERSION_MINOR (deviceProps.apiVersion ), VK_VERSION_PATCH (deviceProps.apiVersion ));
334359 printf (" 🌋 MaxComputeSharedMemory: %d kB\n " , deviceProps.limits .maxComputeSharedMemorySize / 1024 );
360+ printf (" 🌋 NonCoherentAtomSize: %lu bytes\n " , (NnSize)deviceProps.limits .nonCoherentAtomSize );
335361
336362 vk::PhysicalDeviceMemoryProperties memoryProperties = context.physicalDevice .getMemoryProperties ();
337363 for (unsigned int h = 0 ; h < memoryProperties.memoryHeapCount ; h++) {
338364 if (memoryProperties.memoryHeaps [h].flags & vk::MemoryHeapFlagBits::eDeviceLocal)
339- printf (" 🌋 Heap[%u]: %lu MB\n " , h, ((unsigned long )memoryProperties.memoryHeaps [h].size ) / (1024 * 1024 ));
365+ printf (" 🌋 Heap[%u]: %lu MB\n " , h, ((NnSize )memoryProperties.memoryHeaps [h].size ) / (1024 * 1024 ));
340366 }
341367
342368 vk::PhysicalDeviceFeatures deviceFeatures = context.physicalDevice .getFeatures ();
@@ -375,6 +401,7 @@ NnVulkanDevice::NnVulkanDevice(NnUint gpuIndex, NnNetConfig *netConfig, NnNodeCo
375401 vk::CommandPoolCreateInfo commandPoolCreateInfo (vk::CommandPoolCreateFlags (vk::CommandPoolCreateFlagBits::eTransient | vk::CommandPoolCreateFlagBits::eResetCommandBuffer), context.queueFamilyIndex );
376402 context.commandPool = context.device .createCommandPool (commandPoolCreateInfo);
377403 context.queue = context.device .getQueue (context.queueFamilyIndex , 0 );
404+ context.nonCoherentAtomSize = deviceProps.limits .nonCoherentAtomSize ;
378405
379406 VULKAN_TRACE (" Context created" );
380407 data = new NnVulkanDeviceData (&context, netConfig, nodeConfig);
@@ -790,15 +817,17 @@ void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint thre
790817 for (NnUint i = 0 ; i < netConfig->nPreSyncs ; i++) {
791818 NnPreSyncConfig *preSyncConfig = &netConfig->preSyncs [i];
792819 NnByte *pipeData = netExecution->pipes [preSyncConfig->pipeIndex ];
793- data->pipes [preSyncConfig->pipeIndex ]->write (pipeData);
820+ NnVulkanBuffer *buffer = data->pipes [preSyncConfig->pipeIndex ].get ();
821+ buffer->write (pipeData, buffer->calcSliceSize (batchSize, netConfig->nBatches ));
794822 }
795823 }
796824
797825 for (NnUint opIndex = 0 ; opIndex < segmentConfig->nOps ; opIndex++) {
798826 NnOpConfig *opConfig = &segmentConfig->ops [opIndex];
799827 if (opConfig->input .source == SRC_PIPE) {
800828 NnByte *pipeData = netExecution->pipes [opConfig->input .pointerIndex ];
801- data->pipes [opConfig->input .pointerIndex ]->write (pipeData);
829+ NnVulkanBuffer *buffer = data->pipes [opConfig->input .pointerIndex ].get ();
830+ buffer->write (pipeData, buffer->calcSliceSize (batchSize, netConfig->nBatches ));
802831 }
803832 }
804833 }
@@ -856,7 +885,8 @@ void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint thre
856885 NnOpConfig *opConfig = &segmentConfig->ops [opIndex];
857886 if (opConfig->output .source == SRC_PIPE) {
858887 NnByte *pipeData = netExecution->pipes [opConfig->output .pointerIndex ];
859- data->pipes [opConfig->output .pointerIndex ]->read (pipeData);
888+ NnVulkanBuffer *buffer = data->pipes [opConfig->output .pointerIndex ].get ();
889+ buffer->read (pipeData, buffer->calcSliceSize (batchSize, netConfig->nBatches ));
860890 }
861891 }
862892 }
0 commit comments