Skip to content

Commit be93182

Browse files
authored
Merge pull request #38 from mikepapadim/phi-3
[models][phi-3] Support for Microsoft's Phi-3 models
2 parents d053e9c + 3cc7ee6 commit be93182

21 files changed

+1821
-319
lines changed

src/main/java/com/example/inference/InferenceCore.java

Lines changed: 115 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22

33
import com.example.auxiliary.Parallel;
44
import com.example.core.model.tensor.FloatTensor;
5+
import com.example.inference.state.Phi3State;
56
import com.example.inference.state.State;
7+
import com.example.inference.weights.standard.Phi3StandardWeights;
68
import com.example.inference.weights.standard.Qwen3StandardWeights;
79
import com.example.inference.weights.standard.StandardWeights;
810
import com.example.inference.weights.tornado.TornadoWeights;
911
import com.example.model.Configuration;
1012
import com.example.model.Model;
13+
import com.example.model.phi3.Phi3Configuration;
1114
import com.example.model.qwen3.Qwen3Configuration;
1215
import com.example.tornadovm.TornadoVMMasterPlan;
1316
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
@@ -18,8 +21,7 @@
1821
* Low-level operations for model inference.
1922
*
2023
* <p>
21-
* This class provides core computational operations such as RMS normalization and
22-
* forward passes through model layers. It supports both CPU and GPU implementations.
24+
* This class provides core computational operations such as RMS normalization and forward passes through model layers. It supports both CPU and GPU implementations.
2325
* </p>
2426
*
2527
* <p>
@@ -308,6 +310,117 @@ public static FloatTensor forwardJavaQwen3(Model model, State state, int token,
308310
return state.logits;
309311
}
310312

313+
public static FloatTensor forwardJavaPhi3(Model model, Phi3State state, int token, int position) {
314+
Phi3Configuration config = (Phi3Configuration) model.configuration();
315+
Phi3StandardWeights weights = (Phi3StandardWeights) model.weights();
316+
int dim = config.dim();
317+
int headSize = config.headSize();
318+
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
319+
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery
320+
float sqrtHeadSize = (float) Math.sqrt(headSize);
321+
322+
// copy the token embedding into x
323+
weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
324+
325+
// Phi3: op_size = num_heads * head_dim + 2 * (num_key_value_heads * head_dim)
326+
final int opSize = dim + 2 * (config.numberOfKeyValueHeads() * headSize);
327+
328+
// forward all the layers
329+
for (int l = 0; l < config.numberOfLayers(); l++) {
330+
rmsnorm(state.xb, state.x, weights.rms_att_weight[l], 0, dim, config.rmsNormEps());
331+
332+
weights.wqkv[l].matmul(state.xb, state.qkv, opSize, dim);
333+
state.qkv.copyTo(0, state.q, 0, dim);
334+
// key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim]
335+
state.qkv.copyTo(dim, state.k, 0, config.numberOfKeyValueHeads() * headSize);
336+
// value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :]
337+
state.qkv.copyTo(dim + config.numberOfKeyValueHeads() * headSize, state.v, 0, config.numberOfKeyValueHeads() * headSize);
338+
339+
int dimHalf = headSize / 2;
340+
for (int i = 0; i < dim; i += 2) {
341+
int head_dim = i % headSize;
342+
int base = i - head_dim;
343+
int ic = base + head_dim / 2;
344+
float fcr = weights.freq_cis_real.getFloat(position * (headSize / 2) + (head_dim / 2));
345+
float fci = weights.freq_cis_imag.getFloat(position * (headSize / 2) + (head_dim / 2));
346+
int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
347+
for (int v = 0; v < rotn; v++) {
348+
FloatTensor vec = v == 0 ? state.q : state.k; // the vector to rotate (query or key)
349+
float v0 = vec.getFloat(ic);
350+
float v1 = vec.getFloat(ic + dimHalf);
351+
vec.setFloat(ic, v0 * fcr - v1 * fci);
352+
vec.setFloat(ic + dimHalf, v0 * fci + v1 * fcr);
353+
}
354+
}
355+
356+
// save key,value at this time step (position) to our kv cache
357+
state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim);
358+
state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim);
359+
360+
int curLayer = l;
361+
362+
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
363+
int qOffset = h * headSize;
364+
365+
int attOffset = h * config.contextLength();
366+
367+
for (int t = 0; t <= position; t++) {
368+
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
369+
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
370+
score /= sqrtHeadSize;
371+
state.att.setFloat(attOffset + t, score);
372+
}
373+
374+
state.att.softmaxInPlace(attOffset, position + 1);
375+
376+
int xbOffset = h * headSize;
377+
state.xb.fillInPlace(xbOffset, headSize, 0f);
378+
379+
for (int t = 0; t <= position; t++) {
380+
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
381+
float a = state.att.getFloat(attOffset + t);
382+
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
383+
}
384+
});
385+
386+
// final matmul to get the output of the attention
387+
weights.wo[l].matmul(state.xb, state.xb2, dim, dim);
388+
389+
// residual connection back into x
390+
state.x.addInPlace(state.xb2);
391+
392+
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], 0, dim, config.rmsNormEps());
393+
394+
weights.wGateUp[l].matmul(state.xb, state.hb, 2 * config.hiddenDim(), dim);
395+
copyChunk(state.hb, state.hbG, 2 * config.hiddenDim(), config.hiddenDim(), 2, 0);
396+
copyChunk(state.hb, state.hbU, 2 * config.hiddenDim(), config.hiddenDim(), 2, 1);
397+
398+
state.hbG.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
399+
400+
state.hbU.multiplyInPlace(state.hbG);
401+
402+
weights.wDown[l].matmul(state.hbU, state.xb, dim, config.hiddenDim());
403+
404+
state.x.addInPlace(state.xb);
405+
}
406+
407+
// final rmsnorm
408+
rmsnorm(state.x, state.x, weights.rms_final_weight, 0, dim, config.rmsNormEps());
409+
410+
// classifier into logits
411+
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);
412+
413+
return state.logits;
414+
}
415+
416+
static void copyChunk(FloatTensor in, FloatTensor out, int dim1In, int dim1Out, int nChunks, int chunkNo) {
417+
assert (dim1In == dim1Out * nChunks);
418+
final int startOffsetInDim1 = chunkNo * dim1Out;
419+
Parallel.parallelFor(0, dim1Out, i -> {
420+
out.setFloat(i, in.getFloat(startOffsetInDim1 + i));
421+
});
422+
}
423+
311424
/**
312425
* Performs the initial embedding lookup and triggers the TornadoVM accelerated forward pass for an LLM token.
313426
*

src/main/java/com/example/inference/InferenceEngine.java

Lines changed: 151 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import com.example.tornadovm.TornadoVMMasterPlan;
1010
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1111

12+
import java.io.ByteArrayOutputStream;
1213
import java.util.ArrayList;
1314
import java.util.List;
1415
import java.util.Set;
@@ -18,8 +19,7 @@
1819
* Main entry point for LLM token generation.
1920
*
2021
* <p>
21-
* Orchestrates the complete inference process: ingests prompt tokens, then generates
22-
* new tokens until a stop condition is met. Supports both CPU and GPU execution.
22+
* Orchestrates the complete inference process: ingests prompt tokens, then generates new tokens until a stop condition is met. Supports both CPU and GPU execution.
2323
* </p>
2424
*
2525
* <p>
@@ -42,19 +42,26 @@ private InferenceEngine() {
4242
* LLM generation entry point, ingest prompt tokens and generates new tokens.
4343
*
4444
* <p>
45-
* All prompt tokens are ingested first, then inference starts, until a stop token is found.
46-
* The returned tokens only include generated/inferred tokens.
45+
* All prompt tokens are ingested first, then inference starts, until a stop token is found. The returned tokens only include generated/inferred tokens.
4746
*
48-
* @param model model to run inference (including weights, configuration, tokenizer ...)
49-
* @param state state of the model e.g. key/value caches ... this is mutated by this call
50-
* @param startPosition start prompt ingestion + inference at this position in the context e.g. useful if state was kept across calls (chained generation). 0 implies run with no previous context.
51-
* @param promptTokens prompt tokens to ingest, all the prompt tokens will be ingested, given there's enough capacity left in the context
52-
* @param stopTokens set of tokens that abort generation during inference, stop tokens do not affect prompt ingestion
53-
* @param maxTokens maximum number of tokens (can go up to {@link Configuration#contextLength context length}
54-
* if this value is negative or greater than {@link Configuration#contextLength context length}
55-
* @param sampler {@link Sampler strategy} used to select tokens
56-
* @param echo debugging flag, prints ALL, prompt and inferred tokens, to {@link System#err stderr}
57-
* @param onTokenGenerated callback, if non-null, it's called every time a token is inferred e.g. it's not called when ingesting prompt tokens
47+
* @param model
48+
* model to run inference (including weights, configuration, tokenizer ...)
49+
* @param state
50+
* state of the model e.g. key/value caches ... this is mutated by this call
51+
* @param startPosition
52+
* start prompt ingestion + inference at this position in the context e.g. useful if state was kept across calls (chained generation). 0 implies run with no previous context.
53+
* @param promptTokens
54+
* prompt tokens to ingest, all the prompt tokens will be ingested, given there's enough capacity left in the context
55+
* @param stopTokens
56+
* set of tokens that abort generation during inference, stop tokens do not affect prompt ingestion
57+
* @param maxTokens
58+
* maximum number of tokens (can go up to {@link Configuration#contextLength context length} if this value is negative or greater than {@link Configuration#contextLength context length}
59+
* @param sampler
60+
* {@link Sampler strategy} used to select tokens
61+
* @param echo
62+
* debugging flag, prints ALL, prompt and inferred tokens, to {@link System#err stderr}
63+
* @param onTokenGenerated
64+
* callback, if non-null, it's called every time a token is inferred e.g. it's not called when ingesting prompt tokens
5865
* @return list of generated/inferred tokens, including the stop token, if any e.g. does not include any token from the prompt
5966
*/
6067
public static List<Integer> generateTokensLlama(Model model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
@@ -214,6 +221,60 @@ public static List<Integer> generateTokensQwen3(Model model, State state, int st
214221
return generatedTokens;
215222
}
216223

224+
public static List<Integer> generateTokensPhi3(Model model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
225+
IntConsumer onTokenGenerated) {
226+
227+
long startNanos = System.nanoTime();
228+
if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) {
229+
maxTokens = model.configuration().contextLength();
230+
}
231+
List<Integer> generatedTokens = new ArrayList<>(maxTokens);
232+
int token = state.latestToken; // BOS?
233+
int nextToken;
234+
int promptIndex = 0;
235+
ByteArrayOutputStream baos = new ByteArrayOutputStream(5);
236+
for (int position = startPosition; position < maxTokens; ++position) {
237+
238+
model.forward(state, token, position);
239+
if (promptIndex < promptTokens.size()) {
240+
// Force-pick token from prompt.
241+
nextToken = promptTokens.get(promptIndex++);
242+
if (echo) {
243+
System.out.println("NextToken: " + nextToken);
244+
String decoded = model.tokenizer().decode(List.of(nextToken));
245+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
246+
}
247+
} else {
248+
nextToken = sampler.sampleToken(state.logits);
249+
if (echo) {
250+
// log inferred token
251+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
252+
}
253+
generatedTokens.add(nextToken);
254+
if (onTokenGenerated != null) {
255+
onTokenGenerated.accept(nextToken);
256+
}
257+
if (stopTokens.contains(nextToken)) {
258+
break;
259+
}
260+
}
261+
state.latestToken = token = nextToken;
262+
if (position == 2000) {
263+
break;
264+
}
265+
}
266+
267+
// Calculate and print performance metrics
268+
long endNanos = System.nanoTime();
269+
double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0;
270+
int totalTokens = promptIndex + generatedTokens.size();
271+
272+
LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds);
273+
274+
return generatedTokens;
275+
276+
}
277+
217278
public static List<Integer> generateTokensGPULlama(Model model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
218279
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
219280
// === Setup and Initialization ===
@@ -395,4 +456,79 @@ public static List<Integer> generateTokensGPUQwen3(Model model, State state, int
395456

396457
return generatedTokens;
397458
}
398-
}
459+
460+
public static List<Integer> generateTokensGPUPhi3(Model model, State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
461+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
462+
// Start timing the whole process
463+
long startNanos = System.nanoTime();
464+
long inferenceStartNanos = 0;
465+
466+
// Validate and adjust maxTokens if necessary
467+
if (maxTokens < 0 || model.configuration().contextLength() < maxTokens) {
468+
maxTokens = model.configuration().contextLength();
469+
}
470+
471+
// Storage for generated tokens
472+
List<Integer> generatedTokens = new ArrayList<>();
473+
474+
// Initialize token variables
475+
int currentToken = state.latestToken;
476+
int nextToken;
477+
int promptIndex = 0;
478+
int pos = startPosition;
479+
480+
while (pos < maxTokens) {
481+
// GPU Forward Pass
482+
FloatArray logits = InferenceCore.forwardTornadoVM(model, state, currentToken, pos, tornadoVMPlan);
483+
484+
// Handle token processing
485+
if (promptIndex < promptTokens.size()) {
486+
// We're still processing the prompt tokens
487+
nextToken = promptTokens.get(promptIndex++);
488+
if (echo) {
489+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
490+
}
491+
} else {
492+
// Mark the start of actual generation (after prompt processing)
493+
if (inferenceStartNanos == 0) {
494+
inferenceStartNanos = System.nanoTime();
495+
}
496+
497+
// Sample the next token
498+
nextToken = sampler.sampleToken(logits);
499+
500+
// Output the token if echo is enabled
501+
if (echo) {
502+
System.err.print(Tokenizer.replaceControlCharacters(model.tokenizer().decode(List.of(nextToken))));
503+
}
504+
505+
// Track the generated token
506+
generatedTokens.add(nextToken);
507+
508+
// Notify via callback if provided
509+
if (onTokenGenerated != null) {
510+
onTokenGenerated.accept(nextToken);
511+
}
512+
513+
// Check for stop condition
514+
if (stopTokens.contains(nextToken)) {
515+
break;
516+
}
517+
}
518+
519+
// Update for next iteration
520+
currentToken = nextToken;
521+
state.latestToken = currentToken;
522+
pos++;
523+
}
524+
525+
// Calculate and print performance metrics
526+
long endNanos = System.nanoTime();
527+
double totalTimeSeconds = (endNanos - startNanos) / 1_000_000_000.0;
528+
int totalTokens = promptIndex + generatedTokens.size();
529+
530+
LastRunMetrics.setMetrics(totalTokens, totalTimeSeconds);
531+
532+
return generatedTokens;
533+
}
534+
}

0 commit comments

Comments
 (0)