Skip to content

Commit 9b68bf7

Browse files
Initial commit of Mistral port
1 parent 90a719f commit 9b68bf7

20 files changed

+1472
-710
lines changed

src/main/java/com/example/LlamaApp.java

Lines changed: 10 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,18 @@
11
package com.example;
22

33
import com.example.aot.AOT;
4-
import com.example.auxiliary.ChatFormat;
54
import com.example.core.model.tensor.FloatTensor;
65
import com.example.inference.CategoricalSampler;
76
import com.example.inference.Sampler;
87
import com.example.inference.ToppSampler;
9-
import com.example.inference.engine.impl.Llama;
8+
import com.example.inference.engine.impl.Model;
109
import com.example.inference.engine.impl.Options;
1110
import com.example.loader.weights.ModelLoader;
1211
import com.example.loader.weights.State;
1312
import com.example.tornadovm.FloatArrayUtils;
14-
import com.example.tornadovm.TornadoVMMasterPlan;
1513
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1614

1715
import java.io.IOException;
18-
import java.util.ArrayList;
19-
import java.util.List;
20-
import java.util.Scanner;
21-
import java.util.Set;
22-
import java.util.function.IntConsumer;
2316
import java.util.random.RandomGenerator;
2417
import java.util.random.RandomGeneratorFactory;
2518

@@ -115,156 +108,26 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
115108
return sampler;
116109
}
117110

118-
static void runInteractive(Llama model, Sampler sampler, Options options) {
119-
State state = null;
120-
List<Integer> conversationTokens = new ArrayList<>();
121-
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
122-
conversationTokens.add(chatFormat.beginOfText);
123-
if (options.systemPrompt() != null) {
124-
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
125-
}
126-
int startPosition = 0;
127-
Scanner in = new Scanner(System.in);
128-
129-
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
130-
TornadoVMMasterPlan tornadoVMPlan = null;
131-
132-
try {
133-
while (true) {
134-
System.out.print("> ");
135-
System.out.flush();
136-
String userText = in.nextLine();
137-
if (List.of("quit", "exit").contains(userText)) {
138-
break;
139-
}
140-
if (state == null) {
141-
state = model.createNewState();
142-
}
143-
144-
if (USE_TORNADOVM && tornadoVMPlan == null) {
145-
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
146-
}
147-
148-
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
149-
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
150-
Set<Integer> stopTokens = chatFormat.getStopTokens();
151-
152-
List<Integer> responseTokens;
153-
IntConsumer tokenConsumer = token -> {
154-
if (options.stream()) {
155-
if (!model.tokenizer().isSpecialToken(token)) {
156-
System.out.print(model.tokenizer().decode(List.of(token)));
157-
}
158-
}
159-
};
160-
161-
// Choose between GPU and CPU path based on configuration
162-
if (USE_TORNADOVM) {
163-
// GPU path using TornadoVM
164-
responseTokens = Llama.generateTokensGPU(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
165-
sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
166-
} else {
167-
// CPU path
168-
responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler,
169-
options.echo(), tokenConsumer);
170-
}
171-
172-
// Include stop token in the prompt history, but not in the response displayed to the user.
173-
conversationTokens.addAll(responseTokens);
174-
startPosition = conversationTokens.size();
175-
Integer stopToken = null;
176-
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
177-
stopToken = responseTokens.getLast();
178-
responseTokens.removeLast();
179-
}
180-
if (!options.stream()) {
181-
String responseText = model.tokenizer().decode(responseTokens);
182-
System.out.println(responseText);
183-
}
184-
if (stopToken == null) {
185-
System.err.println("\n Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX");
186-
break;
187-
}
188-
System.out.print("\n");
189-
190-
// Optionally print performance metrics after each response
191-
if (SHOW_PERF_INTERACTIVE) {
192-
Llama.LastRunMetrics.printMetrics();
193-
}
194-
}
195-
} finally {
196-
// Clean up TornadoVM resources when exiting the chat loop
197-
if (USE_TORNADOVM && tornadoVMPlan != null) {
198-
try {
199-
tornadoVMPlan.freeTornadoExecutionPlan();
200-
} catch (Exception e) {
201-
System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage());
202-
}
203-
}
204-
}
205-
}
206-
207-
static void runInstructOnce(Llama model, Sampler sampler, Options options) {
208-
State state = model.createNewState();
209-
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
210-
TornadoVMMasterPlan tornadoVMPlan = null;
211-
212-
List<Integer> promptTokens = new ArrayList<>();
213-
promptTokens.add(chatFormat.beginOfText);
214-
if (options.systemPrompt() != null) {
215-
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
216-
}
217-
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
218-
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
219-
List<Integer> responseTokens;
220-
221-
// Define the token consumer
222-
IntConsumer tokenConsumer = token -> {
223-
if (options.stream()) {
224-
if (!model.tokenizer().isSpecialToken(token)) {
225-
System.out.print(model.tokenizer().decode(List.of(token)));
226-
}
227-
}
228-
};
111+
// moved to model and became non-static
112+
//static void runInteractive(Model model, Sampler sampler, Options options)
229113

230-
Set<Integer> stopTokens = chatFormat.getStopTokens();
231-
if (USE_TORNADOVM) {
232-
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
233-
// Call generateTokensGPU without the token consumer parameter
234-
responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
235-
} else {
236-
// CPU path still uses the token consumer
237-
responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
238-
}
239-
240-
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
241-
responseTokens.removeLast();
242-
}
243-
if (!options.stream()) {
244-
String responseText = model.tokenizer().decode(responseTokens);
245-
System.out.println(responseText);
246-
}
247-
248-
Llama.LastRunMetrics.printMetrics();
249-
250-
if (tornadoVMPlan != null) {
251-
tornadoVMPlan.freeTornadoExecutionPlan();
252-
}
253-
}
114+
// moved to model and became non-static
115+
//static void runInstructOnce(Model model, Sampler sampler, Options options)
254116

255117
public static void main(String[] args) throws IOException {
256118
Options options = Options.parseOptions(args);
257-
Llama model;
119+
Model model;
258120
if (USE_AOT) {
259121
model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
260122
} else {
261123
model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
262124
}
263-
Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed());
125+
assert model != null;
126+
Sampler sampler = selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
264127
if (options.interactive()) {
265-
runInteractive(model, sampler, options);
128+
model.runInteractive(sampler, options);
266129
} else {
267-
runInstructOnce(model, sampler, options);
130+
model.runInstructOnce(sampler, options);
268131
}
269132
}
270133
}

src/main/java/com/example/aot/AOT.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import com.example.auxiliary.Timer;
44
import com.example.core.model.GGUF;
55
import com.example.core.model.tensor.GGMLTensorEntry;
6-
import com.example.inference.engine.impl.Llama;
6+
import com.example.inference.engine.impl.Model;
77
import com.example.inference.engine.impl.Options;
8+
import com.example.inference.engine.impl.llama.Llama;
89
import com.example.loader.weights.ModelLoader;
910
import com.example.loader.weights.Weights;
1011

@@ -45,7 +46,7 @@ private static PartialModel preLoadGGUF(String modelPath) {
4546
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
4647
return new PartialModel(
4748
path.getFileName().toString(),
48-
ModelLoader.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false),
49+
ModelLoader.loadLlamaModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false),
4950
gguf.getTensorDataOffset(),
5051
gguf.getTensorInfos()
5152
);
@@ -60,7 +61,7 @@ private static PartialModel preLoadGGUF(String modelPath) {
6061
* The file name (base name) must match with the preloaded file name.
6162
* No checksum/hash is checked for performance reasons.
6263
*/
63-
public static com.example.inference.engine.impl.Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
64+
public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
6465
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
6566
if (preLoaded == null) {
6667
return null; // no pre-loaded model stored
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package com.example.auxiliary.format;
2+
3+
import com.example.tokenizer.impl.Tokenizer;
4+
5+
public interface ChatFormat {
6+
7+
}

src/main/java/com/example/auxiliary/ChatFormat.java renamed to src/main/java/com/example/auxiliary/format/LlamaChatFormat.java

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
1-
package com.example.auxiliary;
1+
package com.example.auxiliary.format;
22

3+
import com.example.tokenizer.impl.LlamaTokenizer;
34
import com.example.tokenizer.impl.Tokenizer;
45

56
import java.util.ArrayList;
67
import java.util.List;
78
import java.util.Map;
89
import java.util.Set;
910

10-
public class ChatFormat {
11+
public class LlamaChatFormat implements ChatFormat {
1112

12-
final Tokenizer tokenizer;
13+
final LlamaTokenizer tokenizer;
1314
public final int beginOfText;
1415
final int endHeader;
1516
final int startHeader;
@@ -18,7 +19,7 @@ public class ChatFormat {
1819
final int endOfMessage;
1920
final Set<Integer> stopTokens;
2021

21-
public ChatFormat(Tokenizer tokenizer) {
22+
public LlamaChatFormat(LlamaTokenizer tokenizer) {
2223
this.tokenizer = tokenizer;
2324
Map<String, Integer> specialTokens = this.tokenizer.getSpecialTokens();
2425
this.beginOfText = specialTokens.get("<|begin_of_text|>");
@@ -38,42 +39,43 @@ public Set<Integer> getStopTokens() {
3839
return stopTokens;
3940
}
4041

41-
public List<Integer> encodeHeader(ChatFormat.Message message) {
42+
public List<Integer> encodeHeader(LlamaChatFormat.Message message) {
4243
List<Integer> tokens = new ArrayList<>();
44+
LlamaTokenizer llamaTokenizer = (LlamaTokenizer) this.tokenizer;
4345
tokens.add(startHeader);
44-
tokens.addAll(this.tokenizer.encodeAsList(message.role().name()));
46+
tokens.addAll(llamaTokenizer.encodeAsList(message.role().name()));
4547
tokens.add(endHeader);
46-
tokens.addAll(this.tokenizer.encodeAsList("\n"));
48+
tokens.addAll(llamaTokenizer.encodeAsList("\n"));
4749
return tokens;
4850
}
4951

50-
public List<Integer> encodeMessage(ChatFormat.Message message) {
52+
public List<Integer> encodeMessage(LlamaChatFormat.Message message) {
5153
List<Integer> tokens = this.encodeHeader(message);
5254
tokens.addAll(this.tokenizer.encodeAsList(message.content().strip()));
5355
tokens.add(endOfTurn);
5456
return tokens;
5557
}
5658

57-
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<ChatFormat.Message> dialog) {
59+
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<LlamaChatFormat.Message> dialog) {
5860
List<Integer> tokens = new ArrayList<>();
5961
tokens.add(beginOfText);
60-
for (ChatFormat.Message message : dialog) {
62+
for (LlamaChatFormat.Message message : dialog) {
6163
tokens.addAll(this.encodeMessage(message));
6264
}
6365
if (appendAssistantTurn) {
6466
// Add the start of an assistant message for the model to complete.
65-
tokens.addAll(this.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
67+
tokens.addAll(this.encodeHeader(new LlamaChatFormat.Message(LlamaChatFormat.Role.ASSISTANT, "")));
6668
}
6769
return tokens;
6870
}
6971

70-
public record Message(ChatFormat.Role role, String content) {
72+
public record Message(LlamaChatFormat.Role role, String content) {
7173
}
7274

7375
public record Role(String name) {
74-
public static ChatFormat.Role SYSTEM = new ChatFormat.Role("system");
75-
public static ChatFormat.Role USER = new ChatFormat.Role("user");
76-
public static ChatFormat.Role ASSISTANT = new ChatFormat.Role("assistant");
76+
public static LlamaChatFormat.Role SYSTEM = new LlamaChatFormat.Role("system");
77+
public static LlamaChatFormat.Role USER = new LlamaChatFormat.Role("user");
78+
public static LlamaChatFormat.Role ASSISTANT = new LlamaChatFormat.Role("assistant");
7779

7880
@Override
7981
public String toString() {
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package com.example.auxiliary.format;
2+
3+
import com.example.tokenizer.impl.MistralTokenizer;
4+
5+
import java.util.*;
6+
7+
public class MistralChatFormat implements ChatFormat {
8+
9+
protected final MistralTokenizer tokenizer;
10+
protected final int unknownToken;
11+
protected final int beginOfText;
12+
protected final int endOfText;
13+
protected final int beginOfInstruction;
14+
protected final int endOfInstruction;
15+
protected final int toolCalls;
16+
protected final int beginOfAvailableTools;
17+
protected final int endOfAvailableTools;
18+
protected final int beginOfToolResults;
19+
protected final int endOfToolResults;
20+
protected final int prefix;
21+
protected final int middle;
22+
protected final int suffix;
23+
24+
public MistralChatFormat(MistralTokenizer tokenizer) {
25+
this.tokenizer = tokenizer;
26+
Map<String, Integer> specialTokens = this.tokenizer.getSpecialTokens();
27+
this.unknownToken = specialTokens.get("<unk>");
28+
this.beginOfText = specialTokens.get("<s>");
29+
this.endOfText = specialTokens.get("</s>");
30+
this.beginOfInstruction = specialTokens.get("[INST]");
31+
this.endOfInstruction = specialTokens.get("[/INST]");
32+
this.toolCalls = specialTokens.get("[TOOL_CALLS]");
33+
this.beginOfAvailableTools = specialTokens.get("[AVAILABLE_TOOLS]");
34+
this.endOfAvailableTools = specialTokens.get("[/AVAILABLE_TOOLS]");
35+
this.beginOfToolResults = specialTokens.get("[TOOL_RESULTS]");
36+
this.endOfToolResults = specialTokens.get("[/TOOL_RESULTS]");
37+
// Only Codestral supports FIM tokens.
38+
this.prefix = specialTokens.getOrDefault("[PREFIX]", unknownToken);
39+
this.suffix = specialTokens.getOrDefault("[SUFFIX]", unknownToken);
40+
this.middle = specialTokens.getOrDefault("[MIDDLE]", unknownToken);
41+
}
42+
43+
public int getBeginOfText() { return beginOfText; }
44+
45+
public Set<Integer> getStopTokens() {
46+
return Set.of(endOfText);
47+
}
48+
49+
public List<Integer> encodeMessage(String userMessage, boolean addHeader, boolean addFooter) {
50+
List<Integer> tokens = new ArrayList<>();
51+
if (addHeader) {
52+
tokens.add(this.beginOfInstruction);
53+
}
54+
if (userMessage != null) {
55+
tokens.addAll(this.tokenizer.encodeAsList(userMessage.strip()));
56+
}
57+
if (addFooter) {
58+
tokens.add(endOfInstruction);
59+
}
60+
return tokens;
61+
}
62+
63+
public List<Integer> encodeFillInTheMiddle(String prefix, String suffix) {
64+
List<Integer> tokens = new ArrayList<>();
65+
// dummy - empty string set to comply with encode method signature.
66+
final Set<String> EMPTY_STRING_SET = Collections.emptySet();
67+
tokens.add(this.suffix);
68+
tokens.addAll(tokenizer.encode(suffix, EMPTY_STRING_SET));
69+
tokens.add(this.prefix);
70+
tokens.addAll(tokenizer.encode(prefix, EMPTY_STRING_SET));
71+
return tokens;
72+
}
73+
}

0 commit comments

Comments
 (0)