Skip to content

Commit c4562ad

Browse files
Add forward method for Qwen2
1 parent f96ddc4 commit c4562ad

File tree

1 file changed

+143
-0
lines changed

1 file changed

+143
-0
lines changed

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
55
import org.beehive.gpullama3.inference.state.Phi3State;
66
import org.beehive.gpullama3.inference.state.State;
7+
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
78
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
89
import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights;
910
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
1011
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
1112
import org.beehive.gpullama3.model.Configuration;
1213
import org.beehive.gpullama3.model.Model;
14+
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
1315
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
1416
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
1517
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
18+
1619
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1720

1821
import java.lang.foreign.MemorySegment;
@@ -176,6 +179,146 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
176179
return state.logits;
177180
}
178181

182+
public static FloatTensor forwardJavaQwen2(Model model, State state, int token, int position) {
183+
final Qwen2Configuration config = (Qwen2Configuration) model.configuration();
184+
final Qwen2StandardWeights weights = (Qwen2StandardWeights) model.weights();
185+
int dim = config.dim();
186+
int headSize = config.headSize();
187+
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
188+
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery
189+
float sqrtHeadSize = (float) Math.sqrt(headSize);
190+
191+
weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
192+
193+
// forward all the layers
194+
for (int l = 0; l < config.numberOfLayers(); l++) {
195+
// attention rmsnorm
196+
final int curLayer = l;
197+
rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer], 0, dim, config.rmsNormEps());
198+
199+
// qkv matmuls for this position
200+
weights.wq[l].matmul(state.xb, state.q, dim, dim);
201+
weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
202+
weights.wv[l].matmul(state.xb, state.v, kvDim, dim);
203+
204+
if ((weights.q_bias != null && weights.q_bias[curLayer] != null)
205+
|| (weights.k_bias != null && weights.k_bias[curLayer] != null)
206+
|| (weights.v_bias != null && weights.v_bias[curLayer] != null)) {
207+
if (weights.q_bias != null && weights.q_bias[curLayer] != null) {
208+
state.q.addInPlace(weights.q_bias[curLayer]);
209+
}
210+
if (weights.k_bias != null && weights.k_bias[curLayer] != null) {
211+
state.k.addInPlace(weights.k_bias[curLayer]);
212+
}
213+
if (weights.v_bias != null && weights.v_bias[curLayer] != null) {
214+
state.v.addInPlace(weights.v_bias[curLayer]);
215+
}
216+
}
217+
218+
// RoPE relative positional encoding: complex-valued rotate q and k in each head
219+
// GPT-NeoX style RoPE, real/imaginary components are stored with a headSize/2 offset per head, instead of consecutive.
220+
for (int h = 0; h < config.numberOfHeads(); ++h) {
221+
int rotn = h < config.numberOfKeyValueHeads() ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
222+
int poffset = h * headSize;
223+
for (int i0 = 0; i0 < headSize; i0 += 2) {
224+
int ic = i0 / 2;
225+
float fcr = weights.freq_cis_real.getFloat((position) * (headSize / 2) + ic);
226+
float fci = weights.freq_cis_imag.getFloat((position) * (headSize / 2) + ic);
227+
for (int vi = 0; vi < rotn; vi++) {
228+
FloatTensor vec = (vi == 0) ? state.q : state.k; // the vector to rotate (query or key)
229+
float v0 = vec.getFloat(poffset + ic);
230+
float v1 = vec.getFloat(poffset + ic + headSize/2);
231+
vec.setFloat(poffset + ic, v0 * fcr - v1 * fci);
232+
vec.setFloat(poffset + ic + headSize/2, v0 * fci + v1 * fcr);
233+
}
234+
}
235+
}
236+
237+
// save key,value at this time step (position) to our kv cache
238+
//int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
239+
state.k.copyTo(0, state.keyCache[curLayer], position * kvDim, kvDim);
240+
state.v.copyTo(0, state.valueCache[curLayer], position * kvDim, kvDim);
241+
242+
// multihead attention. iterate over all heads
243+
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
244+
// get the query vector for this head
245+
// float* q = s.q + h * headSize;
246+
int qOffset = h * headSize;
247+
248+
// attention scores for this head
249+
// float* att = s.att + h * config.seq_len;
250+
int attOffset = h * config.contextLength();
251+
252+
// iterate over all timesteps, including the current one
253+
for (int t = 0; t <= position; t++) {
254+
// get the key vector for this head and at this timestep
255+
// float* k = s.key_cache + loff + t * dim + h * headSize;
256+
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
257+
// calculate the attention score as the dot product of q and k
258+
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
259+
score /= sqrtHeadSize;
260+
// save the score to the attention buffer
261+
state.att.setFloat(attOffset + t, score);
262+
}
263+
264+
// softmax the scores to get attention weights, from 0..position inclusively
265+
state.att.softmaxInPlace(attOffset, position + 1);
266+
267+
// weighted sum of the values, store back into xb
268+
// float* xb = s.xb + h * headSize;
269+
int xbOffset = h * headSize;
270+
// memset(xb, 0, headSize * sizeof(float));
271+
state.xb.fillInPlace(xbOffset, headSize, 0f);
272+
273+
for (int t = 0; t <= position; t++) {
274+
// get the value vector for this head and at this timestep
275+
// float* v = s.value_cache + loff + t * dim + h * headSize;C
276+
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
277+
// get the attention weight for this timestep
278+
float a = state.att.getFloat(attOffset + t);
279+
// accumulate the weighted value into xb
280+
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
281+
}
282+
});
283+
284+
// final matmul to get the output of the attention
285+
weights.wo[l].matmul(state.xb, state.xb2, dim, dim);
286+
287+
// residual connection back into x
288+
state.x.addInPlace(state.xb2);
289+
290+
// ffn rmsnorm
291+
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer], 0, dim, config.rmsNormEps());
292+
293+
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
294+
// first calculate self.w1(x) and self.w3(x)
295+
weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
296+
weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);
297+
298+
// SwiGLU non-linearity
299+
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
300+
state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
301+
302+
// elementwise multiply with w3(x)
303+
state.hb.multiplyInPlace(state.hb2);
304+
305+
// final matmul to get the output of the ffn
306+
weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());
307+
308+
// residual connection
309+
state.x.addInPlace(state.xb);
310+
311+
}
312+
313+
// final rmsnorm
314+
rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());
315+
316+
// classifier into logits
317+
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);
318+
319+
return state.logits;
320+
}
321+
179322
public static FloatTensor forwardJavaQwen3(Model model, State state, int token, int position) {
180323
// a few convenience variables
181324
final Qwen3Configuration config = (Qwen3Configuration) model.configuration();

0 commit comments

Comments
 (0)