|
20 | 20 | import java.util.Set;
|
21 | 21 | import java.util.function.IntConsumer;
|
22 | 22 |
|
| 23 | +import static com.example.LlamaApp.USE_TORNADOVM; |
| 24 | + |
23 | 25 | /**
|
24 | 26 | * Llama class in mistral.java
|
25 | 27 | */
|
26 | 28 | public record Mistral(MistralConfiguration configuration, Tokenizer tokenizer, Weights weights) implements Model {
|
27 | 29 |
|
28 |
| - /* For explicit use */ |
29 |
| - private MistralTokenizer getAsMistralTokenizer() { return (MistralTokenizer) tokenizer; } |
30 |
| - |
31 | 30 | static void rmsnorm(FloatTensor out, FloatTensor x, FloatBuffer weight, int size, float rmsNormEps) {
|
32 | 31 | // calculate sum of squares
|
33 | 32 | float ss = x.reduce(0, size, 0f, (acc, xi) -> acc + xi * xi);
|
@@ -163,15 +162,20 @@ static FloatTensor forward(Mistral model, State state, int token, int position)
|
163 | 162 | return state.logits;
|
164 | 163 | }
|
165 | 164 |
|
| 165 | + /* For explicit use */ |
| 166 | + private MistralTokenizer getAsMistralTokenizer() { |
| 167 | + return (MistralTokenizer) tokenizer; |
| 168 | + } |
| 169 | + |
166 | 170 | @Override
|
167 |
| - public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, |
168 |
| - int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { |
| 171 | + public List<Integer> generateTokensGPU(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo, |
| 172 | + IntConsumer onTokenGenerated, TornadoVMMasterPlan tornadoVMPlan) { |
169 | 173 | throw new UnsupportedOperationException("Mistral.generateTokensGPU is not implemented yet");
|
170 | 174 | }
|
171 | 175 |
|
172 | 176 | @Override
|
173 |
| - public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, |
174 |
| - int maxTokens, Sampler sampler, boolean echo, IntConsumer onTokenGenerated) { |
| 177 | + public List<Integer> generateTokens(State state, int startPosition, List<Integer> promptTokens, Set<Integer> stopTokens, int maxTokens, Sampler sampler, boolean echo, |
| 178 | + IntConsumer onTokenGenerated) { |
175 | 179 | long startNanos = System.nanoTime();
|
176 | 180 | if (maxTokens < 0 || configuration.contextLength() < maxTokens) {
|
177 | 181 | maxTokens = configuration.contextLength();
|
@@ -248,14 +252,15 @@ public void runInteractive(Sampler sampler, Options options) {
|
248 | 252 | }
|
249 | 253 | conversationTokens.addAll(chatFormat.encodeMessage(userText, true, true));
|
250 | 254 | Set<Integer> stopTokens = chatFormat.getStopTokens();
|
251 |
| - List<Integer> responseTokens = generateTokens(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), token -> { |
252 |
| - if (options.stream()) { |
253 |
| - int tokenType = mistralTokenizer.getTokenType(token); |
254 |
| - if (tokenType == 1 || tokenType == 6) { |
255 |
| - System.out.print(mistralTokenizer.decode(List.of(token))); |
256 |
| - } |
257 |
| - } |
258 |
| - }); |
| 255 | + List<Integer> responseTokens = generateTokens(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, |
| 256 | + options.echo(), token -> { |
| 257 | + if (options.stream()) { |
| 258 | + int tokenType = mistralTokenizer.getTokenType(token); |
| 259 | + if (tokenType == 1 || tokenType == 6) { |
| 260 | + System.out.print(mistralTokenizer.decode(List.of(token))); |
| 261 | + } |
| 262 | + } |
| 263 | + }); |
259 | 264 | // Include stop token in the prompt history, but not in the response displayed to the user.
|
260 | 265 | conversationTokens.addAll(responseTokens);
|
261 | 266 | startPosition = conversationTokens.size();
|
@@ -288,15 +293,26 @@ public void runInstructOnce(Sampler sampler, Options options) {
|
288 | 293 | promptTokens.addAll(chatFormat.encodeMessage(options.prompt(), true, true));
|
289 | 294 | }
|
290 | 295 |
|
| 296 | + List<Integer> responseTokens; |
291 | 297 | Set<Integer> stopTokens = chatFormat.getStopTokens();
|
292 |
| - List<Integer> responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), token -> { |
| 298 | + IntConsumer tokenConsumer = token -> { |
293 | 299 | if (options.stream()) {
|
294 | 300 | int tokenType = mistralTokenizer.getTokenType(token);
|
295 | 301 | if (tokenType == 1 || tokenType == 6) {
|
296 | 302 | System.out.print(mistralTokenizer.decode(List.of(token)));
|
297 | 303 | }
|
298 | 304 | }
|
299 |
| - }); |
| 305 | + }; |
| 306 | + |
| 307 | + TornadoVMMasterPlan tornadoVMPlan = null; |
| 308 | + if (USE_TORNADOVM) { |
| 309 | + tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this); |
| 310 | + // Call generateTokensGPU without the token consumer parameter |
| 311 | + responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan); |
| 312 | + } else { |
| 313 | + responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer); |
| 314 | + } |
| 315 | + |
300 | 316 | if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
|
301 | 317 | responseTokens.removeLast();
|
302 | 318 | }
|
|
0 commit comments