Skip to content

Commit a8aeaf8

Browse files
Improve attention performance for qwen2.5 & deepseek
1 parent 1403b4f commit a8aeaf8

File tree

2 files changed

+146
-6
lines changed

2 files changed

+146
-6
lines changed
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
package org.beehive.gpullama3.tornadovm;
2+
3+
import uk.ac.manchester.tornado.api.KernelContext;
4+
import uk.ac.manchester.tornado.api.math.TornadoMath;
5+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
6+
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
7+
8+
public class Qwen2Kernels {
9+
10+
public static void processHeadsFlashAttention(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul,
11+
IntArray positionHolder, int layer, int contextLength) {
12+
13+
// Thread and workgroup information
14+
int globalTid = context.globalIdx;
15+
int localTid = context.localIdx;
16+
int localSize = context.localGroupSizeX;
17+
int workgroupId = context.groupIdx;
18+
19+
// Calculate which head this workgroup processes
20+
int h = workgroupId;
21+
22+
// Early exit if beyond head count
23+
if (h >= nHeads) {
24+
return;
25+
}
26+
27+
int pos = positionHolder.get(0);
28+
int loff = layer * contextLength * kvDim;
29+
int kvHeadIdx = h / kvMul;
30+
int BLOCK_SIZE_C = 8;
31+
32+
// Allocate shared memory for tiled computation
33+
float[] q_shared = context.allocateFloatLocalArray(headSize);
34+
float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize);
35+
float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize);
36+
float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C);
37+
float[] shared_max = context.allocateFloatLocalArray(1);
38+
39+
// Per-thread output accumulation
40+
float[] output = new float[headSize];
41+
for (int i = 0; i < headSize; i++) {
42+
output[i] = 0.0f;
43+
}
44+
45+
// Thread-local accumulators for online softmax
46+
float maxScore = Float.NEGATIVE_INFINITY;
47+
float sumExp = 0.0f;
48+
49+
// Cooperatively load query vector into shared memory
50+
for (int i = localTid; i < headSize; i += localSize) {
51+
q_shared[i] = q.get(h * headSize + i);
52+
}
53+
context.localBarrier();
54+
55+
// Process sequence in tiles
56+
for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) {
57+
int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos);
58+
59+
// Cooperatively load key and value vectors for this tile
60+
for (int tIdxInSeq = tileC + localTid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) {
61+
int k_v_idx_in_tile = tIdxInSeq - tileC;
62+
int tileMemOffset = k_v_idx_in_tile * headSize;
63+
64+
for (int d = 0; d < headSize; d++) {
65+
int kvCacheAbsolutePos = tIdxInSeq;
66+
int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + d;
67+
k_tile[tileMemOffset + d] = key_cache.get(kvOffset);
68+
v_tile[tileMemOffset + d] = value_cache.get(kvOffset);
69+
}
70+
}
71+
context.localBarrier();
72+
73+
// Cooperatively compute attention scores for this tile
74+
for (int tIdxInSeq = tileC + localTid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) {
75+
int score_idx_in_tile = tIdxInSeq - tileC;
76+
77+
float score = 0.0f;
78+
for (int d = 0; d < headSize; d++) {
79+
score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d];
80+
}
81+
score /= TornadoMath.sqrt(headSize);
82+
s_tile[score_idx_in_tile] = score;
83+
}
84+
context.localBarrier();
85+
86+
// Find max score in this tile using reduction
87+
float tileLocalMax = Float.NEGATIVE_INFINITY;
88+
for (int i = 0; i <= tileEnd - tileC; i++) {
89+
if (s_tile[i] > tileLocalMax) {
90+
tileLocalMax = s_tile[i];
91+
}
92+
}
93+
94+
// Thread 0 broadcasts the max
95+
if (localTid == 0) {
96+
shared_max[0] = tileLocalMax;
97+
}
98+
context.localBarrier();
99+
float currentTileMax = shared_max[0];
100+
101+
// Update global max and rescale if needed
102+
float newMax = Math.max(maxScore, currentTileMax);
103+
if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) {
104+
float scale = TornadoMath.exp(maxScore - newMax);
105+
sumExp *= scale;
106+
for (int d = 0; d < headSize; d++) {
107+
output[d] *= scale;
108+
}
109+
}
110+
maxScore = newMax;
111+
112+
// Process each key-value pair in the tile
113+
for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) {
114+
float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore);
115+
sumExp += expScore;
116+
117+
// Accumulate weighted values
118+
for (int d = 0; d < headSize; d++) {
119+
output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d];
120+
}
121+
}
122+
context.localBarrier();
123+
}
124+
125+
// Normalize and cooperatively write final results
126+
float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f;
127+
for (int d = localTid; d < headSize; d += localSize) {
128+
xb.set(h * headSize + d, output[d] * normFactor);
129+
}
130+
}
131+
}

src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ public Tuple2<List<ImmutableTaskGraph>, GridScheduler> setupTornadoForwardPlanLa
8383
config.headSize())
8484
.task("copyToCaches", TransformerComputeKernelsLayered::copyToCache,
8585
state.wrapKeyCache, state.wrapK, state.wrapValueCache, state.wrapV, state.positionHolder, config.kvDim(), layerIndex, config.contextLength())
86-
.task("parallel-attention", TransformerComputeKernelsLayered::processHeadsFlashAttention, context,
86+
.task("parallel-attention", Qwen2Kernels::processHeadsFlashAttention, context,
8787
state.wrapQ, state.wrapKeyCache, state.wrapValueCache, state.wrapXb,
8888
config.numberOfHeads(), config.headSize(), config.kvDim(), config.kvMul(),
8989
state.positionHolder, layerIndex, config.contextLength())
@@ -190,12 +190,21 @@ private GridScheduler setupQwen2GridSchedulersLayeredNonNvidia() {
190190
rmsNormWorker.setLocalWork(32, 1, 1); // Set local work size to 256 (standard efficient size)
191191

192192
// Parallel attention worker configuration
193-
// OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.numberOfHeads,1,1], localWorkSize=[4,1,1])
194-
// CUDA equivalent: kernel<<<dim3((config.numberOfHeads+3)/4,1,1), dim3(4,1,1)>>>
193+
// Calculate optimal local work size based on head dimension
194+
int optimalLocalSize = Math.min(config.headSize(), 64); // Start with 64 threads per head
195+
if (config.headSize() % optimalLocalSize != 0) {
196+
// Find largest divisor of headSize <= 64
197+
for (int size = 64; size >= 1; size--) {
198+
if (config.headSize() % size == 0) {
199+
optimalLocalSize = size;
200+
break;
201+
}
202+
}
203+
}
204+
195205
WorkerGrid parallelAttentionWorker = new WorkerGrid1D(config.numberOfHeads());
196-
// the global group work size is numberOfHeads * localWorkGroupSize, where the localWorkGroupSize is currently 4
197-
parallelAttentionWorker.setGlobalWork(config.numberOfHeads(), 1, 1);
198-
parallelAttentionWorker.setLocalWork(1, 1, 1); // Set local work size to 4 (for parallel attention)
206+
parallelAttentionWorker.setGlobalWork(config.numberOfHeads() * optimalLocalSize, 1, 1);
207+
parallelAttentionWorker.setLocalWork(optimalLocalSize, 1, 1);
199208

200209
// Copy to caches worker configuration
201210
// OpenCL equivalent: clEnqueueNDRangeKernel(globalWorkSize=[config.dim,1,1], localWorkSize=[128,1,1])

0 commit comments

Comments
 (0)