|
| 1 | +package org.beehive.gpullama3.tornadovm; |
| 2 | + |
| 3 | +import uk.ac.manchester.tornado.api.KernelContext; |
| 4 | +import uk.ac.manchester.tornado.api.math.TornadoMath; |
| 5 | +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; |
| 6 | +import uk.ac.manchester.tornado.api.types.arrays.IntArray; |
| 7 | + |
| 8 | +public class Qwen2Kernels { |
| 9 | + |
| 10 | + public static void processHeadsFlashAttention(KernelContext context, FloatArray q, FloatArray key_cache, FloatArray value_cache, FloatArray xb, int nHeads, int headSize, int kvDim, int kvMul, |
| 11 | + IntArray positionHolder, int layer, int contextLength) { |
| 12 | + |
| 13 | + // Thread and workgroup information |
| 14 | + int globalTid = context.globalIdx; |
| 15 | + int localTid = context.localIdx; |
| 16 | + int localSize = context.localGroupSizeX; |
| 17 | + int workgroupId = context.groupIdx; |
| 18 | + |
| 19 | + // Calculate which head this workgroup processes |
| 20 | + int h = workgroupId; |
| 21 | + |
| 22 | + // Early exit if beyond head count |
| 23 | + if (h >= nHeads) { |
| 24 | + return; |
| 25 | + } |
| 26 | + |
| 27 | + int pos = positionHolder.get(0); |
| 28 | + int loff = layer * contextLength * kvDim; |
| 29 | + int kvHeadIdx = h / kvMul; |
| 30 | + int BLOCK_SIZE_C = 8; |
| 31 | + |
| 32 | + // Allocate shared memory for tiled computation |
| 33 | + float[] q_shared = context.allocateFloatLocalArray(headSize); |
| 34 | + float[] k_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); |
| 35 | + float[] v_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C * headSize); |
| 36 | + float[] s_tile = context.allocateFloatLocalArray(BLOCK_SIZE_C); |
| 37 | + float[] shared_max = context.allocateFloatLocalArray(1); |
| 38 | + |
| 39 | + // Per-thread output accumulation |
| 40 | + float[] output = new float[headSize]; |
| 41 | + for (int i = 0; i < headSize; i++) { |
| 42 | + output[i] = 0.0f; |
| 43 | + } |
| 44 | + |
| 45 | + // Thread-local accumulators for online softmax |
| 46 | + float maxScore = Float.NEGATIVE_INFINITY; |
| 47 | + float sumExp = 0.0f; |
| 48 | + |
| 49 | + // Cooperatively load query vector into shared memory |
| 50 | + for (int i = localTid; i < headSize; i += localSize) { |
| 51 | + q_shared[i] = q.get(h * headSize + i); |
| 52 | + } |
| 53 | + context.localBarrier(); |
| 54 | + |
| 55 | + // Process sequence in tiles |
| 56 | + for (int tileC = 0; tileC <= pos; tileC += BLOCK_SIZE_C) { |
| 57 | + int tileEnd = Math.min(tileC + BLOCK_SIZE_C - 1, pos); |
| 58 | + |
| 59 | + // Cooperatively load key and value vectors for this tile |
| 60 | + for (int tIdxInSeq = tileC + localTid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { |
| 61 | + int k_v_idx_in_tile = tIdxInSeq - tileC; |
| 62 | + int tileMemOffset = k_v_idx_in_tile * headSize; |
| 63 | + |
| 64 | + for (int d = 0; d < headSize; d++) { |
| 65 | + int kvCacheAbsolutePos = tIdxInSeq; |
| 66 | + int kvOffset = loff + kvCacheAbsolutePos * kvDim + kvHeadIdx * headSize + d; |
| 67 | + k_tile[tileMemOffset + d] = key_cache.get(kvOffset); |
| 68 | + v_tile[tileMemOffset + d] = value_cache.get(kvOffset); |
| 69 | + } |
| 70 | + } |
| 71 | + context.localBarrier(); |
| 72 | + |
| 73 | + // Cooperatively compute attention scores for this tile |
| 74 | + for (int tIdxInSeq = tileC + localTid; tIdxInSeq <= tileEnd; tIdxInSeq += localSize) { |
| 75 | + int score_idx_in_tile = tIdxInSeq - tileC; |
| 76 | + |
| 77 | + float score = 0.0f; |
| 78 | + for (int d = 0; d < headSize; d++) { |
| 79 | + score += q_shared[d] * k_tile[score_idx_in_tile * headSize + d]; |
| 80 | + } |
| 81 | + score /= TornadoMath.sqrt(headSize); |
| 82 | + s_tile[score_idx_in_tile] = score; |
| 83 | + } |
| 84 | + context.localBarrier(); |
| 85 | + |
| 86 | + // Find max score in this tile using reduction |
| 87 | + float tileLocalMax = Float.NEGATIVE_INFINITY; |
| 88 | + for (int i = 0; i <= tileEnd - tileC; i++) { |
| 89 | + if (s_tile[i] > tileLocalMax) { |
| 90 | + tileLocalMax = s_tile[i]; |
| 91 | + } |
| 92 | + } |
| 93 | + |
| 94 | + // Thread 0 broadcasts the max |
| 95 | + if (localTid == 0) { |
| 96 | + shared_max[0] = tileLocalMax; |
| 97 | + } |
| 98 | + context.localBarrier(); |
| 99 | + float currentTileMax = shared_max[0]; |
| 100 | + |
| 101 | + // Update global max and rescale if needed |
| 102 | + float newMax = Math.max(maxScore, currentTileMax); |
| 103 | + if (newMax != maxScore && maxScore != Float.NEGATIVE_INFINITY) { |
| 104 | + float scale = TornadoMath.exp(maxScore - newMax); |
| 105 | + sumExp *= scale; |
| 106 | + for (int d = 0; d < headSize; d++) { |
| 107 | + output[d] *= scale; |
| 108 | + } |
| 109 | + } |
| 110 | + maxScore = newMax; |
| 111 | + |
| 112 | + // Process each key-value pair in the tile |
| 113 | + for (int t_idx_in_s_tile = 0; t_idx_in_s_tile <= tileEnd - tileC; t_idx_in_s_tile++) { |
| 114 | + float expScore = TornadoMath.exp(s_tile[t_idx_in_s_tile] - maxScore); |
| 115 | + sumExp += expScore; |
| 116 | + |
| 117 | + // Accumulate weighted values |
| 118 | + for (int d = 0; d < headSize; d++) { |
| 119 | + output[d] += expScore * v_tile[t_idx_in_s_tile * headSize + d]; |
| 120 | + } |
| 121 | + } |
| 122 | + context.localBarrier(); |
| 123 | + } |
| 124 | + |
| 125 | + // Normalize and cooperatively write final results |
| 126 | + float normFactor = (sumExp > 0.0f) ? (1.0f / sumExp) : 0.0f; |
| 127 | + for (int d = localTid; d < headSize; d += localSize) { |
| 128 | + xb.set(h * headSize + d, output[d] * normFactor); |
| 129 | + } |
| 130 | + } |
| 131 | +} |
0 commit comments