Skip to content

Commit 420a119

Browse files
Generalize interactive mode implementation for Llama and Mistral
1 parent b6b693f commit 420a119

File tree

9 files changed

+216
-228
lines changed

9 files changed

+216
-228
lines changed
Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,61 @@
11
package com.example.auxiliary.format;
22

3-
import com.example.tokenizer.impl.Tokenizer;
3+
import com.example.tokenizer.impl.LlamaTokenizer;
4+
import com.example.tokenizer.impl.MistralTokenizer;
5+
6+
import java.util.List;
7+
import java.util.Set;
48

59
public interface ChatFormat {
610

11+
static ChatFormat create(Object tokenizer) {
12+
if (tokenizer instanceof LlamaTokenizer llamaTokenizer) {
13+
return new LlamaChatFormat(llamaTokenizer);
14+
} else if (tokenizer instanceof MistralTokenizer mistralTokenizer) {
15+
return new MistralChatFormat(mistralTokenizer);
16+
} else {
17+
throw new IllegalArgumentException("Unsupported tokenizer type: " + tokenizer.getClass().getName());
18+
}
19+
}
20+
21+
List<Integer> encodeHeader(Message message);
22+
List<Integer> encodeMessage(Message message);
723
int getBeginOfText();
24+
Set<Integer> getStopTokens();
25+
26+
/**
27+
* Represents a single message in a LLM chat session.
28+
*
29+
* Each message is associated with a specific role (system, user, or assistant)
30+
* and contains the textual content of that message.
31+
*
32+
* @param role the participant who issued the message (SYSTEM, USER, or ASSISTANT).
33+
* @param content the textual content of the message
34+
*/
35+
record Message(Role role, String content) {
36+
}
37+
38+
/**
39+
* Represents the role of a participant in a LLM chat conversation
40+
*
41+
* There are three standard roles:
42+
* <ul>
43+
* <li><strong>SYSTEM</strong> - sets the behavior and context of the assistant at the start of the conversation.</li>
44+
* <li><strong>USER</strong> - represents input from the human user.</li>
45+
* <li><strong>ASSISTANT</strong> - represents output from the AI assistant.</li>
46+
* </ul>
47+
*
48+
* @param name the string representation of the role
49+
*/
50+
record Role(String name) {
51+
public static Role SYSTEM = new Role("system");
52+
public static Role USER = new Role("user");
53+
public static Role ASSISTANT = new Role("assistant");
54+
55+
@Override
56+
public String toString() {
57+
return name;
58+
}
59+
}
860

961
}

src/main/java/com/example/auxiliary/format/LlamaChatFormat.java

Lines changed: 16 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ public class LlamaChatFormat implements ChatFormat {
2121

2222
public LlamaChatFormat(LlamaTokenizer tokenizer) {
2323
this.tokenizer = tokenizer;
24-
Map<String, Integer> specialTokens = this.tokenizer.getSpecialTokens();
24+
Map<String, Integer> specialTokens = tokenizer.getSpecialTokens();
2525
this.beginOfText = specialTokens.get("<|begin_of_text|>");
2626
this.startHeader = specialTokens.get("<|start_header_id|>");
2727
this.endHeader = specialTokens.get("<|end_header_id|>");
@@ -31,60 +31,40 @@ public LlamaChatFormat(LlamaTokenizer tokenizer) {
3131
this.stopTokens = Set.of(endOfText, endOfTurn);
3232
}
3333

34-
public Tokenizer getTokenizer() {
35-
return tokenizer;
36-
}
34+
@Override
35+
public int getBeginOfText() { return beginOfText; }
3736

38-
public Set<Integer> getStopTokens() {
39-
return stopTokens;
40-
}
37+
@Override
38+
public Set<Integer> getStopTokens() { return stopTokens; }
4139

42-
public List<Integer> encodeHeader(LlamaChatFormat.Message message) {
40+
@Override
41+
public List<Integer> encodeHeader(Message message) {
4342
List<Integer> tokens = new ArrayList<>();
44-
LlamaTokenizer llamaTokenizer = (LlamaTokenizer) this.tokenizer;
4543
tokens.add(startHeader);
46-
tokens.addAll(llamaTokenizer.encodeAsList(message.role().name()));
44+
tokens.addAll(tokenizer.encodeAsList(message.role().name()));
4745
tokens.add(endHeader);
48-
tokens.addAll(llamaTokenizer.encodeAsList("\n"));
46+
tokens.addAll(tokenizer.encodeAsList("\n"));
4947
return tokens;
5048
}
5149

52-
public List<Integer> encodeMessage(LlamaChatFormat.Message message) {
53-
List<Integer> tokens = this.encodeHeader(message);
54-
tokens.addAll(this.tokenizer.encodeAsList(message.content().strip()));
50+
@Override
51+
public List<Integer> encodeMessage(Message message) {
52+
List<Integer> tokens = encodeHeader(message);
53+
tokens.addAll(tokenizer.encodeAsList(message.content().strip()));
5554
tokens.add(endOfTurn);
5655
return tokens;
5756
}
5857

59-
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<LlamaChatFormat.Message> dialog) {
58+
public List<Integer> encodeDialogPrompt(boolean appendAssistantTurn, List<Message> dialog) {
6059
List<Integer> tokens = new ArrayList<>();
6160
tokens.add(beginOfText);
6261
for (LlamaChatFormat.Message message : dialog) {
63-
tokens.addAll(this.encodeMessage(message));
62+
tokens.addAll(encodeMessage(message));
6463
}
6564
if (appendAssistantTurn) {
6665
// Add the start of an assistant message for the model to complete.
67-
tokens.addAll(this.encodeHeader(new LlamaChatFormat.Message(LlamaChatFormat.Role.ASSISTANT, "")));
66+
tokens.addAll(encodeHeader(new Message(ChatFormat.Role.ASSISTANT, "")));
6867
}
6968
return tokens;
7069
}
71-
72-
@Override
73-
public int getBeginOfText() {
74-
return beginOfText;
75-
}
76-
77-
public record Message(LlamaChatFormat.Role role, String content) {
78-
}
79-
80-
public record Role(String name) {
81-
public static LlamaChatFormat.Role SYSTEM = new LlamaChatFormat.Role("system");
82-
public static LlamaChatFormat.Role USER = new LlamaChatFormat.Role("user");
83-
public static LlamaChatFormat.Role ASSISTANT = new LlamaChatFormat.Role("assistant");
84-
85-
@Override
86-
public String toString() {
87-
return name;
88-
}
89-
}
9070
}

src/main/java/com/example/auxiliary/format/MistralChatFormat.java

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public class MistralChatFormat implements ChatFormat {
2323

2424
public MistralChatFormat(MistralTokenizer tokenizer) {
2525
this.tokenizer = tokenizer;
26-
Map<String, Integer> specialTokens = this.tokenizer.getSpecialTokens();
26+
Map<String, Integer> specialTokens = tokenizer.getSpecialTokens();
2727
this.unknownToken = specialTokens.get("<unk>");
2828
this.beginOfText = specialTokens.get("<s>");
2929
this.endOfText = specialTokens.get("</s>");
@@ -43,8 +43,25 @@ public MistralChatFormat(MistralTokenizer tokenizer) {
4343
@Override
4444
public int getBeginOfText() { return beginOfText; }
4545

46-
public Set<Integer> getStopTokens() {
47-
return Set.of(endOfText);
46+
@Override
47+
public Set<Integer> getStopTokens() { return Set.of(endOfText); }
48+
49+
@Override
50+
public List<Integer> encodeHeader(Message message) {
51+
List<Integer> tokens = new ArrayList<>();
52+
tokens.add(beginOfInstruction);
53+
tokens.addAll(tokenizer.encodeAsList(message.role().name()));
54+
tokens.add(endOfInstruction);
55+
return tokens;
56+
}
57+
58+
@Override
59+
public List<Integer> encodeMessage(Message message) {
60+
List<Integer> tokens = encodeHeader(message);
61+
//tokens.add(beginOfInstruction);
62+
tokens.addAll(tokenizer.encodeAsList(message.content().strip()));
63+
tokens.add(endOfInstruction);
64+
return tokens;
4865
}
4966

5067
public List<Integer> encodeMessage(String userMessage, boolean addHeader, boolean addFooter) {

src/main/java/com/example/model/Model.java

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package com.example.model;
22

3+
import com.example.aux.LastRunMetrics;
4+
import com.example.aux.format.ChatFormat;
5+
import com.example.inference.InferenceEngine;
36
import com.example.inference.sampler.Sampler;
47
import com.example.Options;
58
import com.example.loader.weights.ModelLoader.ModelType;
@@ -8,10 +11,15 @@
811
import com.example.tokenizer.impl.Tokenizer;
912
import com.example.tornadovm.TornadoVMMasterPlan;
1013

14+
import java.util.ArrayList;
1115
import java.util.List;
16+
import java.util.Scanner;
1217
import java.util.Set;
1318
import java.util.function.IntConsumer;
1419

20+
import static com.example.LlamaApp.SHOW_PERF_INTERACTIVE;
21+
import static com.example.LlamaApp.USE_TORNADOVM;
22+
1523
public interface Model {
1624
Configuration configuration();
1725
Tokenizer tokenizer();
@@ -22,6 +30,105 @@ public interface Model {
2230
State createNewState();
2331
State createNewState(int batchsize);
2432

25-
void runInteractive(Sampler sampler, Options options);
33+
/**
34+
* Model agnostic default implementation for interactive mode.
35+
* @param sampler
36+
* @param options
37+
*/
38+
default void runInteractive(Sampler sampler, Options options) {
39+
State state = null;
40+
List<Integer> conversationTokens = new ArrayList<>();
41+
42+
ChatFormat chatFormat = ChatFormat.create(tokenizer());
43+
conversationTokens.add(chatFormat.getBeginOfText());
44+
45+
if (options.systemPrompt() != null) {
46+
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
47+
}
48+
49+
int startPosition = 0;
50+
Scanner in = new Scanner(System.in);
51+
52+
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
53+
TornadoVMMasterPlan tornadoVMPlan = null;
54+
55+
try {
56+
while (true) {
57+
System.out.print("> ");
58+
System.out.flush();
59+
String userText = in.nextLine();
60+
if (List.of("quit", "exit").contains(userText)) {
61+
break;
62+
}
63+
if (state == null) {
64+
// State allocation can take some time for large context sizes,
65+
// allocate the model state only after printing the user '>' prompt.
66+
state = createNewState();
67+
}
68+
69+
if (USE_TORNADOVM && tornadoVMPlan == null) {
70+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
71+
}
72+
73+
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
74+
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
75+
Set<Integer> stopTokens = chatFormat.getStopTokens();
76+
77+
List<Integer> responseTokens;
78+
IntConsumer tokenConsumer = token -> {
79+
if (options.stream()) {
80+
if (tokenizer().shouldDisplayToken(token)) {
81+
System.out.print(tokenizer().decode(List.of(token)));
82+
}
83+
}
84+
};
85+
86+
// Choose between GPU and CPU path based on configuration
87+
if (USE_TORNADOVM) {
88+
// GPU path using TornadoVM
89+
responseTokens = InferenceEngine.generateTokensGPU(this, state, startPosition,
90+
conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
91+
options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
92+
} else {
93+
// CPU path
94+
responseTokens = InferenceEngine.generateTokens(this, state, startPosition,
95+
conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens,
96+
options.maxTokens(), sampler, options.echo(), tokenConsumer);
97+
}
98+
99+
// Include stop token in the prompt history, but not in the response displayed to the user.
100+
conversationTokens.addAll(responseTokens);
101+
startPosition = conversationTokens.size();
102+
Integer stopToken = null;
103+
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
104+
stopToken = responseTokens.getLast();
105+
responseTokens.removeLast();
106+
}
107+
if (!options.stream()) {
108+
String responseText = tokenizer().decode(responseTokens);
109+
System.out.println(responseText);
110+
}
111+
if (stopToken == null) {
112+
System.err.println("\n Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX");
113+
break;
114+
}
115+
System.out.print("\n");
116+
117+
// Optionally print performance metrics after each response
118+
if (SHOW_PERF_INTERACTIVE) {
119+
LastRunMetrics.printMetrics();
120+
}
121+
}
122+
} finally {
123+
// Clean up TornadoVM resources when exiting the chat loop
124+
if (USE_TORNADOVM && tornadoVMPlan != null) {
125+
try {
126+
tornadoVMPlan.freeTornadoExecutionPlan();
127+
} catch (Exception e) {
128+
System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage());
129+
}
130+
}
131+
}
132+
}
26133
void runInstructOnce(Sampler sampler, Options options);
27134
}

0 commit comments

Comments
 (0)