Skip to content

Commit b6b693f

Browse files
Fully integrate TornadoVM for Mistral
1 parent 1a18ad4 commit b6b693f

File tree

2 files changed

+91
-43
lines changed

2 files changed

+91
-43
lines changed

src/main/java/com/example/model/llama/Llama.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,14 @@ public State createNewState(int batchsize) {
5151
public void runInteractive(Sampler sampler, Options options) {
5252
State state = null;
5353
List<Integer> conversationTokens = new ArrayList<>();
54+
5455
LlamaChatFormat chatFormat = new LlamaChatFormat(getAsLlamaTokenizer());
5556
conversationTokens.add(chatFormat.getBeginOfText());
57+
5658
if (options.systemPrompt() != null) {
5759
conversationTokens.addAll(chatFormat.encodeMessage(new LlamaChatFormat.Message(LlamaChatFormat.Role.SYSTEM, options.systemPrompt())));
5860
}
61+
5962
int startPosition = 0;
6063
Scanner in = new Scanner(System.in);
6164

@@ -71,6 +74,8 @@ public void runInteractive(Sampler sampler, Options options) {
7174
break;
7275
}
7376
if (state == null) {
77+
// State allocation can take some time for large context sizes,
78+
// allocate the model state only after printing the user '>' prompt.
7479
state = createNewState();
7580
}
7681

@@ -85,8 +90,8 @@ public void runInteractive(Sampler sampler, Options options) {
8590
List<Integer> responseTokens;
8691
IntConsumer tokenConsumer = token -> {
8792
if (options.stream()) {
88-
if (!tokenizer().isSpecialToken(token)) {
89-
System.out.print(tokenizer().decode(List.of(token)));
93+
if (!tokenizer.isSpecialToken(token)) {
94+
System.out.print(tokenizer.decode(List.of(token)));
9095
}
9196
}
9297
};
@@ -113,7 +118,7 @@ public void runInteractive(Sampler sampler, Options options) {
113118
responseTokens.removeLast();
114119
}
115120
if (!options.stream()) {
116-
String responseText = tokenizer().decode(responseTokens);
121+
String responseText = tokenizer.decode(responseTokens);
117122
System.out.println(responseText);
118123
}
119124
if (stopToken == null) {
@@ -143,10 +148,11 @@ public void runInteractive(Sampler sampler, Options options) {
143148
public void runInstructOnce(Sampler sampler, Options options) {
144149
State state = createNewState();
145150
LlamaChatFormat chatFormat = new LlamaChatFormat(getAsLlamaTokenizer());
146-
TornadoVMMasterPlan tornadoVMPlan =null;
151+
TornadoVMMasterPlan tornadoVMPlan = null;
147152

148153
List<Integer> promptTokens = new ArrayList<>();
149154
promptTokens.add(chatFormat.getBeginOfText());
155+
150156
if (options.systemPrompt() != null) {
151157
promptTokens.addAll(chatFormat.encodeMessage(new LlamaChatFormat.Message(LlamaChatFormat.Role.SYSTEM, options.systemPrompt())));
152158
}

src/main/java/com/example/model/mistral/Mistral.java

Lines changed: 81 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -48,51 +48,87 @@ public State createNewState(int batchsize) {
4848

4949
@Override
5050
public void runInteractive(Sampler sampler, Options options) {
51-
MistralTokenizer mistralTokenizer = getAsMistralTokenizer();
5251
State state = null;
53-
MistralChatFormat chatFormat = new MistralChatFormat(getAsMistralTokenizer());
5452
List<Integer> conversationTokens = new ArrayList<>();
53+
54+
MistralChatFormat chatFormat = new MistralChatFormat(getAsMistralTokenizer());
5555
conversationTokens.add(chatFormat.getBeginOfText());
56+
5657
int startPosition = 0;
5758
Scanner in = new Scanner(System.in);
58-
while (true) {
59-
System.out.print("> ");
60-
System.out.flush();
61-
if (state == null) {
62-
// State allocation can take some time for large context sizes,
63-
// allocate the model state only after printing the user '>' prompt.
64-
state = createNewState();
65-
}
66-
String userText = in.nextLine();
67-
if (List.of("quit", "exit").contains(userText)) {
68-
break;
69-
}
70-
conversationTokens.addAll(chatFormat.encodeMessage(userText, true, true));
71-
Set<Integer> stopTokens = chatFormat.getStopTokens();
72-
List<Integer> responseTokens = generateTokens(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler,
73-
options.echo(), token -> {
74-
if (options.stream()) {
75-
int tokenType = mistralTokenizer.getTokenType(token);
76-
if (tokenType == 1 || tokenType == 6) {
77-
System.out.print(mistralTokenizer.decode(List.of(token)));
78-
}
59+
60+
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
61+
TornadoVMMasterPlan tornadoVMPlan = null;
62+
63+
try {
64+
while (true) {
65+
System.out.print("> ");
66+
System.out.flush();
67+
String userText = in.nextLine();
68+
if (List.of("quit", "exit").contains(userText)) {
69+
break;
70+
}
71+
if (state == null) {
72+
// State allocation can take some time for large context sizes,
73+
// allocate the model state only after printing the user '>' prompt.
74+
state = createNewState();
75+
}
76+
77+
if (USE_TORNADOVM && tornadoVMPlan == null) {
78+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
79+
}
80+
81+
conversationTokens.addAll(chatFormat.encodeMessage(userText, true, true));
82+
Set<Integer> stopTokens = chatFormat.getStopTokens();
83+
84+
List<Integer> responseTokens;
85+
IntConsumer tokenConsumer = token -> {
86+
if (options.stream()) {
87+
if (!tokenizer.isSpecialToken(token)) {
88+
System.out.print(tokenizer.decode(List.of(token)));
7989
}
80-
});
81-
// Include stop token in the prompt history, but not in the response displayed to the user.
82-
conversationTokens.addAll(responseTokens);
83-
startPosition = conversationTokens.size();
84-
Integer stopToken = null;
85-
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
86-
stopToken = responseTokens.getLast();
87-
responseTokens.removeLast();
88-
}
89-
if (!options.stream()) {
90-
String responseText = mistralTokenizer.decode(responseTokens);
91-
System.out.println(responseText);
90+
}
91+
};
92+
93+
// Choose between GPU and CPU path based on configuration
94+
if (USE_TORNADOVM) {
95+
// GPU path using TornadoVM
96+
responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition,
97+
conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
98+
options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
99+
} else {
100+
// CPU path
101+
responseTokens = InferenceEngine.generateTokens(this, state, startPosition,
102+
conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
103+
options.maxTokens(), sampler, options.echo(), tokenConsumer);
104+
}
105+
106+
// Include stop token in the prompt history, but not in the response displayed to the user.
107+
conversationTokens.addAll(responseTokens);
108+
startPosition = conversationTokens.size();
109+
Integer stopToken = null;
110+
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
111+
stopToken = responseTokens.getLast();
112+
responseTokens.removeLast();
113+
}
114+
if (!options.stream()) {
115+
String responseText = tokenizer.decode(responseTokens);
116+
System.out.println(responseText);
117+
}
118+
if (stopToken == null) {
119+
System.err.println("Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX");
120+
break;
121+
}
122+
System.out.print("\n");
92123
}
93-
if (stopToken == null) {
94-
System.err.println("Ran out of context length...");
95-
break;
124+
} finally {
125+
// Clean up TornadoVM resources when exiting the chat loop
126+
if (USE_TORNADOVM && tornadoVMPlan != null) {
127+
try {
128+
tornadoVMPlan.freeTornadoExecutionPlan();
129+
} catch (Exception e) {
130+
System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage());
131+
}
96132
}
97133
}
98134
}
@@ -101,8 +137,11 @@ public void runInteractive(Sampler sampler, Options options) {
101137
public void runInstructOnce(Sampler sampler, Options options) {
102138
State state = createNewState();
103139
MistralChatFormat chatFormat = new MistralChatFormat(getAsMistralTokenizer());
140+
TornadoVMMasterPlan tornadoVMPlan = null;
141+
104142
List<Integer> promptTokens = new ArrayList<>();
105143
promptTokens.add(chatFormat.getBeginOfText());
144+
106145
if (options.suffix() != null) {
107146
promptTokens.addAll(chatFormat.encodeFillInTheMiddle(options.prompt(), options.suffix()));
108147
} else {
@@ -120,7 +159,6 @@ public void runInstructOnce(Sampler sampler, Options options) {
120159
}
121160
};
122161

123-
TornadoVMMasterPlan tornadoVMPlan = null;
124162
if (USE_TORNADOVM) {
125163
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
126164
// Call generateTokensGPU without the token consumer parameter
@@ -140,5 +178,9 @@ public void runInstructOnce(Sampler sampler, Options options) {
140178
}
141179

142180
LastRunMetrics.printMetrics();
181+
182+
if (tornadoVMPlan != null) {
183+
tornadoVMPlan.freeTornadoExecutionPlan();
184+
}
143185
}
144186
}

0 commit comments

Comments
 (0)