@@ -34,10 +34,10 @@ public TransformerComputeKernelsLayered() {
34
34
* @param localMemSize Size of local memory allocation (must match work group size)
35
35
*/
36
36
public static void reductionOneBlockWithLayer (KernelContext context , FloatArray output , FloatArray x , int size , float ermsNorm , int localMemSize ) {
37
- int gid = context .globalIdx ; // 0-1024
38
- int lid = context .localIdx ; // 0-256
39
- int groupId = context .groupIdx ; // 0-4
40
- int groupSize = context .localGroupSizeX ; // 256
37
+ int gid = context .globalIdx ;
38
+ int lid = context .localIdx ;
39
+ int groupId = context .groupIdx ;
40
+ int groupSize = context .localGroupSizeX ;
41
41
42
42
// Allocate local memory with the provided size
43
43
float [] localX = context .allocateFloatLocalArray (localMemSize );
@@ -115,8 +115,7 @@ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray
115
115
* @param layer Current transformer layer index
116
116
* @param contextLength Maximum sequence length
117
117
*/
118
- public static void copyToCache (FloatArray destKeyCache , FloatArray srcKey , FloatArray destValueCache , FloatArray srcValue ,
119
- IntArray positioNlayer , int kvDim , int layer , int contextLength ) {
118
+ public static void copyToCache (FloatArray destKeyCache , FloatArray srcKey , FloatArray destValueCache , FloatArray srcValue , IntArray positioNlayer , int kvDim , int layer , int contextLength ) {
120
119
121
120
int position = positioNlayer .get (0 );
122
121
int loff = layer * contextLength * kvDim ;
@@ -195,14 +194,8 @@ public static void ropeRotation(KernelContext context, IntArray positionHolder,
195
194
* @param layer Current transformer layer
196
195
* @param contextLength Maximum context length
197
196
*/
198
- public static void processHeadsParallel (
199
- FloatArray q ,
200
- FloatArray key_cache ,
201
- FloatArray value_cache ,
202
- FloatArray xb ,
203
- int nHeads , int headSize , int kvDim , int kvMul , int seqLen ,
204
- IntArray positionHolder ,
205
- FloatArray wrapAtt , int layer , int contextLength ) {
197
+ public static void processHeadsParallel (FloatArray q , FloatArray key_cache , FloatArray value_cache , FloatArray xb , int nHeads , int headSize , int kvDim , int kvMul , int seqLen ,
198
+ IntArray positionHolder , FloatArray wrapAtt , int layer , int contextLength ) {
206
199
207
200
int pos = positionHolder .get (0 );
208
201
int loff = layer * contextLength * kvDim ;
@@ -663,8 +656,7 @@ public static void matrixVectorGeneric(
663
656
* @param d Output dimension
664
657
* @param localWorkGroupSize Work group size
665
658
*/
666
- public static void matrixVectorGenericWithResidual (KernelContext context , FloatArray x , FloatArray hb , HalfFloatArray w ,
667
- int n , int d , int localWorkGroupSize ) {
659
+ public static void matrixVectorGenericWithResidual (KernelContext context , FloatArray x , FloatArray hb , HalfFloatArray w , int n , int d , int localWorkGroupSize ) {
668
660
// One row per workgroup (not per thread)
669
661
int rowId = context .groupIdx ;
670
662
int localId = context .localIdx ;
@@ -794,8 +786,8 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc
794
786
}
795
787
796
788
public static float matrixVectorRowMajorOptimized (KernelContext context , int localSize , FloatArray x , HalfFloatArray w , int n ) {
797
- int rowId = context .groupIdx ; // 0-dim
798
- int localId = context .localIdx ; // 0-32
789
+ int rowId = context .groupIdx ;
790
+ int localId = context .localIdx ;
799
791
800
792
// Allocate local memory for reduction
801
793
float [] localSum = context .allocateFloatLocalArray (localSize );
0 commit comments