Skip to content

Commit 33189ef

Browse files
authored
fix: rope gpu config memory alignment. (#245)
1 parent 8fbe03f commit 33189ef

File tree

5 files changed

+10
-9
lines changed

5 files changed

+10
-9
lines changed

src/llm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -313,15 +313,15 @@ LlmNet buildLlmNet(LlmHeader *h, NnUint nNodes, NnUint nBatches) {
313313
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
314314
pointerBatchConfig(SRC_BUFFER, qBufferIndex),
315315
size0(),
316-
NnRopeOpConfig{n.header->ropeType, true, n.positionPipeIndex, ropeCacheBufferIndex,
316+
NnRopeOpConfig{n.header->ropeType, 1, n.positionPipeIndex, ropeCacheBufferIndex,
317317
h->ropeScalingFactor, h->ropeScalingLowFreqFactor, h->ropeScalingHighFreqFactory, h->ropeScalingOrigMaxSeqLen,
318318
ropeSlice});
319319
att.addOp(
320320
OP_ROPE, "block_rope_k", layerIndex,
321321
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
322322
pointerBatchConfig(SRC_BUFFER, kTempBufferIndex),
323323
size0(),
324-
NnRopeOpConfig{n.header->ropeType, false, n.positionPipeIndex, ropeCacheBufferIndex,
324+
NnRopeOpConfig{n.header->ropeType, 0, n.positionPipeIndex, ropeCacheBufferIndex,
325325
h->ropeScalingFactor, h->ropeScalingLowFreqFactor, h->ropeScalingHighFreqFactory, h->ropeScalingOrigMaxSeqLen,
326326
ropeSlice});
327327
att.addOp(

src/nn/nn-core.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ typedef struct {
205205

206206
typedef struct {
207207
NnRopeType type;
208-
bool isQ;
208+
NnUint isQ; // Cannot use `bool` here due to GPU memory alignment
209209
NnUint positionPipeIndex;
210210
NnUint ropeCacheBufferIndex;
211211
float ropeScalingFactor;

src/nn/nn-cpu-ops.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,14 +1156,15 @@ static void ropeForward_F32_F32(NnUint nThreads, NnUint threadIndex, NnUint batc
11561156
const NnRopeSlice *slice = &config->slice;
11571157
const float *positions = (float *)context->pipes[config->positionPipeIndex];
11581158
const float *cache = (float *)context->buffers[config->ropeCacheBufferIndex];
1159+
const bool isQ = config->isQ == 1;
11591160

11601161
for (NnUint batchIndex = 0; batchIndex < batchSize; batchIndex++) {
11611162
float *x = (float *)context->input[batchIndex];
11621163
const NnUint pos = (NnUint)positions[batchIndex];
11631164
if (config->type == ROPE_LLAMA || config->type == ROPE_LLAMA3_1)
1164-
ropeLlama_F32(x, cache, config->isQ, pos, slice, nThreads, threadIndex);
1165+
ropeLlama_F32(x, cache, isQ, pos, slice, nThreads, threadIndex);
11651166
else if (config->type == ROPE_FALCON)
1166-
ropeFalcon_F32(x, cache, config->isQ, pos, slice, nThreads, threadIndex);
1167+
ropeFalcon_F32(x, cache, isQ, pos, slice, nThreads, threadIndex);
11671168
else
11681169
throw std::runtime_error("Unsupported rope type");
11691170
}

src/nn/nn-vulkan-test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,7 +440,7 @@ void testRope_F32_F32() {
440440
NnUint xPipeIndex = netBuilder->addPipe("X", size2D(F_32, N_BATCHES, ROPE_DIM));
441441
NnUint posPipeIndex = netBuilder->addPipe("POS", size2D(F_32, N_BATCHES, 1));
442442
NnUint ropeCacheBufferIndex = nodeBuilder->addBuffer("ropeCache", slice.cacheSize);
443-
bool isQ = true;
443+
NnUint isQ = 1;
444444

445445
segmentBuilder->addOp(
446446
OP_ROPE, "rope_llama", 0,

src/nn/vulkan/rope-forward-f32-f32.comp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ layout(binding = 1) writeonly buffer outputBuffer { float y[]; };
3232
layout(binding = 2) readonly buffer batchInfosBuffer { BatchInfo infos[]; };
3333
layout(binding = 3) readonly uniform configBuffer {
3434
uint ropeType;
35-
bool isQ;
35+
uint isQ;
3636
uint positionPipeIndex;
3737
uint ropeCacheBufferIndex;
3838
float ropeScalingFactor;
@@ -56,7 +56,7 @@ void main() {
5656

5757
if (ropeType == 0 || ropeType == 2 /* Llama */) {
5858
sharedOffset = position * slice.sliceDim;
59-
if (isQ) {
59+
if (isQ == 1) {
6060
sharedOffset += slice.qShift;
6161
}
6262
} else if (ropeType == 1 /* Falcon */) {
@@ -70,7 +70,7 @@ void main() {
7070

7171
const uint xOffset = sharedInfo.inputOffset;
7272
const uint yOffset = sharedInfo.outputOffset;
73-
const uint dim0 = isQ ? slice.qDim0 : slice.kvDim0;
73+
const uint dim0 = isQ == 1 ? slice.qDim0 : slice.kvDim0;
7474

7575
if (ropeType == 0 || ropeType == 2 /* Llama */) {
7676
const uint dim0Half = dim0 / 2;

0 commit comments

Comments
 (0)