9
9
10
10
public class Qwen3Kernels {
11
11
12
- //public static void dbgCopy(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) {
13
- public static void dbgCopy (FloatArray srcBuffer , FloatArray dstBuffer , IntArray positioNlayer , int layer ) {
14
- //int position = positioNlayer.get(0);
15
- //if (position == 1) {
12
+ /**
13
+ * For explicit copy out useful in debugging.
14
+ * With this kernel we can store the values of an array to a tmp buffer at a timing of interest.
15
+ * In the end of the taskgraph we copy out the tmp buffer to inspect the array values at the timing of interest.
16
+ * @param srcBuffer the array we want to inspect.
17
+ * @param dstBuffer the tmp buffer.
18
+ */
19
+ public static void dbgCopy (FloatArray srcBuffer , FloatArray dstBuffer ) {
16
20
for (@ Parallel int i = 0 ; i < srcBuffer .getSize (); i ++) {
17
21
dstBuffer .set (i , srcBuffer .get (i ));
18
22
}
19
- //}
20
23
}
21
24
22
- public static void rmsnormReductionWithOffset (
25
+ /**
26
+ * RmsNorm with parallel offset:
27
+ * The following 3 kernels implement rmsnorm in offset range in parallel for qCur and Kcur rmsnorm calculations.
28
+ *
29
+ * Step 1: Reduction.
30
+ * This kernel implements rmsnorm in offset range in parallel for qCur and Kcur rmsnorm calculations.
31
+ */
32
+ public static void rmsnormReductionWithParallelOffset (
23
33
KernelContext context ,
24
34
FloatArray output ,
25
35
FloatArray x ,
26
36
int localMemSize ) {
27
37
28
- // global size: 0 - (config.numberOfHeads() * nEmbdHead)
29
- // local size : 0 - nEmbdHead
30
38
int gid = context .globalIdx ;
31
39
int lid = context .localIdx ;
32
40
int groupId = context .groupIdx ;
@@ -36,13 +44,8 @@ public static void rmsnormReductionWithOffset(
36
44
float [] localX = context .allocateFloatLocalArray (localMemSize );
37
45
38
46
// Load input value and compute square
39
- //int globalReadIndex = gid + offset;
40
- //if (gid < size && globalReadIndex < x.getSize()) {
41
- localX [lid ] = x .get (gid );
42
- localX [lid ] = localX [lid ] * localX [lid ];
43
- //} else {
44
- // localX[lid] = 0.0f;
45
- //}
47
+ localX [lid ] = x .get (gid );
48
+ localX [lid ] = localX [lid ] * localX [lid ];
46
49
47
50
// Perform parallel reduction within the work group
48
51
for (int stride = (groupSize / 2 ); stride > 0 ; stride /= 2 ) {
@@ -59,7 +62,11 @@ public static void rmsnormReductionWithOffset(
59
62
}
60
63
}
61
64
62
- // Second kernel - Combines partial sums and computes final normalization
65
+ /**
66
+ * RmsNorm with parallel offset:
67
+ *
68
+ * Step 2: Combines partial reduction outputs and computes final normalization.
69
+ */
63
70
public static void rmsnormFinalNormalizationWithParallelOffset (
64
71
KernelContext context ,
65
72
FloatArray output , // size should be related to offsetIndex
@@ -72,12 +79,7 @@ public static void rmsnormFinalNormalizationWithParallelOffset(
72
79
// Only the index threads need to perform this calculation
73
80
if (gid < offsetIndex ) {
74
81
// Combine partial sums from all workgroups
75
- float ss = 0.0f ;
76
- //for (int i = 1; i < output.getSize(); i++) { // Fixed bounds to avoid out of bounds
77
- // for (int i = 1; i < output.getSize(); i++) { // Fixed bounds to avoid out of bounds
78
- // ss += output.get(i);
79
- // }
80
- ss = output .get (gid );
82
+ float ss = output .get (gid );
81
83
82
84
ss /= size ;
83
85
ss += ermsNorm ;
@@ -87,36 +89,28 @@ public static void rmsnormFinalNormalizationWithParallelOffset(
87
89
}
88
90
}
89
91
92
+ /**
93
+ * RmsNorm with parallel offset:
94
+ *
95
+ * Step 3: perform mapIndex operation.
96
+ */
90
97
public static void rmsnormMapIndexInPlaceWithParallelOffset (
91
98
KernelContext context ,
92
- FloatArray out , // Q
99
+ FloatArray out ,
93
100
FloatArray weights ,
94
101
int size ,
95
- FloatArray ss // tempQcur1
96
- ) {
102
+ FloatArray ss ) {
97
103
98
- int gid = context .globalIdx ; // 0 - size
99
- //int index = offset + gid;
104
+ int gid = context .globalIdx ;
100
105
int groupId = context .groupIdx ;
101
106
102
107
float finalss = ss .get (groupId );
103
- //out.set(index, weights.get(index % size) * (finalss * x.get(index)));
104
- //out.set(index, weights.get(index) * (finalss * x.get(index)));
105
- //if (index < offset + size) {
108
+
106
109
if (gid < out .getSize ()) { // TODO: check if redundant
107
110
float a = weights .get (gid % size );
108
111
float b = finalss * out .get (gid );
109
112
out .set (gid , a * b );
110
113
}
111
-
112
- //old gid, index:
113
- // int gid = context.globalIdx; // 0 - size
114
- // int index = offset + gid;
115
- // context.globalBarrier();
116
- // // reset ss
117
- // if (gid < ss.getSize()) {
118
- // ss.set(gid, 0.0f);
119
- // }
120
114
}
121
115
122
116
/**
@@ -162,92 +156,12 @@ public static void rmsnormWithParallelOffset(
162
156
}
163
157
}
164
158
165
- public static void reductionOneBlockWithLayerWithOffset (
166
- KernelContext context ,
167
- FloatArray output ,
168
- FloatArray x ,
169
- int offset ,
170
- int size ,
171
- float ermsNorm ,
172
- int localMemSize ) {
173
-
174
- int gid = context .globalIdx ; // 0 - nEmbHead = 128
175
- int lid = context .localIdx ; // 0 - state.localsize [
176
- int groupId = context .groupIdx ;
177
- int groupSize = context .localGroupSizeX ;
178
-
179
- // Allocate local memory with the provided size
180
- float [] localX = context .allocateFloatLocalArray (localMemSize );
181
-
182
- // Load input value and compute square
183
- int globalReadIndex = gid + offset ;
184
- if (gid < size && globalReadIndex < x .getSize ()) {
185
- localX [lid ] = x .get (globalReadIndex );
186
- localX [lid ] = localX [lid ] * localX [lid ];
187
- } else {
188
- localX [lid ] = 0.0f ;
189
- }
190
-
191
- // Perform parallel reduction within the work group
192
- for (int stride = (groupSize / 2 ); stride > 0 ; stride /= 2 ) {
193
- context .localBarrier ();
194
- if (lid < stride ) {
195
- localX [lid ] += localX [lid + stride ];
196
- }
197
- }
198
-
199
- // Each workgroup stores its partial sum in a different location
200
- if (lid == 0 ) {
201
- // Store the partial sum from each workgroup
202
- output .set (groupId + 1 , localX [0 ]);
203
- }
204
-
205
- // // Only the first thread in the first workgroup computes the final normalization factor
206
- // if (gid == 0) {
207
- // // Combine partial sums from all workgroups
208
- // float ss = 0.0f;
209
- // for (int i = 1; i <= (size / localMemSize); i++) { // Assuming 8 workgroups
210
- // ss += output.get(i);
211
- // }
212
- //
213
- // ss /= size;
214
- // ss += ermsNorm;
215
- // ss = 1.0f / TornadoMath.sqrt(ss);
216
- // output.set(0, ss); // Store the final scale factor
217
- // }
218
- }
219
-
220
- /**
221
- * Normalize and scale (in-place) of rmsnorm operation.
222
- */
223
- public static void mapIndexInPlace (KernelContext context , FloatArray out , /*FloatArray x,*/ FloatArray weights , int offset , int size , FloatArray ss ) {
224
- int gid = context .globalIdx ; // 0 - size
225
- int index = offset + gid ;
226
-
227
- float finalss = ss .get (0 );
228
- //out.set(index, weights.get(index % size) * (finalss * x.get(index)));
229
- //out.set(index, weights.get(index) * (finalss * x.get(index)));
230
- //if (index < offset + size) {
231
- if (index < out .getSize ()) { // TODO: check if redundant
232
- float a = weights .get (index % size );
233
- float b = finalss * out .get (index );
234
- out .set (index , a * b );
235
- }
236
-
237
- context .globalBarrier ();
238
- // reset ss
239
- if (gid < ss .getSize ()) {
240
- ss .set (gid , 0.0f );
241
- }
242
- }
243
-
244
159
public static void ropeRotation (KernelContext context ,
245
160
IntArray position ,
246
161
FloatArray q ,
247
162
FloatArray k ,
248
163
int numberOfKeyValueHeads ,
249
164
int nEmbdHead ) {
250
- //System.out.println("ropeRotationSplit");
251
165
int h = context .globalIdx ;
252
166
int ic = context .globalIdy ;
253
167
@@ -256,7 +170,6 @@ public static void ropeRotation(KernelContext context,
256
170
int nComplEmbdHead = nEmbdHead / 2 ;
257
171
258
172
// Compute RoPE frequencies for Qwen3
259
- //float freq = 1.0f / TornadoMath.pow(10000.0f, (2.0f * ic) / (float) nEmbdHead);
260
173
float theta = 1000000.0f ;
261
174
int i = ic * 2 ; // match i in precompute (see RoPE.precomputeFreqsCis)
262
175
float freq = 1.0f / TornadoMath .pow (theta , (float )i / (float )nEmbdHead );
@@ -290,13 +203,11 @@ public static void processHeadsParallel(
290
203
int nEmbdHeadV , /* = config.numberOfHeadsValue(), replace headSize in lines: 266, 268, 273 */
291
204
int nEmbdGqa , /* kvDim */
292
205
int gqa , /* kvMul */
293
- int seqLen ,
294
206
IntArray positionHolder ,
295
207
FloatArray wrapAtt ,
296
208
int layer , int contextLength ) {
297
209
298
210
int pos = positionHolder .get (0 );
299
- //int loff = layer * contextLength * kvDim;
300
211
int loff = layer * contextLength * nEmbdGqa ;
301
212
302
213
// Parallelize computation across attention heads
@@ -332,22 +243,16 @@ private static void processHeadTornado(
332
243
333
244
// Base index for this head's attention weights
334
245
int headOffset = h * (pos + 1 );
335
- //int headOffset = h * contextLength;
336
246
337
247
// STEP 1: Calculate attention scores for all timesteps
338
248
for (int t = 0 ; t <= pos ; t ++) {
339
- //int kvHeadIdx = h / kvMul;
340
249
int kvHeadIdx = h / gqa ;
341
- //int keyOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize);
342
250
int keyOffset = (int ) (loff + t * nEmbdGqa + kvHeadIdx * nEmbdHeadK ); // line 255
343
251
344
252
float score = 0.0f ;
345
- //for (int i = 0; i < headSize; i++) {
346
253
for (int i = 0 ; i < nEmbdHeadK ; i ++) {
347
- //score += allQ.get(h * headSize + i) * key_cache.get(keyOffset + i);
348
254
score += allQ .get (h * nEmbdHeadK + i ) * key_cache .get (keyOffset + i ); // line 255
349
255
}
350
- //score = score / TornadoMath.sqrt(headSize);
351
256
score = score / TornadoMath .sqrt (nEmbdHead ); // line 257
352
257
353
258
// Store in attention buffer
@@ -380,28 +285,24 @@ private static void processHeadTornado(
380
285
}
381
286
382
287
// STEP 5: Compute weighted sum of values for each dimension
383
- //for (int i = 0; i < headSize; i++) {
384
288
for (int i = 0 ; i < nEmbdHeadV ; i ++) {
385
289
float weightedSum = 0.0f ;
386
290
for (int t = 0 ; t <= pos ; t ++) {
387
- //int kvHeadIdx = h / kvMul;
388
291
int kvHeadIdx = h / gqa ;
389
- //int valueOffset = (int) (loff + t * kvDim + kvHeadIdx * headSize);
390
292
int valueOffset = (int ) (loff + t * nEmbdGqa + kvHeadIdx * nEmbdHeadV ); //line 273
391
293
weightedSum += wrapAtt .get (headOffset + t ) * value_cache .get (valueOffset + i );
392
294
}
393
- //allXb.set(h * headSize + i, weightedSum);
394
295
allXb .set (h * nEmbdHeadV + i , weightedSum ); // offset from line 266
395
296
}
396
297
}
397
298
398
299
public static void matrixVectorGenericWithResidual (
399
300
KernelContext context ,
400
- FloatArray v , // vector = [2048]
401
- FloatArray out , // out = [1024]
402
- HalfFloatArray m , // matrix = [2048, 1024]
403
- int dim1 , // dim1 = 2048, vectorSize
404
- int dim0 , // dim0 = 1024, outputSize
301
+ FloatArray v ,
302
+ FloatArray out ,
303
+ HalfFloatArray m ,
304
+ int dim1 ,
305
+ int dim0 ,
405
306
int localWorkGroupSize ) {
406
307
407
308
// One row per workgroup (not per thread)
@@ -431,8 +332,8 @@ public static float matrixVectorRowMajorOptimized(
431
332
int dim1 ,
432
333
int dim0
433
334
) {
434
- int rowId = context .groupIdx ; // 0-dim
435
- int localId = context .localIdx ; // 0-32
335
+ int rowId = context .groupIdx ;
336
+ int localId = context .localIdx ;
436
337
437
338
// Allocate local memory for reduction
438
339
float [] localSum = context .allocateFloatLocalArray (localSize );
@@ -444,48 +345,6 @@ public static float matrixVectorRowMajorOptimized(
444
345
for (int j = localId ; j < dim1 ; j += localSize ) {
445
346
int matrixIdx = rowOffset + j ;
446
347
partialSum += m .get (matrixIdx ).getFloat32 () * v .get (j );
447
- //partialSum += w.get(rowOffset + j).getFloat32() * x.get(j);
448
- }
449
-
450
- // Store partial sum in local memory
451
- localSum [localId ] = partialSum ;
452
- context .localBarrier ();
453
-
454
- // Parallel reduction within workgroup
455
- for (int stride = localSize / 2 ; stride > 0 ; stride >>= 1 ) {
456
- if (localId < stride ) {
457
- localSum [localId ] += localSum [localId + stride ];
458
- }
459
- context .localBarrier ();
460
- }
461
-
462
- return localSum [0 ];
463
- }
464
-
465
- public static float matrixVectorRowMajorOptimized2 (
466
- KernelContext context ,
467
- int localSize ,
468
- FloatArray v , // input vector [2048]
469
- HalfFloatArray m , // matrix [2048, 1024]
470
- int vectorSize , // 2048
471
- int outputSize ,
472
- int rowId // which output row we're computing (0-1023)
473
- ) {
474
- int localId = context .localIdx ; // 0 to localSize-1
475
-
476
- // Allocate local memory for reduction
477
- float [] localSum = context .allocateFloatLocalArray (localSize );
478
-
479
- // For matrix [2048, 1024], if we want row 'rowId' of the OUTPUT,
480
- // we need to compute dot product of INPUT vector with COLUMN 'rowId' of the matrix
481
- // Matrix element [i][j] is at index i * outputSize + j
482
- // We want column 'rowId', so elements are at: 0*outputSize + rowId, 1*outputSize + rowId, etc.
483
-
484
- // Each thread calculates partial dot product
485
- float partialSum = 0.0f ;
486
- for (int i = localId ; i < vectorSize ; i += localSize ) {
487
- int matrixIdx = i * outputSize + rowId ; // Column-wise access for row rowId
488
- partialSum += m .get (matrixIdx ).getFloat32 () * v .get (i );
489
348
}
490
349
491
350
// Store partial sum in local memory
0 commit comments