Skip to content

Commit 6224ae0

Browse files
Cleanup Qwen3Kernels
1 parent e1a4632 commit 6224ae0

File tree

1 file changed

+40
-181
lines changed

1 file changed

+40
-181
lines changed

src/main/java/com/example/tornadovm/Qwen3Kernels.java

Lines changed: 40 additions & 181 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,32 @@
99

1010
public class Qwen3Kernels {
1111

12-
//public static void dbgCopy(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) {
13-
public static void dbgCopy(FloatArray srcBuffer, FloatArray dstBuffer, IntArray positioNlayer, int layer) {
14-
//int position = positioNlayer.get(0);
15-
//if (position == 1) {
12+
/**
13+
* For explicit copy out useful in debugging.
14+
* With this kernel we can store the values of an array to a tmp buffer at a timing of interest.
15+
* In the end of the taskgraph we copy out the tmp buffer to inspect the array values at the timing of interest.
16+
* @param srcBuffer the array we want to inspect.
17+
* @param dstBuffer the tmp buffer.
18+
*/
19+
public static void dbgCopy(FloatArray srcBuffer, FloatArray dstBuffer) {
1620
for (@Parallel int i = 0; i < srcBuffer.getSize(); i++) {
1721
dstBuffer.set(i, srcBuffer.get(i));
1822
}
19-
//}
2023
}
2124

22-
public static void rmsnormReductionWithOffset(
25+
/**
26+
* RmsNorm with parallel offset:
27+
* The following 3 kernels implement rmsnorm in offset range in parallel for qCur and Kcur rmsnorm calculations.
28+
*
29+
* Step 1: Reduction.
30+
* This kernel implements rmsnorm in offset range in parallel for qCur and Kcur rmsnorm calculations.
31+
*/
32+
public static void rmsnormReductionWithParallelOffset(
2333
KernelContext context,
2434
FloatArray output,
2535
FloatArray x,
2636
int localMemSize) {
2737

28-
// global size: 0 - (config.numberOfHeads() * nEmbdHead)
29-
// local size : 0 - nEmbdHead
3038
int gid = context.globalIdx;
3139
int lid = context.localIdx;
3240
int groupId = context.groupIdx;
@@ -36,13 +44,8 @@ public static void rmsnormReductionWithOffset(
3644
float[] localX = context.allocateFloatLocalArray(localMemSize);
3745

3846
// Load input value and compute square
39-
//int globalReadIndex = gid + offset;
40-
//if (gid < size && globalReadIndex < x.getSize()) {
41-
localX[lid] = x.get(gid);
42-
localX[lid] = localX[lid] * localX[lid];
43-
//} else {
44-
// localX[lid] = 0.0f;
45-
//}
47+
localX[lid] = x.get(gid);
48+
localX[lid] = localX[lid] * localX[lid];
4649

4750
// Perform parallel reduction within the work group
4851
for (int stride = (groupSize / 2); stride > 0; stride /= 2) {
@@ -59,7 +62,11 @@ public static void rmsnormReductionWithOffset(
5962
}
6063
}
6164

62-
// Second kernel - Combines partial sums and computes final normalization
65+
/**
66+
* RmsNorm with parallel offset:
67+
*
68+
* Step 2: Combines partial reduction outputs and computes final normalization.
69+
*/
6370
public static void rmsnormFinalNormalizationWithParallelOffset(
6471
KernelContext context,
6572
FloatArray output, // size should be related to offsetIndex
@@ -72,12 +79,7 @@ public static void rmsnormFinalNormalizationWithParallelOffset(
7279
// Only the index threads need to perform this calculation
7380
if (gid < offsetIndex) {
7481
// Combine partial sums from all workgroups
75-
float ss = 0.0f;
76-
//for (int i = 1; i < output.getSize(); i++) { // Fixed bounds to avoid out of bounds
77-
// for (int i = 1; i < output.getSize(); i++) { // Fixed bounds to avoid out of bounds
78-
// ss += output.get(i);
79-
// }
80-
ss = output.get(gid);
82+
float ss = output.get(gid);
8183

8284
ss /= size;
8385
ss += ermsNorm;
@@ -87,36 +89,28 @@ public static void rmsnormFinalNormalizationWithParallelOffset(
8789
}
8890
}
8991

92+
/**
93+
* RmsNorm with parallel offset:
94+
*
95+
* Step 3: perform mapIndex operation.
96+
*/
9097
public static void rmsnormMapIndexInPlaceWithParallelOffset(
9198
KernelContext context,
92-
FloatArray out, // Q
99+
FloatArray out,
93100
FloatArray weights,
94101
int size,
95-
FloatArray ss // tempQcur1
96-
) {
102+
FloatArray ss) {
97103

98-
int gid = context.globalIdx; // 0 - size
99-
//int index = offset + gid;
104+
int gid = context.globalIdx;
100105
int groupId = context.groupIdx;
101106

102107
float finalss = ss.get(groupId);
103-
//out.set(index, weights.get(index % size) * (finalss * x.get(index)));
104-
//out.set(index, weights.get(index) * (finalss * x.get(index)));
105-
//if (index < offset + size) {
108+
106109
if (gid < out.getSize()) { // TODO: check if redundant
107110
float a = weights.get(gid % size);
108111
float b = finalss * out.get(gid);
109112
out.set(gid, a * b);
110113
}
111-
112-
//old gid, index:
113-
// int gid = context.globalIdx; // 0 - size
114-
// int index = offset + gid;
115-
// context.globalBarrier();
116-
// // reset ss
117-
// if (gid < ss.getSize()) {
118-
// ss.set(gid, 0.0f);
119-
// }
120114
}
121115

122116
/**
@@ -162,92 +156,12 @@ public static void rmsnormWithParallelOffset(
162156
}
163157
}
164158

165-
public static void reductionOneBlockWithLayerWithOffset(
166-
KernelContext context,
167-
FloatArray output,
168-
FloatArray x,
169-
int offset,
170-
int size,
171-
float ermsNorm,
172-
int localMemSize) {
173-
174-
int gid = context.globalIdx; // 0 - nEmbHead = 128
175-
int lid = context.localIdx; // 0 - state.localsize [
176-
int groupId = context.groupIdx;
177-
int groupSize = context.localGroupSizeX;
178-
179-
// Allocate local memory with the provided size
180-
float[] localX = context.allocateFloatLocalArray(localMemSize);
181-
182-
// Load input value and compute square
183-
int globalReadIndex = gid + offset;
184-
if (gid < size && globalReadIndex < x.getSize()) {
185-
localX[lid] = x.get(globalReadIndex);
186-
localX[lid] = localX[lid] * localX[lid];
187-
} else {
188-
localX[lid] = 0.0f;
189-
}
190-
191-
// Perform parallel reduction within the work group
192-
for (int stride = (groupSize / 2); stride > 0; stride /= 2) {
193-
context.localBarrier();
194-
if (lid < stride) {
195-
localX[lid] += localX[lid + stride];
196-
}
197-
}
198-
199-
// Each workgroup stores its partial sum in a different location
200-
if (lid == 0) {
201-
// Store the partial sum from each workgroup
202-
output.set(groupId + 1, localX[0]);
203-
}
204-
205-
// // Only the first thread in the first workgroup computes the final normalization factor
206-
// if (gid == 0) {
207-
// // Combine partial sums from all workgroups
208-
// float ss = 0.0f;
209-
// for (int i = 1; i <= (size / localMemSize); i++) { // Assuming 8 workgroups
210-
// ss += output.get(i);
211-
// }
212-
//
213-
// ss /= size;
214-
// ss += ermsNorm;
215-
// ss = 1.0f / TornadoMath.sqrt(ss);
216-
// output.set(0, ss); // Store the final scale factor
217-
// }
218-
}
219-
220-
/**
221-
* Normalize and scale (in-place) of rmsnorm operation.
222-
*/
223-
public static void mapIndexInPlace(KernelContext context, FloatArray out, /*FloatArray x,*/ FloatArray weights, int offset, int size, FloatArray ss) {
224-
int gid = context.globalIdx; // 0 - size
225-
int index = offset + gid;
226-
227-
float finalss = ss.get(0);
228-
//out.set(index, weights.get(index % size) * (finalss * x.get(index)));
229-
//out.set(index, weights.get(index) * (finalss * x.get(index)));
230-
//if (index < offset + size) {
231-
if (index < out.getSize()) { // TODO: check if redundant
232-
float a = weights.get(index % size);
233-
float b = finalss * out.get(index);
234-
out.set(index, a * b);
235-
}
236-
237-
context.globalBarrier();
238-
// reset ss
239-
if (gid < ss.getSize()) {
240-
ss.set(gid, 0.0f);
241-
}
242-
}
243-
244159
public static void ropeRotation(KernelContext context,
245160
IntArray position,
246161
FloatArray q,
247162
FloatArray k,
248163
int numberOfKeyValueHeads,
249164
int nEmbdHead) {
250-
//System.out.println("ropeRotationSplit");
251165
int h = context.globalIdx;
252166
int ic = context.globalIdy;
253167

@@ -256,7 +170,6 @@ public static void ropeRotation(KernelContext context,
256170
int nComplEmbdHead = nEmbdHead / 2;
257171

258172
// Compute RoPE frequencies for Qwen3
259-
//float freq = 1.0f / TornadoMath.pow(10000.0f, (2.0f * ic) / (float) nEmbdHead);
260173
float theta = 1000000.0f;
261174
int i = ic * 2; // match i in precompute (see RoPE.precomputeFreqsCis)
262175
float freq = 1.0f / TornadoMath.pow(theta, (float)i / (float)nEmbdHead);
@@ -290,13 +203,11 @@ public static void processHeadsParallel(
290203
int nEmbdHeadV, /* = config.numberOfHeadsValue(), replace headSize in lines: 266, 268, 273 */
291204
int nEmbdGqa, /* kvDim */
292205
int gqa, /* kvMul */
293-
int seqLen,
294206
IntArray positionHolder,
295207
FloatArray wrapAtt,
296208
int layer, int contextLength) {
297209

298210
int pos = positionHolder.get(0);
299-
//int loff = layer * contextLength * kvDim;
300211
int loff = layer * contextLength * nEmbdGqa;
301212

302213
// Parallelize computation across attention heads
@@ -332,22 +243,16 @@ private static void processHeadTornado(
332243

333244
// Base index for this head's attention weights
334245
int headOffset = h * (pos + 1);
335-
//int headOffset = h * contextLength;
336246

337247
// STEP 1: Calculate attention scores for all timesteps
338248
for (int t = 0; t <= pos; t++) {
339-
//int kvHeadIdx = h / kvMul;
340249
int kvHeadIdx = h / gqa;
341-
//int keyOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize);
342250
int keyOffset = (int) (loff + t * nEmbdGqa + kvHeadIdx * nEmbdHeadK); // line 255
343251

344252
float score = 0.0f;
345-
//for (int i = 0; i < headSize; i++) {
346253
for (int i = 0; i < nEmbdHeadK; i++) {
347-
//score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i);
348254
score += allQ.get(h * nEmbdHeadK + i) * key_cache.get(keyOffset + i); // line 255
349255
}
350-
//score = score / TornadoMath.sqrt(headSize);
351256
score = score / TornadoMath.sqrt(nEmbdHead); // line 257
352257

353258
// Store in attention buffer
@@ -380,28 +285,24 @@ private static void processHeadTornado(
380285
}
381286

382287
// STEP 5: Compute weighted sum of values for each dimension
383-
//for (int i = 0; i < headSize; i++) {
384288
for (int i = 0; i < nEmbdHeadV; i++) {
385289
float weightedSum = 0.0f;
386290
for (int t = 0; t <= pos; t++) {
387-
//int kvHeadIdx = h / kvMul;
388291
int kvHeadIdx = h / gqa;
389-
//int valueOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize);
390292
int valueOffset = (int) (loff + t * nEmbdGqa + kvHeadIdx * nEmbdHeadV); //line 273
391293
weightedSum += wrapAtt.get(headOffset + t) * value_cache.get(valueOffset + i);
392294
}
393-
//allXb.set(h * headSize + i, weightedSum);
394295
allXb.set(h * nEmbdHeadV + i, weightedSum); // offset from line 266
395296
}
396297
}
397298

398299
public static void matrixVectorGenericWithResidual(
399300
KernelContext context,
400-
FloatArray v, // vector = [2048]
401-
FloatArray out, // out = [1024]
402-
HalfFloatArray m, // matrix = [2048, 1024]
403-
int dim1, // dim1 = 2048, vectorSize
404-
int dim0, // dim0 = 1024, outputSize
301+
FloatArray v,
302+
FloatArray out,
303+
HalfFloatArray m,
304+
int dim1,
305+
int dim0,
405306
int localWorkGroupSize) {
406307

407308
// One row per workgroup (not per thread)
@@ -431,8 +332,8 @@ public static float matrixVectorRowMajorOptimized(
431332
int dim1,
432333
int dim0
433334
) {
434-
int rowId = context.groupIdx; // 0-dim
435-
int localId = context.localIdx; // 0-32
335+
int rowId = context.groupIdx;
336+
int localId = context.localIdx;
436337

437338
// Allocate local memory for reduction
438339
float[] localSum = context.allocateFloatLocalArray(localSize);
@@ -444,48 +345,6 @@ public static float matrixVectorRowMajorOptimized(
444345
for (int j = localId; j < dim1; j += localSize) {
445346
int matrixIdx = rowOffset + j;
446347
partialSum += m.get(matrixIdx).getFloat32() * v.get(j);
447-
//partialSum += w.get(rowOffset + j).getFloat32() * x.get(j);
448-
}
449-
450-
// Store partial sum in local memory
451-
localSum[localId] = partialSum;
452-
context.localBarrier();
453-
454-
// Parallel reduction within workgroup
455-
for (int stride = localSize / 2; stride > 0; stride >>= 1) {
456-
if (localId < stride) {
457-
localSum[localId] += localSum[localId + stride];
458-
}
459-
context.localBarrier();
460-
}
461-
462-
return localSum[0];
463-
}
464-
465-
public static float matrixVectorRowMajorOptimized2(
466-
KernelContext context,
467-
int localSize,
468-
FloatArray v, // input vector [2048]
469-
HalfFloatArray m, // matrix [2048, 1024]
470-
int vectorSize, // 2048
471-
int outputSize,
472-
int rowId // which output row we're computing (0-1023)
473-
) {
474-
int localId = context.localIdx; // 0 to localSize-1
475-
476-
// Allocate local memory for reduction
477-
float[] localSum = context.allocateFloatLocalArray(localSize);
478-
479-
// For matrix [2048, 1024], if we want row 'rowId' of the OUTPUT,
480-
// we need to compute dot product of INPUT vector with COLUMN 'rowId' of the matrix
481-
// Matrix element [i][j] is at index i * outputSize + j
482-
// We want column 'rowId', so elements are at: 0*outputSize + rowId, 1*outputSize + rowId, etc.
483-
484-
// Each thread calculates partial dot product
485-
float partialSum = 0.0f;
486-
for (int i = localId; i < vectorSize; i += localSize) {
487-
int matrixIdx = i * outputSize + rowId; // Column-wise access for row rowId
488-
partialSum += m.get(matrixIdx).getFloat32() * v.get(i);
489348
}
490349

491350
// Store partial sum in local memory

0 commit comments

Comments
 (0)