Skip to content

Commit 89d5aa3

Browse files
mikepapadimorionpapadakis
authored andcommitted
Refactor TornadoVM integration and extend Mistral configuration.
Introduced `kvDim` and `kvMul` methods in `Configuration` and `MistralConfiguration` to enhance model configuration flexibility. Refactored TornadoVM classes to generalize handling of different models by replacing `Llama`-specific types with `Model` interface. Streamlined token generation logic to support conditional GPU execution with TornadoVM.
1 parent 9b68bf7 commit 89d5aa3

File tree

5 files changed

+55
-29
lines changed

5 files changed

+55
-29
lines changed

src/main/java/com/example/inference/engine/impl/Configuration.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,7 @@ public interface Configuration {
3232
/** Size of each attention head (derived from dim / numberOfHeads) */
3333
int headSize();
3434

35+
int kvDim();
36+
37+
int kvMul();
3538
}

src/main/java/com/example/inference/engine/impl/mistral/Mistral.java

Lines changed: 33 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020
import java.util.Set;
2121
import java.util.function.IntConsumer;
2222

23+
import static com.example.LlamaApp.USE_TORNADOVM;
24+
2325
/**
2426
* Llama class in mistral.java
2527
*/
2628
public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model {
2729

28-
/* For explicit use */
29-
private MistralTokenizer getAsMistralTokenizer() { return (MistralTokenizer) tokenizer; }
30-
3130
static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) {
3231
// calculate sum of squares
3332
float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi);
@@ -163,15 +162,20 @@ static FloatTensor forward(Mistral model, State state, int token, int position)
163162
return state.logits;
164163
}
165164

165+
/* For explicit use */
166+
private MistralTokenizer getAsMistralTokenizer() {
167+
return (MistralTokenizer) tokenizer;
168+
}
169+
166170
@Override
167-
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens,
168-
int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
171+
public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
172+
IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) {
169173
throw new UnsupportedOperationException("Mistral.generateTokensGPU is not implemented yet");
170174
}
171175

172176
@Override
173-
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens,
174-
int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) {
177+
public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo,
178+
IntConsumer onTokenGenerated) {
175179
long startNanos = System.nanoTime();
176180
if (maxTokens < 0 || configuration.contextLength() < maxTokens) {
177181
maxTokens = configuration.contextLength();
@@ -248,14 +252,15 @@ public void runInteractive(Sampler sampler, Options options) {
248252
}
249253
conversationTokens.addAll(chatFormat.encodeMessage(userText, true, true));
250254
Set<Integer> stopTokens = chatFormat.getStopTokens();
251-
List<Integer> responseTokens = generateTokens(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), token -> {
252-
if (options.stream()) {
253-
int tokenType = mistralTokenizer.getTokenType(token);
254-
if (tokenType == 1 || tokenType == 6) {
255-
System.out.print(mistralTokenizer.decode(List.of(token)));
256-
}
257-
}
258-
});
255+
List<Integer> responseTokens = generateTokens(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler,
256+
options.echo(), token -> {
257+
if (options.stream()) {
258+
int tokenType = mistralTokenizer.getTokenType(token);
259+
if (tokenType == 1 || tokenType == 6) {
260+
System.out.print(mistralTokenizer.decode(List.of(token)));
261+
}
262+
}
263+
});
259264
// Include stop token in the prompt history, but not in the response displayed to the user.
260265
conversationTokens.addAll(responseTokens);
261266
startPosition = conversationTokens.size();
@@ -288,15 +293,26 @@ public void runInstructOnce(Sampler sampler, Options options) {
288293
promptTokens.addAll(chatFormat.encodeMessage(options.prompt(), true, true));
289294
}
290295

296+
List<Integer> responseTokens;
291297
Set<Integer> stopTokens = chatFormat.getStopTokens();
292-
List<Integer> responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), token -> {
298+
IntConsumer tokenConsumer = token -> {
293299
if (options.stream()) {
294300
int tokenType = mistralTokenizer.getTokenType(token);
295301
if (tokenType == 1 || tokenType == 6) {
296302
System.out.print(mistralTokenizer.decode(List.of(token)));
297303
}
298304
}
299-
});
305+
};
306+
307+
TornadoVMMasterPlan tornadoVMPlan = null;
308+
if (USE_TORNADOVM) {
309+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
310+
// Call generateTokensGPU without the token consumer parameter
311+
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
312+
} else {
313+
responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
314+
}
315+
300316
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
301317
responseTokens.removeLast();
302318
}

src/main/java/com/example/inference/engine/impl/mistral/MistralConfiguration.java

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,16 @@
22

33
import com.example.inference.engine.impl.Configuration;
44

5-
public record MistralConfiguration(
6-
int dim, int hiddenDim, int numberOfLayers, int numberOfHeads,
7-
int numberOfKeyValueHeads, int vocabularySize, int contextLength,
8-
boolean sharedWeights, float rmsNormEps, float ropeTheta
9-
) implements Configuration {
5+
public record MistralConfiguration(int dim, int hiddenDim, int numberOfLayers, int numberOfHeads, int numberOfKeyValueHeads, int vocabularySize, int contextLength, boolean sharedWeights,
6+
float rmsNormEps, float ropeTheta) implements Configuration {
7+
8+
public int kvDim() {
9+
return dim * numberOfKeyValueHeads / numberOfHeads;
10+
}
11+
12+
public int kvMul() {
13+
return numberOfHeads / numberOfKeyValueHeads;
14+
}
1015

1116
public int headSize() {
1217
return dim / numberOfHeads;

src/main/java/com/example/tornadovm/TornadoVMLayerPlanner.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
package com.example.tornadovm;
22

33
import com.example.auxiliary.Tuple2;
4+
import com.example.inference.engine.impl.Configuration;
5+
import com.example.inference.engine.impl.Model;
46
import com.example.inference.engine.impl.llama.LlamaConfiguration;
57
import com.example.inference.engine.impl.llama.Llama;
68
import com.example.loader.weights.State;
@@ -49,7 +51,7 @@ public class TornadoVMLayerPlanner {
4951
private static final int THREAD_SCALE_FOR_LOGITS = 8;
5052

5153
private final State state;
52-
private final LlamaConfiguration config;
54+
private final Configuration config;
5355
private final Weights weights;
5456
private final KernelContext context;
5557

@@ -61,7 +63,7 @@ public class TornadoVMLayerPlanner {
6163
* @param model
6264
* The Llama model instance containing configuration and weights
6365
*/
64-
public TornadoVMLayerPlanner(State state, Llama model) {
66+
public TornadoVMLayerPlanner(State state, Model model) {
6567
this.state = state;
6668
this.config = model.configuration();
6769
this.weights = model.weights();

src/main/java/com/example/tornadovm/TornadoVMMasterPlan.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
package com.example.tornadovm;
22

33
import com.example.auxiliary.Tuple2;
4-
import com.example.inference.engine.impl.llama.LlamaConfiguration;
5-
import com.example.inference.engine.impl.llama.Llama;
4+
import com.example.inference.engine.impl.Configuration;
5+
import com.example.inference.engine.impl.Model;
66
import com.example.loader.weights.State;
77
import uk.ac.manchester.tornado.api.GridScheduler;
88
import uk.ac.manchester.tornado.api.ImmutableTaskGraph;
@@ -17,12 +17,12 @@ public class TornadoVMMasterPlan {
1717
private static final boolean ENABLE_TORNADOVM_INIT_TIME = Boolean.parseBoolean(System.getProperty("llama.EnableTimingForTornadoVMInit", "False"));
1818

1919
private final State state;
20-
private final LlamaConfiguration config;
20+
private final Configuration config;
2121
public GridScheduler scheduler;
2222
public TornadoExecutionPlan executionPlan;
2323
List<ImmutableTaskGraph> taskGraphs;
2424

25-
public TornadoVMMasterPlan(State state, Llama model, boolean isNvidia) {
25+
public TornadoVMMasterPlan(State state, Model model, boolean isNvidia) {
2626
TornadoVMLayerPlanner tornadoVMLayerPlanner = new TornadoVMLayerPlanner(state, model);
2727
Tuple2<List<ImmutableTaskGraph>, GridScheduler> tornadoVMPlan = isNvidia ? tornadoVMLayerPlanner.setupTornadoForwardPlanLayered() : tornadoVMLayerPlanner.setupTornadoForwardPlanLayeredNonNvidia();
2828
this.taskGraphs = tornadoVMPlan.getFirst();
@@ -43,7 +43,7 @@ public TornadoVMMasterPlan(State state, Llama model, boolean isNvidia) {
4343
* @param model The Llama model instance
4444
* @return The initialized TornadoVMMasterPlan ready for inference
4545
*/
46-
public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Llama model) {
46+
public static TornadoVMMasterPlan initializeTornadoVMPlan(State state, Model model) {
4747
// Initialize timing variables outside conditional blocks to avoid scope issues
4848
long startTime = System.nanoTime();
4949
long planCreationTime = 0;

0 commit comments

Comments
 (0)