Skip to content

Commit 65e4888

Browse files
authored
Merge pull request #17 from orionpapadakis/mistral
[model] Add support for Mistral models
2 parents 0d17405 + 2371a7c commit 65e4888

32 files changed

+1934
-1239
lines changed

README.md

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -164,11 +164,14 @@ Check models below.
164164

165165
## Download Model Files
166166

167-
Download `FP16` quantized .gguf files from:
167+
Download `FP16` quantized `Llama-3` .gguf files from:
168168
- https://huggingface.co/beehive-lab/Llama-3.2-1B-Instruct-GGUF-FP16
169169
- https://huggingface.co/beehive-lab/Llama-3.2-3B-Instruct-GGUF-FP16
170170
- https://huggingface.co/beehive-lab/Llama-3.2-8B-Instruct-GGUF-FP16
171171

172+
Download `FP16` quantized `Mistral` .gguf files from:
173+
- https://huggingface.co/collections/beehive-lab/mistral-gpullama3java-684afabb206136d2e9cd47e0
174+
172175
Please be gentle with [huggingface.co](https://huggingface.co) servers:
173176

174177
**Note** FP16 models are first-class citizens for the current version.
@@ -181,6 +184,9 @@ wget https://huggingface.co/beehive-lab/Llama-3.2-3B-Instruct-GGUF-FP16/resolve/
181184
182185
# Llama 3 (8B) - FP16
183186
wget https://huggingface.co/beehive-lab/Llama-3.2-8B-Instruct-GGUF-FP16/resolve/main/beehive-llama-3.2-8b-instruct-fp16.gguf
187+
188+
# Mistral (7B) - FP16
189+
wget https://huggingface.co/MaziyarPanahi/Mistral-7B-Instruct-v0.3-GGUF/resolve/main/Mistral-7B-Instruct-v0.3.fp16.gguf
184190
```
185191

186192
**[Experimental]** you can download the Q8 and Q4 used in the original implementation of Llama3.java, but for now are going to be dequanted to FP16 for TornadoVM support:
@@ -201,7 +207,7 @@ curl -L -O https://huggingface.co/mukel/Meta-Llama-3.1-8B-Instruct-GGUF/resolve/
201207

202208
## Running `llama-tornado`
203209

204-
To execute Llama3 models with TornadoVM on GPUs use the `llama-tornado` script with the `--gpu` flag.
210+
To execute Llama3, or Mistral models with TornadoVM on GPUs use the `llama-tornado` script with the `--gpu` flag.
205211

206212
### Usage Examples
207213

@@ -246,11 +252,11 @@ First, check your GPU specifications. If your GPU has high memory capacity, you
246252

247253
### GPU Memory Requirements by Model Size
248254

249-
| Model Size | Recommended GPU Memory |
250-
|------------|------------------------|
251-
| 1B models | 7GB (default) |
252-
| 3B models | 15GB |
253-
| 8B models | 20GB+ |
255+
| Model Size | Recommended GPU Memory |
256+
|-------------|------------------------|
257+
| 1B models | 7GB (default) |
258+
| 3-7B models | 15GB |
259+
| 8B models | 20GB+ |
254260

255261
**Note**: If you still encounter memory issues, try:
256262

@@ -288,6 +294,7 @@ LLaMA Configuration:
288294
Maximum number of tokens to generate (default: 512)
289295
--stream STREAM Enable streaming output (default: True)
290296
--echo ECHO Echo the input prompt (default: False)
297+
--suffix SUFFIX Suffix for fill-in-the-middle request (Codestral) (default: None)
291298

292299
Mode Selection:
293300
-i, --interactive Run in interactive/chat mode (default: False)

llama-tornado

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
#!/usr/bin/env python3
22
"""
3-
llama-tornado: GPU-accelerated LLaMA.java runner with TornadoVM
4-
Run LLaMA models using either OpenCL or PTX backends.
3+
llama-tornado: GPU-accelerated Java LLM runner with TornadoVM
4+
Run LLM models using either OpenCL or PTX backends.
55
"""
66

77
import argparse
@@ -19,7 +19,7 @@ class Backend(Enum):
1919
PTX = "ptx"
2020

2121
class LlamaRunner:
22-
"""Main class for managing LLaMA model execution with GPU acceleration."""
22+
"""Main class for managing LLM execution with GPU acceleration."""
2323

2424
def __init__(self):
2525
self.java_home = os.environ.get('JAVA_HOME')
@@ -266,30 +266,31 @@ def create_parser() -> argparse.ArgumentParser:
266266
"""Create and configure the argument parser."""
267267
parser = argparse.ArgumentParser(
268268
prog="llama-tornado",
269-
description="GPU-accelerated LLaMA.java model runner using TornadoVM",
269+
description="GPU-accelerated LLM runner using TornadoVM",
270270
formatter_class=argparse.ArgumentDefaultsHelpFormatter
271271
)
272272

273273
# Required arguments
274274
parser.add_argument("--model", dest="model_path", required=True,
275-
help="Path to the LLaMA model file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)")
275+
help="Path to the LLM gguf file (e.g., Llama-3.2-1B-Instruct-Q8_0.gguf)")
276276

277-
# LLaMA arguments
278-
llama_group = parser.add_argument_group("LLaMA Configuration")
279-
llama_group.add_argument("--prompt", help="Input prompt for the model")
280-
llama_group.add_argument("-sp", "--system-prompt", help="System prompt for the model")
281-
llama_group.add_argument("--temperature", type=float, default=0.1,
277+
# LLM arguments
278+
llm_group = parser.add_argument_group("LLaMA Configuration")
279+
llm_group.add_argument("--prompt", help="Input prompt for the model")
280+
llm_group.add_argument("-sp", "--system-prompt", help="System prompt for the model")
281+
llm_group.add_argument("--temperature", type=float, default=0.1,
282282
help="Sampling temperature (0.0 to 2.0)")
283-
llama_group.add_argument("--top-p", type=float, default=0.95,
283+
llm_group.add_argument("--top-p", type=float, default=0.95,
284284
help="Top-p sampling parameter")
285-
llama_group.add_argument("--seed", type=int, default=None,
285+
llm_group.add_argument("--seed", type=int, default=None,
286286
help="Random seed (default: current timestamp)")
287-
llama_group.add_argument("-n", "--max-tokens", type=int, default=512,
287+
llm_group.add_argument("-n", "--max-tokens", type=int, default=512,
288288
help="Maximum number of tokens to generate")
289-
llama_group.add_argument("--stream", type=bool, default=True,
289+
llm_group.add_argument("--stream", type=bool, default=True,
290290
help="Enable streaming output")
291-
llama_group.add_argument("--echo", type=bool, default=False,
291+
llm_group.add_argument("--echo", type=bool, default=False,
292292
help="Echo the input prompt")
293+
llm_group.add_argument("--suffix", help="Suffix for fill-in-the-middle request (Codestral)")
293294

294295
# Mode selection
295296
mode_group = parser.add_argument_group("Mode Selection")

src/main/java/com/example/LlamaApp.java

Lines changed: 9 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,16 @@
11
package com.example;
22

33
import com.example.aot.AOT;
4-
import com.example.auxiliary.ChatFormat;
54
import com.example.core.model.tensor.FloatTensor;
6-
import com.example.inference.CategoricalSampler;
7-
import com.example.inference.Sampler;
8-
import com.example.inference.ToppSampler;
9-
import com.example.inference.engine.impl.Llama;
10-
import com.example.inference.engine.impl.Options;
5+
import com.example.inference.sampler.CategoricalSampler;
6+
import com.example.inference.sampler.Sampler;
7+
import com.example.inference.sampler.ToppSampler;
8+
import com.example.model.Model;
119
import com.example.loader.weights.ModelLoader;
12-
import com.example.loader.weights.State;
1310
import com.example.tornadovm.FloatArrayUtils;
14-
import com.example.tornadovm.TornadoVMMasterPlan;
1511
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1612

1713
import java.io.IOException;
18-
import java.util.ArrayList;
19-
import java.util.List;
20-
import java.util.Scanner;
21-
import java.util.Set;
22-
import java.util.function.IntConsumer;
2314
import java.util.random.RandomGenerator;
2415
import java.util.random.RandomGeneratorFactory;
2516

@@ -115,156 +106,20 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
115106
return sampler;
116107
}
117108

118-
static void runInteractive(Llama model, Sampler sampler, Options options) {
119-
State state = null;
120-
List<Integer> conversationTokens = new ArrayList<>();
121-
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
122-
conversationTokens.add(chatFormat.beginOfText);
123-
if (options.systemPrompt() != null) {
124-
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
125-
}
126-
int startPosition = 0;
127-
Scanner in = new Scanner(System.in);
128-
129-
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
130-
TornadoVMMasterPlan tornadoVMPlan = null;
131-
132-
try {
133-
while (true) {
134-
System.out.print("> ");
135-
System.out.flush();
136-
String userText = in.nextLine();
137-
if (List.of("quit", "exit").contains(userText)) {
138-
break;
139-
}
140-
if (state == null) {
141-
state = model.createNewState();
142-
}
143-
144-
if (USE_TORNADOVM && tornadoVMPlan == null) {
145-
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
146-
}
147-
148-
conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText)));
149-
conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
150-
Set<Integer> stopTokens = chatFormat.getStopTokens();
151-
152-
List<Integer> responseTokens;
153-
IntConsumer tokenConsumer = token -> {
154-
if (options.stream()) {
155-
if (!model.tokenizer().isSpecialToken(token)) {
156-
System.out.print(model.tokenizer().decode(List.of(token)));
157-
}
158-
}
159-
};
160-
161-
// Choose between GPU and CPU path based on configuration
162-
if (USE_TORNADOVM) {
163-
// GPU path using TornadoVM
164-
responseTokens = Llama.generateTokensGPU(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(),
165-
sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
166-
} else {
167-
// CPU path
168-
responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler,
169-
options.echo(), tokenConsumer);
170-
}
171-
172-
// Include stop token in the prompt history, but not in the response displayed to the user.
173-
conversationTokens.addAll(responseTokens);
174-
startPosition = conversationTokens.size();
175-
Integer stopToken = null;
176-
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
177-
stopToken = responseTokens.getLast();
178-
responseTokens.removeLast();
179-
}
180-
if (!options.stream()) {
181-
String responseText = model.tokenizer().decode(responseTokens);
182-
System.out.println(responseText);
183-
}
184-
if (stopToken == null) {
185-
System.err.println("\n Ran out of context length...\n Increase context length with by passing to llama-tornado --max-tokens XXX");
186-
break;
187-
}
188-
System.out.print("\n");
189-
190-
// Optionally print performance metrics after each response
191-
if (SHOW_PERF_INTERACTIVE) {
192-
Llama.LastRunMetrics.printMetrics();
193-
}
194-
}
195-
} finally {
196-
// Clean up TornadoVM resources when exiting the chat loop
197-
if (USE_TORNADOVM && tornadoVMPlan != null) {
198-
try {
199-
tornadoVMPlan.freeTornadoExecutionPlan();
200-
} catch (Exception e) {
201-
System.err.println("Error while cleaning up TornadoVM resources: " + e.getMessage());
202-
}
203-
}
204-
}
205-
}
206-
207-
static void runInstructOnce(Llama model, Sampler sampler, Options options) {
208-
State state = model.createNewState();
209-
ChatFormat chatFormat = new ChatFormat(model.tokenizer());
210-
TornadoVMMasterPlan tornadoVMPlan = null;
211-
212-
List<Integer> promptTokens = new ArrayList<>();
213-
promptTokens.add(chatFormat.beginOfText);
214-
if (options.systemPrompt() != null) {
215-
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
216-
}
217-
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
218-
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
219-
List<Integer> responseTokens;
220-
221-
// Define the token consumer
222-
IntConsumer tokenConsumer = token -> {
223-
if (options.stream()) {
224-
if (!model.tokenizer().isSpecialToken(token)) {
225-
System.out.print(model.tokenizer().decode(List.of(token)));
226-
}
227-
}
228-
};
229-
230-
Set<Integer> stopTokens = chatFormat.getStopTokens();
231-
if (USE_TORNADOVM) {
232-
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, model);
233-
// Call generateTokensGPU without the token consumer parameter
234-
responseTokens = Llama.generateTokensGPU(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
235-
} else {
236-
// CPU path still uses the token consumer
237-
responseTokens = Llama.generateTokens(model, state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
238-
}
239-
240-
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
241-
responseTokens.removeLast();
242-
}
243-
if (!options.stream()) {
244-
String responseText = model.tokenizer().decode(responseTokens);
245-
System.out.println(responseText);
246-
}
247-
248-
Llama.LastRunMetrics.printMetrics();
249-
250-
if (tornadoVMPlan != null) {
251-
tornadoVMPlan.freeTornadoExecutionPlan();
252-
}
253-
}
254-
255109
public static void main(String[] args) throws IOException {
256110
Options options = Options.parseOptions(args);
257-
Llama model;
111+
Model model;
258112
if (USE_AOT) {
259113
model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
260114
} else {
261115
model = ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true);
262116
}
263-
Sampler sampler = selectSampler(model.configuration().vocabularySize, options.temperature(), options.topp(), options.seed());
117+
assert model != null;
118+
Sampler sampler = selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
264119
if (options.interactive()) {
265-
runInteractive(model, sampler, options);
120+
model.runInteractive(sampler, options);
266121
} else {
267-
runInstructOnce(model, sampler, options);
122+
model.runInstructOnce(sampler, options);
268123
}
269124
}
270125
}

src/main/java/com/example/inference/engine/impl/Options.java renamed to src/main/java/com/example/Options.java

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
package com.example.inference.engine.impl;
1+
package com.example;
22

33
import java.io.PrintStream;
44
import java.nio.file.Path;
55
import java.nio.file.Paths;
66

7-
public record Options(Path modelPath, String prompt, String systemPrompt, boolean interactive,
7+
public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive,
88
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) {
99

1010
public static final int DEFAULT_MAX_TOKENS = 1024;
1111

1212
public Options {
1313
require(modelPath != null, "Missing argument: --model <path> is required");
14-
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"" );
14+
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
1515
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
1616
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
1717
}
@@ -33,7 +33,8 @@ static void printUsage(PrintStream out) {
3333
out.println(" --interactive, --chat, -i run in chat mode");
3434
out.println(" --instruct run in instruct (once) mode, default mode");
3535
out.println(" --prompt, -p <string> input prompt");
36-
out.println(" --system-prompt, -sp <string> (optional) system prompt");
36+
out.println(" --system-prompt, -sp <string> (optional) system prompt (Llama models)");
37+
out.println(" --suffix <string> suffix for fill-in-the-middle request (Codestral)");
3738
out.println(" --temperature, -temp <float> temperature in [0,inf], default 0.1");
3839
out.println(" --top-p <float> p value in top-p (nucleus) sampling in [0,1] default 0.95");
3940
out.println(" --seed <long> random seed, default System.nanoTime()");
@@ -46,6 +47,7 @@ static void printUsage(PrintStream out) {
4647
public static Options parseOptions(String[] args) {
4748
String prompt = "Tell me a story with Java"; // Hardcoded for testing
4849
String systemPrompt = null;
50+
String suffix = null;
4951
float temperature = 0.1f;
5052
float topp = 0.95f;
5153
Path modelPath = null;
@@ -80,6 +82,7 @@ public static Options parseOptions(String[] args) {
8082
switch (optionName) {
8183
case "--prompt", "-p" -> prompt = nextArg;
8284
case "--system-prompt", "-sp" -> systemPrompt = nextArg;
85+
case "--suffix" -> suffix = nextArg;
8386
case "--temperature", "--temp" -> temperature = Float.parseFloat(nextArg);
8487
case "--top-p" -> topp = Float.parseFloat(nextArg);
8588
case "--model", "-m" -> modelPath = Paths.get(nextArg);
@@ -92,6 +95,6 @@ public static Options parseOptions(String[] args) {
9295
}
9396
}
9497
}
95-
return new Options(modelPath, prompt, systemPrompt, interactive, temperature, topp, seed, maxTokens, stream, echo);
98+
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo);
9699
}
97100
}

src/main/java/com/example/aot/AOT.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import com.example.auxiliary.Timer;
44
import com.example.core.model.GGUF;
55
import com.example.core.model.tensor.GGMLTensorEntry;
6-
import com.example.inference.engine.impl.Llama;
7-
import com.example.inference.engine.impl.Options;
6+
import com.example.model.Model;
7+
import com.example.Options;
8+
import com.example.model.llama.Llama;
89
import com.example.loader.weights.ModelLoader;
910
import com.example.loader.weights.Weights;
1011

@@ -45,7 +46,7 @@ private static PartialModel preLoadGGUF(String modelPath) {
4546
try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
4647
return new PartialModel(
4748
path.getFileName().toString(),
48-
ModelLoader.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false),
49+
Llama.loadModel(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false), // TODO: needs proper handling for AOT
4950
gguf.getTensorDataOffset(),
5051
gguf.getTensorInfos()
5152
);
@@ -60,7 +61,7 @@ private static PartialModel preLoadGGUF(String modelPath) {
6061
* The file name (base name) must match with the preloaded file name.
6162
* No checksum/hash is checked for performance reasons.
6263
*/
63-
public static com.example.inference.engine.impl.Llama tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
64+
public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
6465
AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
6566
if (preLoaded == null) {
6667
return null; // no pre-loaded model stored

0 commit comments

Comments
 (0)