Skip to content

Commit 1a18ad4

Browse files
Decouple inference implementation from Model
1 parent 41a1733 commit 1a18ad4

File tree

5 files changed

+443
-582
lines changed

5 files changed

+443
-582
lines changed
Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,199 @@
1+
package com.example.inference;
2+
3+
import com.example.aux.Parallel;
4+
import com.example.core.model.tensor.FloatTensor;
5+
import com.example.loader.weights.State;
6+
import com.example.loader.weights.Weights;
7+
import com.example.model.Configuration;
8+
import com.example.model.Model;
9+
import com.example.tornadovm.TornadoVMMasterPlan;
10+
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
11+
12+
import java.lang.foreign.MemorySegment;
13+
import java.nio.FloatBuffer;
14+
15+
/**
16+
* Low-level operations for model inference.
17+
*
18+
* <p>
19+
* Provides core computational operations: RMS normalization and forward passes
20+
* through model layers. Supports both CPU and GPU implementations.
21+
*/
22+
23+
public final class InferenceCore {
24+
25+
private InferenceCore() {
26+
// prevent instantiation
27+
}
28+
29+
public static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) {
30+
// calculate sum of squares
31+
float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi);
32+
ss /= size;
33+
ss += rmsNormEps;
34+
ss = (float) (1.0 / Math.sqrt(ss));
35+
// normalize and scale
36+
final float finalss = ss; // for the lambda
37+
out.mapWithIndexInPlace(0, size, (value, index) -> weight.get(index) * (finalss * x.getFloat(index)));
38+
}
39+
40+
public static FloatTensor forwardJava(Model model, State state, int token, int position) {
41+
// a few convenience variables
42+
final Configuration config = model.configuration();
43+
final Weights weights = model.weights();
44+
int dim = config.dim();
45+
int headSize = config.headSize();
46+
int kvDim = (config.dim() * config.numberOfKeyValueHeads()) / config.numberOfHeads();
47+
int kvMul = config.numberOfHeads() / config.numberOfKeyValueHeads(); // integer multiplier of the kv sharing in multiquery
48+
float sqrtHeadSize = (float) Math.sqrt(headSize);
49+
50+
// copy the token embedding into x
51+
weights.token_embedding_table.copyTo(token * dim, state.x, 0, dim);
52+
53+
// forward all the layers
54+
for (int l = 0; l < config.numberOfLayers(); l++) {
55+
// attention rmsnorm
56+
rmsnorm(state.xb, state.x, weights.rms_att_weight[l], dim, config.rmsNormEps());
57+
58+
// qkv matmuls for this position
59+
60+
weights.wq[l].matmul(state.xb, state.q, dim, dim);
61+
weights.wk[l].matmul(state.xb, state.k, kvDim, dim);
62+
weights.wv[l].matmul(state.xb, state.v, kvDim, dim);
63+
64+
// RoPE relative positional encoding: complex-valued rotate q and k in each head
65+
for (int i = 0; i < dim; i += 2) {
66+
int head_dim = i % headSize;
67+
float fcr = weights.freq_cis_real.get(position * (headSize / 2) + (head_dim / 2));
68+
float fci = weights.freq_cis_imag.get(position * (headSize / 2) + (head_dim / 2));
69+
int rotn = i < kvDim ? 2 : 1; // how many vectors? 2 = q & k, 1 = q only
70+
for (int v = 0; v < rotn; v++) {
71+
FloatTensor vec = v == 0 ? state.q : state.k; // the vector to rotate (query or key)
72+
float v0 = vec.getFloat(i);
73+
float v1 = vec.getFloat(i + 1);
74+
vec.setFloat(i, v0 * fcr - v1 * fci);
75+
vec.setFloat(i + 1, v0 * fci + v1 * fcr);
76+
}
77+
}
78+
79+
// save key,value at this time step (position) to our kv cache
80+
//int loff = l * config.seq_len * kvDim;
81+
// kv cache layer offset for convenience
82+
state.k.copyTo(0, state.keyCache[l], position * kvDim, kvDim);
83+
state.v.copyTo(0, state.valueCache[l], position * kvDim, kvDim);
84+
85+
int curLayer = l;
86+
87+
// multihead attention. iterate over all heads
88+
Parallel.parallelFor(0, config.numberOfHeads(), h -> {
89+
// get the query vector for this head
90+
// float* q = s.q + h * headSize;
91+
int qOffset = h * headSize;
92+
93+
// attention scores for this head
94+
// float* att = s.att + h * config.seq_len;
95+
int attOffset = h * config.contextLength();
96+
97+
// iterate over all timesteps, including the current one
98+
for (int t = 0; t <= position; t++) {
99+
// get the key vector for this head and at this timestep
100+
// float* k = s.key_cache + loff + t * dim + h * headSize;
101+
int keyCacheOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
102+
// calculate the attention score as the dot product of q and k
103+
float score = state.q.dot(qOffset, state.keyCache[curLayer], keyCacheOffset, headSize);
104+
score /= sqrtHeadSize;
105+
// save the score to the attention buffer
106+
state.att.setFloat(attOffset + t, score);
107+
}
108+
109+
// softmax the scores to get attention weights, from 0..position inclusively
110+
state.att.softmaxInPlace(attOffset, position + 1);
111+
112+
// weighted sum of the values, store back into xb
113+
// float* xb = s.xb + h * headSize;
114+
int xbOffset = h * headSize;
115+
// memset(xb, 0, headSize * sizeof(float));
116+
state.xb.fillInPlace(xbOffset, headSize, 0f);
117+
118+
for (int t = 0; t <= position; t++) {
119+
// get the value vector for this head and at this timestep
120+
// float* v = s.value_cache + loff + t * dim + h * headSize;
121+
int vOffset = /* loff + */ t * kvDim + (h / kvMul) * headSize;
122+
// get the attention weight for this timestep
123+
float a = state.att.getFloat(attOffset + t);
124+
// accumulate the weighted value into xb
125+
state.xb.saxpyInPlace(xbOffset, state.valueCache[curLayer], vOffset, headSize, a);
126+
}
127+
});
128+
129+
// final matmul to get the output of the attention
130+
weights.wo[l].matmul(state.xb, state.xb2, dim, dim);
131+
132+
// residual connection back into x
133+
state.x.addInPlace(state.xb2);
134+
135+
// ffn rmsnorm
136+
rmsnorm(state.xb, state.x, weights.rms_ffn_weight[l], dim, config.rmsNormEps());
137+
138+
// Now for FFN in PyTorch we have: self.w2(F.silu(self.w1(x)) * self.w3(x))
139+
// first calculate self.w1(x) and self.w3(x)
140+
weights.w1[l].matmul(state.xb, state.hb, config.hiddenDim(), dim);
141+
weights.w3[l].matmul(state.xb, state.hb2, config.hiddenDim(), dim);
142+
143+
// SwiGLU non-linearity
144+
// silu(x)=x*σ(x), where σ(x) is the logistic sigmoid
145+
state.hb.mapInPlace(value -> value / (float) (1.0 + Math.exp(-value)));
146+
147+
// elementwise multiply with w3(x)
148+
state.hb.multiplyInPlace(state.hb2);
149+
150+
// final matmul to get the output of the ffn
151+
weights.w2[l].matmul(state.hb, state.xb, dim, config.hiddenDim());
152+
153+
// residual connection
154+
state.x.addInPlace(state.xb);
155+
}
156+
157+
rmsnorm(state.x, state.x, weights.rms_final_weight, dim, config.rmsNormEps());
158+
159+
weights.wcls.matmul(state.x, state.logits, config.vocabularySize(), dim);
160+
161+
return state.logits;
162+
}
163+
164+
/**
165+
* Performs the initial embedding lookup and triggers the TornadoVM accelerated forward pass for an LLM token.
166+
*
167+
* <p>This method handles the first phase of processing a token through the transformer model:
168+
* <ol>
169+
* <li>Copies the token embedding from the model's embedding table to the state's buffer</li>
170+
* <li>Delegates the transformer layer processing to TornadoVM through the master plan</li>
171+
* </ol>
172+
*
173+
* <p>The token embedding lookup happens on the CPU using {@link MemorySegment} operations,
174+
* while the subsequent transformer layers processing is offloaded to the accelerator through
175+
* TornadoVM for improved performance.
176+
*
177+
* @param model
178+
* The Llama model containing weights and configuration parameters
179+
* @param state
180+
* The current execution state holding input/output tensors and temporary buffers
181+
* @param token
182+
* The input token ID to process
183+
* @param position
184+
* The position of this token in the sequence context window
185+
* @param tornadoVMMasterPlan
186+
* The execution plan for TornadoVM acceleration
187+
* @return FloatTensor containing the output logits for token prediction
188+
*/
189+
public static FloatArray forwardTornadoVM(Model model, State state, int token, int position, TornadoVMMasterPlan tornadoVMMasterPlan) {
190+
final Configuration configuration = model.configuration();
191+
final Weights weights = model.weights();
192+
193+
MemorySegment.copy(weights.tokenEmbeddingTable.getSegment(), token * configuration.dim() * Float.BYTES, state.wrapX.getSegment(), 0, configuration.dim() * Float.BYTES);
194+
195+
return tornadoVMMasterPlan.tornadoVMForwardExecuteLayered(position);
196+
}
197+
198+
199+
}

0 commit comments

Comments
 (0)