Skip to content

Commit 1d4b5bc

Browse files
committed
Replace static USE_TORNADOVM flag with dynamic Options property, removing redundant code and ensuring consistent TornadoVM usage configuration.
1 parent 37e1097 commit 1d4b5bc

File tree

6 files changed

+29
-29
lines changed

6 files changed

+29
-29
lines changed

src/main/java/org/beehive/gpullama3/LlamaApp.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ public class LlamaApp {
1919
// Configuration flags for hardware acceleration and optimizations
2020
public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration
2121
public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
22-
public static final boolean USE_TORNADOVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration
2322
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode
2423

2524
/**

src/main/java/org/beehive/gpullama3/Options.java

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
import java.nio.file.Paths;
66

77
public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive,
8-
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) {
8+
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo, boolean useTornadovm) {
99

1010
public static final int DEFAULT_MAX_TOKENS = 1024;
1111

1212
public Options {
13-
1413
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
1514
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
1615
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
@@ -25,6 +24,11 @@ static void require(boolean condition, String messageFormat, Object... args) {
2524
}
2625
}
2726

27+
private static boolean getDefaultTornadoVM() {
28+
return Boolean.parseBoolean(System.getProperty("use.tornadovm", "false"));
29+
}
30+
31+
2832
static void printUsage(PrintStream out) {
2933
out.println("Usage: jbang Llama3.java [options]");
3034
out.println();
@@ -58,8 +62,9 @@ public static Options getDefaultOptions() {
5862
boolean interactive = false;
5963
boolean stream = true;
6064
boolean echo = false;
65+
boolean useTornadoVM = getDefaultTornadoVM();
6166

62-
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo);
67+
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM);
6368
}
6469

6570
public static Options parseOptions(String[] args) {
@@ -75,6 +80,7 @@ public static Options parseOptions(String[] args) {
7580
boolean interactive = false;
7681
boolean stream = false;
7782
boolean echo = false;
83+
Boolean useTornadovm = null; // null means not specified via command line
7884

7985
for (int i = 0; i < args.length; i++) {
8086
String optionName = args[i];
@@ -116,7 +122,10 @@ public static Options parseOptions(String[] args) {
116122

117123
require(modelPath != null, "Missing argument: --model <path> is required");
118124

125+
if (useTornadovm == null) {
126+
useTornadovm = getDefaultTornadoVM();
127+
}
119128

120-
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo);
129+
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm);
121130
}
122131
}

src/main/java/org/beehive/gpullama3/model/Model.java

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import java.util.function.IntConsumer;
1818

1919
import static org.beehive.gpullama3.LlamaApp.SHOW_PERF_INTERACTIVE;
20-
import static org.beehive.gpullama3.LlamaApp.USE_TORNADOVM;
2120

2221
public interface Model {
2322

@@ -81,7 +80,7 @@ default void runInteractive(Sampler sampler, Options options) {
8180
Scanner in = new Scanner(System.in);
8281

8382
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
84-
if (USE_TORNADOVM && tornadoVMPlan == null) {
83+
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) {
8584
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
8685
}
8786

@@ -108,7 +107,7 @@ default void runInteractive(Sampler sampler, Options options) {
108107
};
109108

110109
// Choose between GPU and CPU path based on configuration
111-
if (USE_TORNADOVM) {
110+
if (Options.getDefaultOptions().useTornadovm()) {
112111
// GPU path using TornadoVM
113112
responseTokens = generateTokensGPU(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler,
114113
options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
@@ -143,7 +142,7 @@ default void runInteractive(Sampler sampler, Options options) {
143142
}
144143
} finally {
145144
// Clean up TornadoVM resources when exiting the chat loop
146-
if (USE_TORNADOVM && tornadoVMPlan != null) {
145+
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan != null) {
147146
try {
148147
tornadoVMPlan.freeTornadoExecutionPlan();
149148
} catch (Exception e) {
@@ -176,7 +175,7 @@ default String runInstructOnce(Sampler sampler, Options options) {
176175
}
177176

178177
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
179-
if (USE_TORNADOVM && tornadoVMPlan == null) {
178+
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) {
180179
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
181180
}
182181

@@ -195,9 +194,8 @@ default String runInstructOnce(Sampler sampler, Options options) {
195194

196195
Set<Integer> stopTokens = chatFormat.getStopTokens();
197196

198-
if (USE_TORNADOVM) {
199-
// GPU path using TornadoVM
200-
// Call generateTokensGPU without the token consumer parameter
197+
if (Options.getDefaultOptions().useTornadovm()) {
198+
// GPU path using TornadoVM - Call generateTokensGPU without the token consumer parameter
201199
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
202200
} else {
203201
// CPU path
@@ -208,7 +206,7 @@ default String runInstructOnce(Sampler sampler, Options options) {
208206
responseTokens.removeLast();
209207
}
210208

211-
String responseText = null;
209+
String responseText = "";
212210
if (!options.stream()) {
213211
responseText = tokenizer().decode(responseTokens);
214212
}
@@ -242,7 +240,7 @@ default String runInstructOnceLangChain4J(Sampler sampler, Options options, Cons
242240
}
243241

244242
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
245-
if (USE_TORNADOVM && tornadoVMPlan == null) {
243+
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) {
246244
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
247245
}
248246

@@ -262,9 +260,8 @@ default String runInstructOnceLangChain4J(Sampler sampler, Options options, Cons
262260

263261
Set<Integer> stopTokens = chatFormat.getStopTokens();
264262

265-
if (USE_TORNADOVM) {
266-
// GPU path using TornadoVM
267-
// Call generateTokensGPU without the token consumer parameter
263+
if (Options.getDefaultOptions().useTornadovm()) {
264+
// GPU path using TornadoVM Call generateTokensGPU without the token consumer parameter
268265
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
269266
} else {
270267
// CPU path

src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package org.beehive.gpullama3.model.loader;
22

3-
import org.beehive.gpullama3.LlamaApp;
3+
import org.beehive.gpullama3.Options;
44
import org.beehive.gpullama3.core.model.GGMLType;
55
import org.beehive.gpullama3.core.model.GGUF;
66
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
@@ -18,7 +18,6 @@
1818
import org.beehive.gpullama3.model.Configuration;
1919
import org.beehive.gpullama3.model.Model;
2020
import org.beehive.gpullama3.model.ModelType;
21-
import org.beehive.gpullama3.model.llama.Llama;
2221
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
2322
import uk.ac.manchester.tornado.api.types.HalfFloat;
2423
import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
@@ -34,11 +33,7 @@
3433
import java.util.Map;
3534
import java.util.function.IntFunction;
3635

37-
import static org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME;
38-
3936
public abstract class ModelLoader {
40-
private static final String TOKENIZER_LLAMA_3_MODEL = "gpt2";
41-
private static final String TOKENIZER_MISTRAL_MODEL = "llama";
4237

4338
protected FileChannel fileChannel;
4439
protected GGUF gguf;
@@ -223,7 +218,7 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configura
223218
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
224219
GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
225220

226-
if (LlamaApp.USE_TORNADOVM) {
221+
if (Options.getDefaultOptions().useTornadovm()) {
227222
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
228223
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
229224
}

src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package org.beehive.gpullama3.model.loader;
22

33
import org.beehive.gpullama3.LlamaApp;
4+
import org.beehive.gpullama3.Options;
45
import org.beehive.gpullama3.auxiliary.Timer;
56
import org.beehive.gpullama3.core.model.GGMLType;
67
import org.beehive.gpullama3.core.model.GGUF;
@@ -97,7 +98,7 @@ private Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configur
9798
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
9899
GGMLTensorEntry outputWeight = tensorEntries.get("output.weight"); // Phi3 always has separate output weight
99100

100-
if (LlamaApp.USE_TORNADOVM) {
101+
if (Options.getDefaultOptions().useTornadovm()) {
101102
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
102103
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
103104
}
@@ -155,6 +156,5 @@ public Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries,
155156
outputWeight.ggmlType() // weightType
156157
);
157158
}
158-
159159
// @formatter:on
160160
}

src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
package org.beehive.gpullama3.model.loader;
22

3-
import org.beehive.gpullama3.LlamaApp;
3+
import org.beehive.gpullama3.Options;
44
import org.beehive.gpullama3.auxiliary.Timer;
55
import org.beehive.gpullama3.core.model.GGMLType;
66
import org.beehive.gpullama3.core.model.GGUF;
@@ -101,7 +101,7 @@ public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, Configura
101101
GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
102102
GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
103103

104-
if (LlamaApp.USE_TORNADOVM) {
104+
if (Options.getDefaultOptions().useTornadovm()) {
105105
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
106106
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
107107
}

0 commit comments

Comments
 (0)