|
4 | 4 | import org.beehive.gpullama3.core.model.tensor.FloatTensor;
|
5 | 5 | import org.beehive.gpullama3.inference.state.Phi3State;
|
6 | 6 | import org.beehive.gpullama3.inference.state.State;
|
| 7 | +import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights; |
7 | 8 | import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
|
8 | 9 | import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights;
|
9 | 10 | import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
|
10 | 11 | import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
|
11 | 12 | import org.beehive.gpullama3.model.Configuration;
|
12 | 13 | import org.beehive.gpullama3.model.Model;
|
| 14 | +import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; |
13 | 15 | import org.beehive.gpullama3.model.phi3.Phi3Configuration;
|
14 | 16 | import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
|
15 | 17 | import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
|
| 18 | + |
16 | 19 | import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
|
17 | 20 |
|
18 | 21 | import java.lang.foreign.MemorySegment;
|
@@ -176,6 +179,137 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
|
176 | 179 | return state.logits;
|
177 | 180 | }
|
178 | 181 |
|
| 182 | + public static FloatTensor forwardJavaQwen2(Model model, State state, int token, int position) { |
| 183 | + final Qwen2Configuration config = (Qwen2Configuration) model.configuration(); |
| 184 | + final Qwen2StandardWeights weights = (Qwen2StandardWeights) model.weights(); |
| 185 | + int dim = config.dim(); |
| 186 | + int headSize = config.headSize(); |
| 187 | + int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads(); |
| 188 | + int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery |
| 189 | + float sqrtHeadSize = (float) Math.sqrt(headSize); |
| 190 | + |
| 191 | + weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim); |
| 192 | + |
| 193 | + // forward all the layers |
| 194 | + for (int l = 0; l < config.numberOfLayers(); l++) { |
| 195 | + // attention rmsnorm |
| 196 | + final int curLayer = l; |
| 197 | + rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer], 0, dim, config.rmsNormEps()); |
| 198 | + |
| 199 | + // qkv matmuls for this position |
| 200 | + weights.wq[l].matmul(state.xb, state.q, dim, dim); |
| 201 | + weights.wk[l].matmul(state.xb, state.k, kvDim, dim); |
| 202 | + weights.wv[l].matmul(state.xb, state.v, kvDim, dim); |
| 203 | + |
| 204 | + // qkv additions with qkv bias |
| 205 | + state.q.addInPlace(weights.q_bias[curLayer]); |
| 206 | + state.k.addInPlace(weights.k_bias[curLayer]); |
| 207 | + state.v.addInPlace(weights.v_bias[curLayer]); |
| 208 | + |
| 209 | + // RoPE relative positional encoding: complex-valued rotate q and k in each head |
| 210 | + // GPT-NeoX style RoPE, real/imaginary components are stored with a headSize/2 offset per head, instead of consecutive. |
| 211 | + for (int h = 0; h < config.numberOfHeads(); ++h) { |
| 212 | + int rotn = h < config.numberOfKeyValueHeads() ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only |
| 213 | + int poffset = h * headSize; |
| 214 | + for (int i0 = 0; i0 < headSize; i0 += 2) { |
| 215 | + int ic = i0 / 2; |
| 216 | + float fcr = weights.freq_cis_real.getFloat((position) * (headSize / 2) + ic); |
| 217 | + float fci = weights.freq_cis_imag.getFloat((position) * (headSize / 2) + ic); |
| 218 | + for (int vi = 0; vi < rotn; vi++) { |
| 219 | + FloatTensor vec = (vi == 0) ? state.q : state.k; // the vector to rotate (query or key) |
| 220 | + float v0 = vec.getFloat(poffset + ic); |
| 221 | + float v1 = vec.getFloat(poffset + ic + headSize/2); |
| 222 | + vec.setFloat(poffset + ic, v0 * fcr - v1 * fci); |
| 223 | + vec.setFloat(poffset + ic + headSize/2, v0 * fci + v1 * fcr); |
| 224 | + } |
| 225 | + } |
| 226 | + } |
| 227 | + |
| 228 | + // save key,value at this time step (position) to our kv cache |
| 229 | + //int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience |
| 230 | + state.k.copyTo(0, state.keyCache[curLayer], position * kvDim, kvDim); |
| 231 | + state.v.copyTo(0, state.valueCache[curLayer], position * kvDim, kvDim); |
| 232 | + |
| 233 | + // multihead attention. iterate over all heads |
| 234 | + Parallel.parallelFor(0, config.numberOfHeads(), h -> { |
| 235 | + // get the query vector for this head |
| 236 | + // float* q = s.q + h * headSize; |
| 237 | + int qOffset = h * headSize; |
| 238 | + |
| 239 | + // attention scores for this head |
| 240 | + // float* att = s.att + h * config.seq_len; |
| 241 | + int attOffset = h * config.contextLength(); |
| 242 | + |
| 243 | + // iterate over all timesteps, including the current one |
| 244 | + for (int t = 0; t <= position; t++) { |
| 245 | + // get the key vector for this head and at this timestep |
| 246 | + // float* k = s.key_cache + loff + t * dim + h * headSize; |
| 247 | + int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; |
| 248 | + // calculate the attention score as the dot product of q and k |
| 249 | + float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize); |
| 250 | + score /= sqrtHeadSize; |
| 251 | + // save the score to the attention buffer |
| 252 | + state.att.setFloat(attOffset + t, score); |
| 253 | + } |
| 254 | + |
| 255 | + // softmax the scores to get attention weights, from 0..position inclusively |
| 256 | + state.att.softmaxInPlace(attOffset, position + 1); |
| 257 | + |
| 258 | + // weighted sum of the values, store back into xb |
| 259 | + // float* xb = s.xb + h * headSize; |
| 260 | + int xbOffset = h * headSize; |
| 261 | + // memset(xb, 0, headSize * sizeof(float)); |
| 262 | + state.xb.fillInPlace(xbOffset, headSize, 0f); |
| 263 | + |
| 264 | + for (int t = 0; t <= position; t++) { |
| 265 | + // get the value vector for this head and at this timestep |
| 266 | + // float* v = s.value_cache + loff + t * dim + h * headSize;C |
| 267 | + int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize; |
| 268 | + // get the attention weight for this timestep |
| 269 | + float a = state.att.getFloat(attOffset + t); |
| 270 | + // accumulate the weighted value into xb |
| 271 | + state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a); |
| 272 | + } |
| 273 | + }); |
| 274 | + |
| 275 | + // final matmul to get the output of the attention |
| 276 | + weights.wo[l].matmul(state.xb, state.xb2, dim, dim); |
| 277 | + |
| 278 | + // residual connection back into x |
| 279 | + state.x.addInPlace(state.xb2); |
| 280 | + |
| 281 | + // ffn rmsnorm |
| 282 | + rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer], 0, dim, config.rmsNormEps()); |
| 283 | + |
| 284 | + // Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x)) |
| 285 | + // first calculate self.w1(x) and self.w3(x) |
| 286 | + weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim); |
| 287 | + weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim); |
| 288 | + |
| 289 | + // SwiGLU non-linearity |
| 290 | + // silu(x)=x*σ(x), where σ(x) is the logistic sigmoid |
| 291 | + state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value))); |
| 292 | + |
| 293 | + // elementwise multiply with w3(x) |
| 294 | + state.hb.multiplyInPlace(state.hb2); |
| 295 | + |
| 296 | + // final matmul to get the output of the ffn |
| 297 | + weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim()); |
| 298 | + |
| 299 | + // residual connection |
| 300 | + state.x.addInPlace(state.xb); |
| 301 | + |
| 302 | + } |
| 303 | + |
| 304 | + // final rmsnorm |
| 305 | + rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps()); |
| 306 | + |
| 307 | + // classifier into logits |
| 308 | + weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim); |
| 309 | + |
| 310 | + return state.logits; |
| 311 | + } |
| 312 | + |
179 | 313 | public static FloatTensor forwardJavaQwen3(Model model, State state, int token, int position) {
|
180 | 314 | // a few convenience variables
|
181 | 315 | final Qwen3Configuration config = (Qwen3Configuration) model.configuration();
|
|
0 commit comments