Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/llm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,10 @@ LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches) {
const NnUint yBufferIndex = nodeBuilder.addBuffer("y", size2D(F_32, nBatches, h->dim));
const NnUint yqBufferIndex = h->syncType == F_32
? yBufferIndex
: nodeBuilder.addBuffer("yq", size2D(h->syncType, nBatches, h->dim));
: nodeBuilder.addBuffer("q_y", size2D(h->syncType, nBatches, h->dim));

const NnUint zBufferIndex = nodeBuilder.addBuffer("z", size2D(F_32, nBatches, h->qDim));
const NnUint zqSliceBufferIndex = nodeBuilder.addBuffer("zq_slice", size2D(h->syncType, nBatches, h->qDim / nNodes));
const NnUint zqSliceBufferIndex = nodeBuilder.addBuffer("q_z_slice", size2D(h->syncType, nBatches, h->qDim / nNodes));

const NnUint qBufferIndex = nodeBuilder.addBuffer("q", size2D(F_32, nBatches, n.qSlice.d0));
const NnUint kTempBufferIndex = nodeBuilder.addBuffer("k_temp", size2D(F_32, nBatches, n.kSlice.d0));
Expand All @@ -201,7 +201,7 @@ LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches) {
const NnUint dBufferIndex = nodeBuilder.addBuffer("d", size2D(F_32, nBatches, n.w1Slice.d0));
const NnUint dqBufferIndex = h->syncType == F_32
? dBufferIndex
: nodeBuilder.addBuffer("d", size2D(h->syncType, nBatches, n.w1Slice.d0));
: nodeBuilder.addBuffer("q_d", size2D(h->syncType, nBatches, n.w1Slice.d0));
const NnUint lBufferIndex = nodeBuilder.addBuffer("l", size2D(F_32, nBatches, n.w3Slice.d0));
const NnUint ropeCacheBufferIndex = nodeBuilder.addBuffer("rope_cache", ropeSlice.cacheSize);
const NnUint attBufferIndex = nodeBuilder.addBuffer("att", multiHeadAttSlice.attSize);
Expand Down
40 changes: 34 additions & 6 deletions src/nn/nn-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#define VULKAN_TRACE(...)
#endif

#define DEBUG_VULKAN_BUFFERS false

static bool hasPortabilityExtension() {
#ifdef __APPLE__
const std::vector<vk::ExtensionProperties> extensionProperties = vk::enumerateInstanceExtensionProperties();
Expand Down Expand Up @@ -120,8 +122,9 @@ void NnVulkanStagingCopy::executeCopyCommand() {
context->device.freeCommandBuffers(context->commandPool, 1, &commandBuffer);
}

NnVulkanBuffer::NnVulkanBuffer(NnVulkanContext *context, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess) {
NnVulkanBuffer::NnVulkanBuffer(NnVulkanContext *context, const char *name, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess) {
this->context = context;
this->name = name;
this->bufferSize = bufferSize;
this->usageFlags = usageFlags;
this->hostPointer = nullptr;
Expand Down Expand Up @@ -220,9 +223,9 @@ NnVulkanDeviceData::NnVulkanDeviceData(NnVulkanContext *context, NnNetConfig *ne
this->nodeConfig = nodeConfig;

for (NnUint i = 0; i < netConfig->nPipes; i++)
pipes[i].reset(new NnVulkanBuffer(context, netConfig->pipes[i].size.nBytes, vk::BufferUsageFlagBits::eStorageBuffer, true));
pipes[i].reset(new NnVulkanBuffer(context, netConfig->pipes[i].name, netConfig->pipes[i].size.nBytes, vk::BufferUsageFlagBits::eStorageBuffer, true));
for (NnUint i = 0; i < nodeConfig->nBuffers; i++)
buffers[i].reset(new NnVulkanBuffer(context, nodeConfig->buffers[i].size.nBytes, vk::BufferUsageFlagBits::eStorageBuffer, false));
buffers[i].reset(new NnVulkanBuffer(context, nodeConfig->buffers[i].name, nodeConfig->buffers[i].size.nBytes, vk::BufferUsageFlagBits::eStorageBuffer, false));

NnRopeOpConfig *ropeLlamaOpConfig = (NnRopeOpConfig *)findFirstOpConfig(nodeConfig, OP_ROPE);
if (ropeLlamaOpConfig != nullptr) {
Expand Down Expand Up @@ -561,18 +564,18 @@ NnVulkanDeviceSegmentData::NnVulkanDeviceSegmentData(NnVulkanContext *context, N

std::vector<NnVulkanBatchInfo> batchInfo = buildBatchInfo(opConfig, data, nBatches);
NnSize batchInfoSize = sizeof(NnVulkanBatchInfo) * batchInfo.size();
NnVulkanBuffer *batchInfoBuffer = new NnVulkanBuffer(context, batchInfoSize, vk::BufferUsageFlagBits::eStorageBuffer, false);
NnVulkanBuffer *batchInfoBuffer = new NnVulkanBuffer(context, "batchInfo", batchInfoSize, vk::BufferUsageFlagBits::eStorageBuffer, false);
data->internalBuffers.push_back(std::unique_ptr<NnVulkanBuffer>(batchInfoBuffer));
batchInfoBuffer->write((NnByte *)batchInfo.data());
batchInfoBufferIndex[opIndex] = data->internalBuffers.size() - 1;

if (opConfig->weightSize.nBytes > 0) {
NnVulkanBuffer *buffer = new NnVulkanBuffer(context, opConfig->weightSize.nBytes, vk::BufferUsageFlagBits::eStorageBuffer, false);
NnVulkanBuffer *buffer = new NnVulkanBuffer(context, "weights", opConfig->weightSize.nBytes, vk::BufferUsageFlagBits::eStorageBuffer, false);
data->internalBuffers.push_back(std::unique_ptr<NnVulkanBuffer>(buffer));
weightBufferIndex[opIndex] = data->internalBuffers.size() - 1;
}
if (opConfig->configSize > 0) {
NnVulkanBuffer *configBuffer = new NnVulkanBuffer(context, opConfig->configSize, vk::BufferUsageFlagBits::eUniformBuffer, false);
NnVulkanBuffer *configBuffer = new NnVulkanBuffer(context, "config", opConfig->configSize, vk::BufferUsageFlagBits::eUniformBuffer, false);
data->internalBuffers.push_back(std::unique_ptr<NnVulkanBuffer>(configBuffer));
configBuffer->write(opConfig->config);
configBufferIndex[opIndex] = data->internalBuffers.size() - 1;
Expand Down Expand Up @@ -854,4 +857,29 @@ void NnVulkanDeviceSegment::forward(NnUint opIndex, NnUint nThreads, NnUint thre
}
}
}

#if DEBUG_VULKAN_BUFFERS
NnUint nBuffers = data->buffers.size();
for (NnUint i = 0; i < nBuffers; i++) {
NnVulkanBuffer *buffer = data->buffers[i].get();
printf("[%3d:%3d:%10s] ", segmentIndex, i, buffer->name);
std::vector<NnByte> data(buffer->bufferSize);
buffer->read(data.data());
if (strncmp(buffer->name, "q_", 2) == 0) {
NnUint nBytes = 32;
if (buffer->bufferSize < nBytes)
nBytes = buffer->bufferSize;
for (NnUint j = 0; j < nBytes; j++)
printf(" %x", data.data()[j]);
} else {
NnUint nNumbers = buffer->bufferSize / sizeof(float);
if (nNumbers > 16)
nNumbers = 16;
float *nums = (float *)data.data();
for (NnUint j = 0; j < nNumbers; j++)
printf(" %.4f", nums[j]);
}
printf("\n");
}
#endif
}
3 changes: 2 additions & 1 deletion src/nn/nn-vulkan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@ class NnVulkanBuffer {
vk::DeviceMemory deviceMemory;
void *hostPointer;
public:
const char *name;
vk::DeviceSize bufferSize;
vk::Buffer deviceBuffer;
vk::BufferUsageFlags usageFlags;
NnVulkanBuffer(NnVulkanContext *context, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess);
NnVulkanBuffer(NnVulkanContext *context, const char *name, const vk::DeviceSize bufferSize, vk::BufferUsageFlags usageFlags, bool fastAccess);
~NnVulkanBuffer();
void write(const NnByte *data);
void read(NnByte *data);
Expand Down
Loading