Skip to content

Commit dabbdfb

Browse files
Remove duplicative kernel
1 parent 5224381 commit dabbdfb

File tree

2 files changed

+1
-67
lines changed

2 files changed

+1
-67
lines changed

src/main/java/com/example/tornadovm/Qwen3Kernels.java

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -296,70 +296,4 @@ private static void processHeadTornado(
296296
}
297297
}
298298

299-
public static void matrixVectorGenericWithResidual(
300-
KernelContext context,
301-
FloatArray v,
302-
FloatArray out,
303-
HalfFloatArray m,
304-
int dim1,
305-
int dim0,
306-
int localWorkGroupSize) {
307-
308-
// One row per workgroup (not per thread)
309-
int rowId = context.groupIdx;
310-
int localId = context.localIdx;
311-
int localSize = localWorkGroupSize;
312-
313-
// Early exit if this workgroup is beyond our output dimension
314-
if (rowId >= dim0) {
315-
return;
316-
}
317-
318-
float sum = matrixVectorRowMajorOptimized(context, localSize, v, m, dim1, dim0);
319-
320-
// Thread 0 in each workgroup writes the final result
321-
if (localId == 0) {
322-
float result = out.get(rowId) + sum;
323-
out.set(rowId, result);
324-
}
325-
}
326-
327-
public static float matrixVectorRowMajorOptimized(
328-
KernelContext context,
329-
int localSize,
330-
FloatArray v,
331-
HalfFloatArray m,
332-
int dim1,
333-
int dim0
334-
) {
335-
int rowId = context.groupIdx;
336-
int localId = context.localIdx;
337-
338-
// Allocate local memory for reduction
339-
float[] localSum = context.allocateFloatLocalArray(localSize);
340-
341-
int rowOffset = rowId * dim1;
342-
343-
// Each thread calculates partial dot product
344-
float partialSum = 0.0f;
345-
for (int j = localId; j < dim1; j += localSize) {
346-
int matrixIdx = rowOffset + j;
347-
partialSum += m.get(matrixIdx).getFloat32() * v.get(j);
348-
}
349-
350-
// Store partial sum in local memory
351-
localSum[localId] = partialSum;
352-
context.localBarrier();
353-
354-
// Parallel reduction within workgroup
355-
for (int stride = localSize / 2; stride > 0; stride >>= 1) {
356-
if (localId < stride) {
357-
localSum[localId] += localSum[localId + stride];
358-
}
359-
context.localBarrier();
360-
}
361-
362-
return localSum[0];
363-
}
364-
365299
}

src/main/java/com/example/tornadovm/Qwen3TornadoVMLayerPlanner.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
221221
layerIndex,
222222
config.contextLength());
223223

224-
unifiedLayer.task("matmul1", Qwen3Kernels::matrixVectorGenericWithResidual,
224+
unifiedLayer.task("matmul1", TransformerComputeKernelsLayered::matrixVectorGenericWithResidual,
225225
context,
226226
state.wrapXb, // vector
227227
state.wrapX, // out, should be [1024]

0 commit comments

Comments
 (0)