@@ -296,70 +296,4 @@ private static void processHeadTornado(
296
296
}
297
297
}
298
298
299
- public static void matrixVectorGenericWithResidual (
300
- KernelContext context ,
301
- FloatArray v ,
302
- FloatArray out ,
303
- HalfFloatArray m ,
304
- int dim1 ,
305
- int dim0 ,
306
- int localWorkGroupSize ) {
307
-
308
- // One row per workgroup (not per thread)
309
- int rowId = context .groupIdx ;
310
- int localId = context .localIdx ;
311
- int localSize = localWorkGroupSize ;
312
-
313
- // Early exit if this workgroup is beyond our output dimension
314
- if (rowId >= dim0 ) {
315
- return ;
316
- }
317
-
318
- float sum = matrixVectorRowMajorOptimized (context , localSize , v , m , dim1 , dim0 );
319
-
320
- // Thread 0 in each workgroup writes the final result
321
- if (localId == 0 ) {
322
- float result = out .get (rowId ) + sum ;
323
- out .set (rowId , result );
324
- }
325
- }
326
-
327
- public static float matrixVectorRowMajorOptimized (
328
- KernelContext context ,
329
- int localSize ,
330
- FloatArray v ,
331
- HalfFloatArray m ,
332
- int dim1 ,
333
- int dim0
334
- ) {
335
- int rowId = context .groupIdx ;
336
- int localId = context .localIdx ;
337
-
338
- // Allocate local memory for reduction
339
- float [] localSum = context .allocateFloatLocalArray (localSize );
340
-
341
- int rowOffset = rowId * dim1 ;
342
-
343
- // Each thread calculates partial dot product
344
- float partialSum = 0.0f ;
345
- for (int j = localId ; j < dim1 ; j += localSize ) {
346
- int matrixIdx = rowOffset + j ;
347
- partialSum += m .get (matrixIdx ).getFloat32 () * v .get (j );
348
- }
349
-
350
- // Store partial sum in local memory
351
- localSum [localId ] = partialSum ;
352
- context .localBarrier ();
353
-
354
- // Parallel reduction within workgroup
355
- for (int stride = localSize / 2 ; stride > 0 ; stride >>= 1 ) {
356
- if (localId < stride ) {
357
- localSum [localId ] += localSum [localId + stride ];
358
- }
359
- context .localBarrier ();
360
- }
361
-
362
- return localSum [0 ];
363
- }
364
-
365
299
}
0 commit comments