Skip to content

Commit 613062c

Browse files
Generalize instruct mode implementation for Llama and Mistral
1 parent 420a119 commit 613062c

File tree

6 files changed

+61
-134
lines changed

6 files changed

+61
-134
lines changed

src/main/java/com/example/inference/InferenceCore.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package com.example.inference;
22

3-
import com.example.aux.Parallel;
3+
import com.example.auxiliary.Parallel;
44
import com.example.core.model.tensor.FloatTensor;
55
import com.example.loader.weights.State;
66
import com.example.loader.weights.Weights;

src/main/java/com/example/inference/InferenceEngine.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package com.example.inference;
22

3-
import com.example.aux.LastRunMetrics;
3+
import com.example.auxiliary.LastRunMetrics;
44
import com.example.inference.sampler.Sampler;
55
import com.example.loader.weights.State;
66
import com.example.model.Configuration;

src/main/java/com/example/loader/weights/ModelLoader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ public static Mistral loadMistralModel(FileChannel fileChannel, GGUF gguf, int c
168168

169169
Weights weights = null;
170170
if (loadWeights) {
171-
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensorsWithMapping(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
171+
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
172172
weights = loadWeights(tensorEntries, config);
173173
}
174174
return new Mistral(config, tokenizer, weights);

src/main/java/com/example/model/Model.java

Lines changed: 58 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package com.example.model;
22

3-
import com.example.aux.LastRunMetrics;
4-
import com.example.aux.format.ChatFormat;
3+
import com.example.auxiliary.LastRunMetrics;
4+
import com.example.auxiliary.format.ChatFormat;
55
import com.example.inference.InferenceEngine;
66
import com.example.inference.sampler.Sampler;
77
import com.example.Options;
@@ -130,5 +130,60 @@ default void runInteractive(Sampler sampler, Options options) {
130130
}
131131
}
132132
}
133-
void runInstructOnce(Sampler sampler, Options options);
133+
134+
/**
135+
* Model agnostic default implementation for instruct mode.
136+
* @param sampler
137+
* @param options
138+
*/
139+
default void runInstructOnce(Sampler sampler, Options options) {
140+
State state = createNewState();
141+
ChatFormat chatFormat = ChatFormat.create(tokenizer());
142+
TornadoVMMasterPlan tornadoVMPlan = null;
143+
144+
List<Integer> promptTokens = new ArrayList<>();
145+
promptTokens.add(chatFormat.getBeginOfText());
146+
147+
if (options.systemPrompt() != null) {
148+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
149+
}
150+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
151+
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
152+
153+
List<Integer> responseTokens;
154+
155+
IntConsumer tokenConsumer = token -> {
156+
if (options.stream()) {
157+
if (tokenizer().shouldDisplayToken(token)) {
158+
System.out.print(tokenizer().decode(List.of(token)));
159+
}
160+
}
161+
};
162+
163+
Set<Integer> stopTokens = chatFormat.getStopTokens();
164+
165+
if (USE_TORNADOVM) {
166+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
167+
// Call generateTokensGPU without the token consumer parameter
168+
responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens,
169+
options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
170+
} else {
171+
responseTokens = InferenceEngine.generateTokens(this, state, 0, promptTokens, stopTokens,
172+
options.maxTokens(), sampler, options.echo(), tokenConsumer);
173+
}
174+
175+
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
176+
responseTokens.removeLast();
177+
}
178+
if (!options.stream()) {
179+
String responseText = tokenizer().decode(responseTokens);
180+
System.out.println(responseText);
181+
}
182+
183+
LastRunMetrics.printMetrics();
184+
185+
if (tornadoVMPlan != null) {
186+
tornadoVMPlan.freeTornadoExecutionPlan();
187+
}
188+
}
134189
}
Lines changed: 0 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,11 @@
11
package com.example.model.llama;
22

3-
import com.example.auxiliary.LastRunMetrics;
4-
import com.example.auxiliary.format.LlamaChatFormat;
5-
import com.example.inference.InferenceEngine;
6-
import com.example.inference.sampler.Sampler;
73
import com.example.model.Model;
8-
import com.example.Options;
94
import com.example.loader.weights.ModelLoader;
105
import com.example.loader.weights.State;
116
import com.example.loader.weights.Weights;
127
import com.example.tokenizer.impl.LlamaTokenizer;
138
import com.example.tokenizer.impl.Tokenizer;
14-
import com.example.tornadovm.TornadoVMMasterPlan;
15-
16-
import java.util.ArrayList;
17-
import java.util.List;
18-
import java.util.Set;
19-
import java.util.function.IntConsumer;
20-
21-
import static com.example.LlamaApp.USE_TORNADOVM;
229

2310
public record Llama(LlamaConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model {
2411
private static final int BATCH_SIZE = Integer.getInteger("llama.BatchSize", 16);
@@ -45,57 +32,5 @@ public State createNewState(int batchsize) {
4532
return state;
4633
}
4734

48-
@Override
49-
public void runInstructOnce(Sampler sampler, Options options) {
50-
State state = createNewState();
51-
LlamaChatFormat chatFormat = new LlamaChatFormat(getAsLlamaTokenizer());
52-
TornadoVMMasterPlan tornadoVMPlan = null;
53-
54-
List<Integer> promptTokens = new ArrayList<>();
55-
promptTokens.add(chatFormat.getBeginOfText());
56-
57-
if (options.systemPrompt() != null) {
58-
promptTokens.addAll(chatFormat.encodeMessage(new LlamaChatFormat.Message(LlamaChatFormat.Role.SYSTEM, options.systemPrompt())));
59-
}
60-
promptTokens.addAll(chatFormat.encodeMessage(new LlamaChatFormat.Message(LlamaChatFormat.Role.USER, options.prompt())));
61-
promptTokens.addAll(chatFormat.encodeHeader(new LlamaChatFormat.Message(LlamaChatFormat.Role.ASSISTANT, "")));
62-
List<Integer> responseTokens;
63-
64-
// Define the token consumer
65-
IntConsumer tokenConsumer = token -> {
66-
if (options.stream()) {
67-
if (!tokenizer.isSpecialToken(token)) {
68-
System.out.print(tokenizer.decode(List.of(token)));
69-
}
70-
}
71-
};
72-
73-
Set<Integer> stopTokens = chatFormat.getStopTokens();
74-
if (USE_TORNADOVM) {
75-
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
76-
// Call generateTokensGPU without the token consumer parameter
77-
responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens,
78-
options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
79-
} else {
80-
// CPU path still uses the token consumer
81-
responseTokens = InferenceEngine.generateTokens(this, state, 0, promptTokens, stopTokens,
82-
options.maxTokens(), sampler, options.echo(), tokenConsumer);
83-
}
84-
85-
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
86-
responseTokens.removeLast();
87-
}
88-
if (!options.stream()) {
89-
String responseText = tokenizer.decode(responseTokens);
90-
System.out.println(responseText);
91-
}
92-
93-
LastRunMetrics.printMetrics();
94-
95-
if (tornadoVMPlan != null) {
96-
tornadoVMPlan.freeTornadoExecutionPlan();
97-
}
98-
}
99-
10035
}
10136

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,11 @@
11
package com.example.model.mistral;
22

3-
import com.example.auxiliary.LastRunMetrics;
4-
import com.example.auxiliary.format.MistralChatFormat;
5-
import com.example.inference.InferenceEngine;
6-
import com.example.inference.sampler.Sampler;
73
import com.example.model.Model;
8-
import com.example.Options;
94
import com.example.loader.weights.ModelLoader;
105
import com.example.loader.weights.State;
116
import com.example.loader.weights.Weights;
127
import com.example.tokenizer.impl.MistralTokenizer;
138
import com.example.tokenizer.impl.Tokenizer;
14-
import com.example.tornadovm.TornadoVMMasterPlan;
15-
16-
import java.util.ArrayList;
17-
import java.util.List;
18-
import java.util.Set;
19-
import java.util.function.IntConsumer;
20-
21-
import static com.example.LlamaApp.USE_TORNADOVM;
229

2310
/**
2411
* Llama class in mistral.java
@@ -45,54 +32,4 @@ public State createNewState(int batchsize) {
4532
return state;
4633
}
4734

48-
@Override
49-
public void runInstructOnce(Sampler sampler, Options options) {
50-
State state = createNewState();
51-
MistralChatFormat chatFormat = new MistralChatFormat(getAsMistralTokenizer());
52-
TornadoVMMasterPlan tornadoVMPlan = null;
53-
54-
List<Integer> promptTokens = new ArrayList<>();
55-
promptTokens.add(chatFormat.getBeginOfText());
56-
57-
if (options.suffix() != null) {
58-
promptTokens.addAll(chatFormat.encodeFillInTheMiddle(options.prompt(), options.suffix()));
59-
} else {
60-
promptTokens.addAll(chatFormat.encodeMessage(options.prompt(), true, true));
61-
}
62-
63-
List<Integer> responseTokens;
64-
Set<Integer> stopTokens = chatFormat.getStopTokens();
65-
IntConsumer tokenConsumer = token -> {
66-
if (options.stream()) {
67-
int tokenType = getAsMistralTokenizer().getTokenType(token);
68-
if (tokenType == 1 || tokenType == 6) {
69-
System.out.print(tokenizer.decode(List.of(token)));
70-
}
71-
}
72-
};
73-
74-
if (USE_TORNADOVM) {
75-
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
76-
// Call generateTokensGPU without the token consumer parameter
77-
responseTokens = InferenceEngine.generateTokensGPU(this, state, 0, promptTokens, stopTokens,
78-
options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
79-
} else {
80-
responseTokens = InferenceEngine.generateTokens(this, state, 0, promptTokens, stopTokens,
81-
options.maxTokens(), sampler, options.echo(), tokenConsumer);
82-
}
83-
84-
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
85-
responseTokens.removeLast();
86-
}
87-
if (!options.stream()) {
88-
String responseText = tokenizer.decode(responseTokens);
89-
System.out.println(responseText);
90-
}
91-
92-
LastRunMetrics.printMetrics();
93-
94-
if (tornadoVMPlan != null) {
95-
tornadoVMPlan.freeTornadoExecutionPlan();
96-
}
97-
}
9835
}

0 commit comments

Comments
 (0)