Skip to content

Commit dd63309

Browse files
committed
Merge remote-tracking branch 'origin/main' into feat/api
2 parents 47f281b + efbe261 commit dd63309

18 files changed

+1127
-91
lines changed

README.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,13 @@ Download `FP16` quantized `Qwen3` .gguf files from:
244244
- https://huggingface.co/ggml-org/Qwen3-4B-GGUF
245245
- https://huggingface.co/ggml-org/Qwen3-8B-GGUF
246246

247+
Download `FP16` quantized `Qwen2.5` .gguf files from:
248+
- https://huggingface.co/bartowski/Qwen2.5-0.5B-Instruct-GGUF
249+
- https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct-GGUF
250+
251+
Download `FP16` quantized `DeepSeek-R1-Distill-Qwen` .gguf files from:
252+
- https://huggingface.co/hdnh2006/DeepSeek-R1-Distill-Qwen-1.5B-GGUF
253+
247254
Please be gentle with [huggingface.co](https://huggingface.co) servers:
248255

249256
**Note** FP16 models are first-class citizens for the current version.
@@ -274,6 +281,15 @@ wget https://huggingface.co/ggml-org/Qwen3-0.6B-GGUF/resolve/main/Qwen3-8B-f16.g
274281
275282
# Phi-3-mini-4k - FP16
276283
wget https://huggingface.co/microsoft/Phi-3-mini-4k-instruct-gguf/resolve/main/Phi-3-mini-4k-instruct-fp16.gguf
284+
285+
# Qwen2.5 (0.5B)
286+
wget https://huggingface.co/bartowski/Qwen2.5-0.5B-Instruct-GGUF/resolve/main/Qwen2.5-0.5B-Instruct-f16.gguf
287+
288+
# Qwen2.5 (1.5B)
289+
wget https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct-GGUF/resolve/main/qwen2.5-1.5b-instruct-fp16.gguf
290+
291+
# DeepSeek-R1-Distill-Qwen (1.5B)
292+
wget https://huggingface.co/hdnh2006/DeepSeek-R1-Distill-Qwen-1.5B-GGUF/resolve/main/DeepSeek-R1-Distill-Qwen-1.5B-F16.gguf
277293
```
278294

279295
**[Experimental]** you can download the Q8 and Q4 used in the original implementation of Llama3.java, but for now are going to be dequanted to FP16 for TornadoVM support:

src/main/java/org/beehive/gpullama3/inference/InferenceCore.java

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,18 @@
44
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
55
import org.beehive.gpullama3.inference.state.Phi3State;
66
import org.beehive.gpullama3.inference.state.State;
7+
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
78
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
89
import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights;
910
import org.beehive.gpullama3.inference.weights.standard.StandardWeights;
1011
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
1112
import org.beehive.gpullama3.model.Configuration;
1213
import org.beehive.gpullama3.model.Model;
14+
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
1315
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
1416
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
1517
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
18+
1619
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1720

1821
import java.lang.foreign.MemorySegment;
@@ -176,6 +179,137 @@ public static FloatTensor forwardJava(Model model, State state, int token, int p
176179
return state.logits;
177180
}
178181

182+
public static FloatTensor forwardJavaQwen2(Model model, State state, int token, int position) {
183+
final Qwen2Configuration config = (Qwen2Configuration) model.configuration();
184+
final Qwen2StandardWeights weights = (Qwen2StandardWeights) model.weights();
185+
int dim = config.dim();
186+
int headSize = config.headSize();
187+
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
188+
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery
189+
float sqrtHeadSize = (float) Math.sqrt(headSize);
190+
191+
weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
192+
193+
// forward all the layers
194+
for (int l = 0; l < config.numberOfLayers(); l++) {
195+
// attention rmsnorm
196+
final int curLayer = l;
197+
rmsnorm(state.xb, state.x, weights.rms_att_weight[curLayer], 0, dim, config.rmsNormEps());
198+
199+
// qkv matmuls for this position
200+
weights.wq[l].matmul(state.xb, state.q, dim, dim);
201+
weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
202+
weights.wv[l].matmul(state.xb, state.v, kvDim, dim);
203+
204+
// qkv additions with qkv bias
205+
state.q.addInPlace(weights.q_bias[curLayer]);
206+
state.k.addInPlace(weights.k_bias[curLayer]);
207+
state.v.addInPlace(weights.v_bias[curLayer]);
208+
209+
// RoPE relative positional encoding: complex-valued rotate q and k in each head
210+
// GPT-NeoX style RoPE, real/imaginary components are stored with a headSize/2 offset per head, instead of consecutive.
211+
for (int h = 0; h < config.numberOfHeads(); ++h) {
212+
int rotn = h < config.numberOfKeyValueHeads() ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
213+
int poffset = h * headSize;
214+
for (int i0 = 0; i0 < headSize; i0 += 2) {
215+
int ic = i0 / 2;
216+
float fcr = weights.freq_cis_real.getFloat((position) * (headSize / 2) + ic);
217+
float fci = weights.freq_cis_imag.getFloat((position) * (headSize / 2) + ic);
218+
for (int vi = 0; vi < rotn; vi++) {
219+
FloatTensor vec = (vi == 0) ? state.q : state.k; // the vector to rotate (query or key)
220+
float v0 = vec.getFloat(poffset + ic);
221+
float v1 = vec.getFloat(poffset + ic + headSize/2);
222+
vec.setFloat(poffset + ic, v0 * fcr - v1 * fci);
223+
vec.setFloat(poffset + ic + headSize/2, v0 * fci + v1 * fcr);
224+
}
225+
}
226+
}
227+
228+
// save key,value at this time step (position) to our kv cache
229+
//int loff = l * config.seq_len * kvDim; // kv cache layer offset for convenience
230+
state.k.copyTo(0, state.keyCache[curLayer], position * kvDim, kvDim);
231+
state.v.copyTo(0, state.valueCache[curLayer], position * kvDim, kvDim);
232+
233+
// multihead attention. iterate over all heads
234+
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
235+
// get the query vector for this head
236+
// float* q = s.q + h * headSize;
237+
int qOffset = h * headSize;
238+
239+
// attention scores for this head
240+
// float* att = s.att + h * config.seq_len;
241+
int attOffset = h * config.contextLength();
242+
243+
// iterate over all timesteps, including the current one
244+
for (int t = 0; t <= position; t++) {
245+
// get the key vector for this head and at this timestep
246+
// float* k = s.key_cache + loff + t * dim + h * headSize;
247+
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
248+
// calculate the attention score as the dot product of q and k
249+
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
250+
score /= sqrtHeadSize;
251+
// save the score to the attention buffer
252+
state.att.setFloat(attOffset + t, score);
253+
}
254+
255+
// softmax the scores to get attention weights, from 0..position inclusively
256+
state.att.softmaxInPlace(attOffset, position + 1);
257+
258+
// weighted sum of the values, store back into xb
259+
// float* xb = s.xb + h * headSize;
260+
int xbOffset = h * headSize;
261+
// memset(xb, 0, headSize * sizeof(float));
262+
state.xb.fillInPlace(xbOffset, headSize, 0f);
263+
264+
for (int t = 0; t <= position; t++) {
265+
// get the value vector for this head and at this timestep
266+
// float* v = s.value_cache + loff + t * dim + h * headSize;C
267+
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
268+
// get the attention weight for this timestep
269+
float a = state.att.getFloat(attOffset + t);
270+
// accumulate the weighted value into xb
271+
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
272+
}
273+
});
274+
275+
// final matmul to get the output of the attention
276+
weights.wo[l].matmul(state.xb, state.xb2, dim, dim);
277+
278+
// residual connection back into x
279+
state.x.addInPlace(state.xb2);
280+
281+
// ffn rmsnorm
282+
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[curLayer], 0, dim, config.rmsNormEps());
283+
284+
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
285+
// first calculate self.w1(x) and self.w3(x)
286+
weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
287+
weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);
288+
289+
// SwiGLU non-linearity
290+
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
291+
state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
292+
293+
// elementwise multiply with w3(x)
294+
state.hb.multiplyInPlace(state.hb2);
295+
296+
// final matmul to get the output of the ffn
297+
weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());
298+
299+
// residual connection
300+
state.x.addInPlace(state.xb);
301+
302+
}
303+
304+
// final rmsnorm
305+
rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());
306+
307+
// classifier into logits
308+
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);
309+
310+
return state.logits;
311+
}
312+
179313
public static FloatTensor forwardJavaQwen3(Model model, State state, int token, int position) {
180314
// a few convenience variables
181315
final Qwen3Configuration config = (Qwen3Configuration) model.configuration();
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
package org.beehive.gpullama3.inference.state;
2+
3+
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
4+
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
5+
import org.beehive.gpullama3.model.Configuration;
6+
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
7+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
8+
import uk.ac.manchester.tornado.api.types.arrays.IntArray;
9+
10+
import java.util.stream.Stream;
11+
12+
public class Qwen2State extends State {
13+
14+
public Qwen2State(Configuration config, int batchsize) {
15+
super(config, batchsize);
16+
this.localSize = 32;
17+
}
18+
@Override
19+
protected StateFields createStateFields(Configuration configuration) {
20+
StateFields fields = new StateFields();
21+
22+
Qwen2Configuration config = (Qwen2Configuration) configuration;
23+
24+
int nEmbdGqa = config.kvDim();
25+
26+
// with Qwen2-specific sizes
27+
fields.x = ArrayFloatTensor.allocate(config.dim());
28+
fields.xb = ArrayFloatTensor.allocate(config.dim());
29+
fields.xb2 = ArrayFloatTensor.allocate(config.dim());
30+
fields.hb = ArrayFloatTensor.allocate(config.hiddenDim());
31+
fields.hb2 = ArrayFloatTensor.allocate(config.hiddenDim());
32+
fields.q = ArrayFloatTensor.allocate(config.dim());
33+
fields.k = ArrayFloatTensor.allocate(config.kvDim());
34+
fields.v = ArrayFloatTensor.allocate(config.kvDim());
35+
fields.att = ArrayFloatTensor.allocate(config.numberOfHeads(), config.contextLength());
36+
fields.logits = ArrayFloatTensor.allocate(config.vocabularySize());
37+
38+
// Key-value cache with Qwen2 dimensions
39+
fields.keyCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
40+
fields.valueCache = Stream.generate(() -> ArrayFloatTensor.allocate(config.contextLength(), nEmbdGqa)).limit(config.numberOfLayers()).toArray(FloatTensor[]::new);
41+
42+
// TornadoVM wrappers with Qwen2 dimensions
43+
fields.wrapX = new FloatArray(config.dim());
44+
fields.wrapXb = new FloatArray(config.dim());
45+
fields.wrapXb2 = new FloatArray(config.dim());
46+
fields.wrapHb = new FloatArray(config.hiddenDim());
47+
fields.wrapHb2 = new FloatArray(config.hiddenDim());
48+
49+
fields.wrapLogits = new FloatArray(config.vocabularySize());
50+
fields.wrapQ = new FloatArray(config.dim());
51+
fields.wrapK = new FloatArray(config.kvDim());
52+
fields.wrapV = new FloatArray(config.kvDim());
53+
54+
fields.wrapKeyCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
55+
fields.wrapValueCache = new FloatArray(config.contextLength() * nEmbdGqa * config.numberOfLayers());
56+
fields.wrapValueCache.init(0.f);
57+
fields.wrapKeyCache.init(0.f);
58+
fields.wrapAtt = new FloatArray(config.numberOfHeads() * config.contextLength());
59+
fields.positionHolder = new IntArray(1);
60+
61+
// Temporary arrays
62+
fields.temp = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
63+
fields.tempFFN = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
64+
fields.tempLogits = new FloatArray(1 + ((config.dim() + localSize - 1) / localSize));
65+
66+
return fields;
67+
68+
}
69+
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
package org.beehive.gpullama3.inference.weights.standard;
2+
3+
import org.beehive.gpullama3.core.model.GGMLType;
4+
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
5+
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
6+
import org.beehive.gpullama3.inference.weights.Weights;
7+
8+
public class Qwen2StandardWeights extends StandardWeights {
9+
// Qwen2-specific weights
10+
public final FloatTensor[] q_bias, k_bias, v_bias;
11+
12+
public Qwen2StandardWeights(
13+
FloatTensor token_embedding_table,
14+
FloatTensor[] rms_att_weight,
15+
FloatTensor[] wq,
16+
FloatTensor[] wk,
17+
FloatTensor[] wv,
18+
FloatTensor[] q_bias,
19+
FloatTensor[] k_bias,
20+
FloatTensor[] v_bias,
21+
FloatTensor[] wo,
22+
FloatTensor[] rms_ffn_weight,
23+
FloatTensor[] w1,
24+
FloatTensor[] w2,
25+
FloatTensor[] w3,
26+
FloatTensor rms_final_weight,
27+
ArrayFloatTensor freq_cis_real,
28+
ArrayFloatTensor freq_cis_imag,
29+
FloatTensor wcls,
30+
GGMLType weightType) {
31+
// call to StandardWeights constructor
32+
super(token_embedding_table,
33+
rms_att_weight,
34+
wq,
35+
wk,
36+
wv,
37+
wo,
38+
rms_ffn_weight,
39+
w1,
40+
w2,
41+
w3,
42+
rms_final_weight,
43+
freq_cis_real,
44+
freq_cis_imag,
45+
wcls,
46+
weightType);
47+
// init Qwen2-specific fields
48+
this.q_bias = q_bias;
49+
this.k_bias = k_bias;
50+
this.v_bias = v_bias;
51+
}
52+
53+
@Override
54+
public GGMLType getWeightType() {
55+
return weightType;
56+
}
57+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package org.beehive.gpullama3.inference.weights.tornado;
2+
3+
import org.beehive.gpullama3.core.model.GGMLType;
4+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
5+
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
6+
7+
public class Qwen2TornadoWeights extends TornadoWeights {
8+
9+
// Qwen2-specific tornado weights
10+
public FloatArray[] q_biasLayered;
11+
public FloatArray[] k_biasLayered;
12+
public FloatArray[] v_biasLayered;
13+
14+
public Qwen2TornadoWeights(FloatArray tokenEmbeddingTable, FloatArray[] rms_att_weightLayered, HalfFloatArray[] wqLayered, HalfFloatArray[] wkLayered, HalfFloatArray[] wvLayered,
15+
FloatArray[] wqBiasLayered,
16+
FloatArray[] wkBiasLayered,
17+
FloatArray[] wvBiasLayered,
18+
HalfFloatArray[] woLayered, FloatArray[] rms_ffn_weightLayered, HalfFloatArray[] w1Layered,
19+
HalfFloatArray[] w2Layered, HalfFloatArray[] w3Layered, FloatArray rms_final_weight_as_floatArray, FloatArray freq_cis_realFlat, FloatArray freq_cis_imagFlat, HalfFloatArray wclsByteArray,
20+
GGMLType weightType) {
21+
// call to TornadoWeights constructor
22+
super(tokenEmbeddingTable,
23+
rms_att_weightLayered,
24+
wqLayered,
25+
wkLayered,
26+
wvLayered,
27+
woLayered,
28+
rms_ffn_weightLayered,
29+
w1Layered,
30+
w2Layered,
31+
w3Layered,
32+
rms_final_weight_as_floatArray,
33+
freq_cis_realFlat,
34+
freq_cis_imagFlat,
35+
wclsByteArray,
36+
weightType);
37+
// init qwen2-specific fields
38+
this.q_biasLayered = wqBiasLayered;
39+
this.k_biasLayered = wkBiasLayered;
40+
this.v_biasLayered = wvBiasLayered;
41+
}
42+
}

0 commit comments

Comments
 (0)