Skip to content

Commit 7dc5056

Browse files
General cleanup
1 parent 5b5e9e3 commit 7dc5056

File tree

1 file changed

+10
-18
lines changed

1 file changed

+10
-18
lines changed

src/main/java/com/example/tornadovm/TransformerComputeKernelsLayered.java

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,10 @@ public TransformerComputeKernelsLayered() {
3434
* @param localMemSize Size of local memory allocation (must match work group size)
3535
*/
3636
public static void reductionOneBlockWithLayer(KernelContext context, FloatArray output, FloatArray x, int size, float ermsNorm, int localMemSize) {
37-
int gid = context.globalIdx; // 0-1024
38-
int lid = context.localIdx; // 0-256
39-
int groupId = context.groupIdx; // 0-4
40-
int groupSize = context.localGroupSizeX; // 256
37+
int gid = context.globalIdx;
38+
int lid = context.localIdx;
39+
int groupId = context.groupIdx;
40+
int groupSize = context.localGroupSizeX;
4141

4242
// Allocate local memory with the provided size
4343
float[] localX = context.allocateFloatLocalArray(localMemSize);
@@ -115,8 +115,7 @@ public static void reductionOneBlock2WithLayer(KernelContext context, FloatArray
115115
* @param layer Current transformer layer index
116116
* @param contextLength Maximum sequence length
117117
*/
118-
public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue,
119-
IntArray positioNlayer, int kvDim, int layer, int contextLength) {
118+
public static void copyToCache(FloatArray destKeyCache, FloatArray srcKey, FloatArray destValueCache, FloatArray srcValue, IntArray positioNlayer, int kvDim, int layer, int contextLength) {
120119

121120
int position = positioNlayer.get(0);
122121
int loff = layer * contextLength * kvDim;
@@ -195,14 +194,8 @@ public static void ropeRotation(KernelContext context, IntArray positionHolder,
195194
* @param layer Current transformer layer
196195
* @param contextLength Maximum context length
197196
*/
198-
public static void processHeadsParallel(
199-
FloatArray q,
200-
FloatArray key_cache,
201-
FloatArray value_cache,
202-
FloatArray xb,
203-
int nHeads, int headSize, int kvDim, int kvMul, int seqLen,
204-
IntArray positionHolder,
205-
FloatArray wrapAtt, int layer, int contextLength) {
197+
public static void processHeadsParallel(FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, int seqLen,
198+
IntArray positionHolder, FloatArray wrapAtt, int layer, int contextLength) {
206199

207200
int pos = positionHolder.get(0);
208201
int loff = layer * contextLength * kvDim;
@@ -663,8 +656,7 @@ public static void matrixVectorGeneric(
663656
* @param d Output dimension
664657
* @param localWorkGroupSize Work group size
665658
*/
666-
public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w,
667-
int n, int d, int localWorkGroupSize) {
659+
public static void matrixVectorGenericWithResidual(KernelContext context, FloatArray x, FloatArray hb, HalfFloatArray w, int n, int d, int localWorkGroupSize) {
668660
// One row per workgroup (not per thread)
669661
int rowId = context.groupIdx;
670662
int localId = context.localIdx;
@@ -794,8 +786,8 @@ public static float matrixVectorRowMajorOptimized(KernelContext context, int loc
794786
}
795787

796788
public static float matrixVectorRowMajorOptimized(KernelContext context, int localSize, FloatArray x, HalfFloatArray w, int n) {
797-
int rowId = context.groupIdx; // 0-dim
798-
int localId = context.localIdx; // 0-32
789+
int rowId = context.groupIdx;
790+
int localId = context.localIdx;
799791

800792
// Allocate local memory for reduction
801793
float[] localSum = context.allocateFloatLocalArray(localSize);

0 commit comments

Comments
 (0)