Skip to content

Commit d7b237d

Browse files
authored
Merge pull request #45 from mikepapadim/feat/api
Refactor to work with Langchain4J as API
2 parents efbe261 + 03591c5 commit d7b237d

File tree

10 files changed

+191
-66
lines changed

10 files changed

+191
-66
lines changed

src/main/java/org/beehive/gpullama3/LlamaApp.java

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
package org.beehive.gpullama3;
22

33
import org.beehive.gpullama3.aot.AOT;
4+
import org.beehive.gpullama3.auxiliary.LastRunMetrics;
45
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
56
import org.beehive.gpullama3.inference.sampler.CategoricalSampler;
67
import org.beehive.gpullama3.inference.sampler.Sampler;
78
import org.beehive.gpullama3.inference.sampler.ToppSampler;
8-
import org.beehive.gpullama3.model.loader.ModelLoader;
99
import org.beehive.gpullama3.model.Model;
10+
import org.beehive.gpullama3.model.loader.ModelLoader;
1011
import org.beehive.gpullama3.tornadovm.FloatArrayUtils;
1112
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
1213

@@ -18,7 +19,6 @@ public class LlamaApp {
1819
// Configuration flags for hardware acceleration and optimizations
1920
public static final boolean USE_VECTOR_API = Boolean.parseBoolean(System.getProperty("llama.VectorAPI", "true")); // Enable Java Vector API for CPU acceleration
2021
public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
21-
public static final boolean USE_TORNADOVM = Boolean.parseBoolean(System.getProperty("use.tornadovm", "false")); // Use TornadoVM for GPU acceleration
2222
public static final boolean SHOW_PERF_INTERACTIVE = Boolean.parseBoolean(System.getProperty("llama.ShowPerfInteractive", "true")); // Show performance metrics in interactive mode
2323

2424
/**
@@ -36,27 +36,29 @@ public class LlamaApp {
3636
* <p>The method handles both {@link FloatTensor} and {@link FloatArray} logits types
3737
* to support both CPU and GPU execution paths.</p>
3838
*
39-
* @param vocabularySize The size of the model's vocabulary
40-
* @param temperature A value controlling randomness in sampling:
41-
* <ul>
42-
* <li>0.0f: No randomness (greedy sampling)</li>
43-
* <li>1.0f: Standard sampling from unmodified distribution</li>
44-
* <li>&lt;1.0f: More deterministic (sharper distribution)</li>
45-
* <li>&gt;1.0f: More random (flatter distribution)</li>
46-
* </ul>
47-
* @param topp The cumulative probability threshold for nucleus sampling (0.0-1.0).
48-
* <ul>
49-
* <li>Values ≤0 or ≥1: Disables top-p sampling</li>
50-
* <li>Values in (0,1): Restricts sampling to tokens comprising the top p probability mass</li>
51-
* </ul>
52-
* @param rngSeed Seed value for the random number generator to ensure reproducibility
53-
*
54-
* @return A configured {@link Sampler} that implements the selected sampling strategy
55-
* and handles both tensor and array-based logits
56-
*
57-
* @throws IllegalArgumentException if logits are of an unsupported type
39+
* @param vocabularySize
40+
* The size of the model's vocabulary
41+
* @param temperature
42+
* A value controlling randomness in sampling:
43+
* <ul>
44+
* <li>0.0f: No randomness (greedy sampling)</li>
45+
* <li>1.0f: Standard sampling from unmodified distribution</li>
46+
* <li>&lt;1.0f: More deterministic (sharper distribution)</li>
47+
* <li>&gt;1.0f: More random (flatter distribution)</li>
48+
* </ul>
49+
* @param topp
50+
* The cumulative probability threshold for nucleus sampling (0.0-1.0).
51+
* <ul>
52+
* <li>Values ≤0 or ≥1: Disables top-p sampling</li>
53+
* <li>Values in (0,1): Restricts sampling to tokens comprising the top p probability mass</li>
54+
* </ul>
55+
* @param rngSeed
56+
* Seed value for the random number generator to ensure reproducibility
57+
* @return A configured {@link Sampler} that implements the selected sampling strategy and handles both tensor and array-based logits
58+
* @throws IllegalArgumentException
59+
* if logits are of an unsupported type
5860
*/
59-
static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) {
61+
public static Sampler selectSampler(int vocabularySize, float temperature, float topp, long rngSeed) {
6062
Sampler sampler;
6163
if (temperature == 0.0f) {
6264
// greedy argmax sampling: take the token with the highest probability
@@ -109,14 +111,16 @@ static Sampler selectSampler(int vocabularySize, float temperature, float topp,
109111
/**
110112
* Loads the language model based on the given options.
111113
* <p>
112-
* If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model.
113-
* Otherwise, loads the model from the specified path using the model loader.
114+
* If Ahead-of-Time (AOT) mode is enabled, attempts to use a pre-loaded compiled model. Otherwise, loads the model from the specified path using the model loader.
114115
* </p>
115116
*
116-
* @param options the parsed CLI options containing model path and max token limit
117+
* @param options
118+
* the parsed CLI options containing model path and max token limit
117119
* @return the loaded {@link Model} instance
118-
* @throws IOException if the model fails to load
119-
* @throws IllegalStateException if AOT loading is enabled but the preloaded model is unavailable
120+
* @throws IOException
121+
* if the model fails to load
122+
* @throws IllegalStateException
123+
* if AOT loading is enabled but the preloaded model is unavailable
120124
*/
121125
private static Model loadModel(Options options) throws IOException {
122126
if (USE_AOT) {
@@ -133,25 +137,34 @@ private static Sampler createSampler(Model model, Options options) {
133137
return selectSampler(model.configuration().vocabularySize(), options.temperature(), options.topp(), options.seed());
134138
}
135139

140+
private static void runSingleInstruction(Model model, Sampler sampler, Options options) {
141+
String response = model.runInstructOnce(sampler, options);
142+
System.out.println(response);
143+
if (SHOW_PERF_INTERACTIVE) {
144+
LastRunMetrics.printMetrics();
145+
}
146+
}
147+
136148
/**
137149
* Entry point for running the LLaMA-based model with provided command-line arguments.
138150
*
139151
* <p>Initializes model options, loads the appropriate model (either AOT or on-demand),
140-
* configures the sampler, and runs either in interactive or single-instruction mode
141-
* based on the input options.</p>
152+
* configures the sampler, and runs either in interactive or single-instruction mode based on the input options.</p>
142153
*
143-
* @param args command-line arguments used to configure model path, temperature, seed, etc.
144-
* @throws IOException if model loading or file operations fail.
154+
* @param args
155+
* command-line arguments used to configure model path, temperature, seed, etc.
156+
* @throws IOException
157+
* if model loading or file operations fail.
145158
*/
146-
public static void main(String[] args) throws IOException {
159+
static void main(String[] args) throws IOException {
147160
Options options = Options.parseOptions(args);
148161
Model model = loadModel(options);
149162
Sampler sampler = createSampler(model, options);
150163

151164
if (options.interactive()) {
152165
model.runInteractive(sampler, options);
153166
} else {
154-
model.runInstructOnce(sampler, options);
167+
runSingleInstruction(model, sampler, options);
155168
}
156169
}
157170
}

src/main/java/org/beehive/gpullama3/Options.java

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import java.nio.file.Path;
55
import java.nio.file.Paths;
66

7-
public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive,
8-
float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo) {
7+
public record Options(Path modelPath, String prompt, String systemPrompt, String suffix, boolean interactive, float temperature, float topp, long seed, int maxTokens, boolean stream, boolean echo,
8+
boolean useTornadovm) {
99

1010
public static final int DEFAULT_MAX_TOKENS = 1024;
1111

1212
public Options {
13-
require(modelPath != null, "Missing argument: --model <path> is required");
1413
require(interactive || prompt != null, "Missing argument: --prompt is required in --instruct mode e.g. --prompt \"Why is the sky blue?\"");
1514
require(0 <= temperature, "Invalid argument: --temperature must be non-negative");
1615
require(0 <= topp && topp <= 1, "Invalid argument: --top-p must be within [0, 1]");
@@ -25,6 +24,10 @@ static void require(boolean condition, String messageFormat, Object... args) {
2524
}
2625
}
2726

27+
private static boolean getDefaultTornadoVM() {
28+
return Boolean.parseBoolean(System.getProperty("use.tornadovm", "false"));
29+
}
30+
2831
static void printUsage(PrintStream out) {
2932
out.println("Usage: jbang Llama3.java [options]");
3033
out.println();
@@ -44,19 +47,36 @@ static void printUsage(PrintStream out) {
4447
out.println();
4548
}
4649

47-
public static Options parseOptions(String[] args) {
50+
public static Options getDefaultOptions() {
4851
String prompt = "Tell me a story with Java"; // Hardcoded for testing
4952
String systemPrompt = null;
5053
String suffix = null;
5154
float temperature = 0.1f;
5255
float topp = 0.95f;
5356
Path modelPath = null;
5457
long seed = System.nanoTime();
55-
// Keep max context length small for low-memory devices.
5658
int maxTokens = DEFAULT_MAX_TOKENS;
5759
boolean interactive = false;
5860
boolean stream = true;
5961
boolean echo = false;
62+
boolean useTornadoVM = getDefaultTornadoVM();
63+
64+
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadoVM);
65+
}
66+
67+
public static Options parseOptions(String[] args) {
68+
String prompt = "Tell me a story with Java"; // Hardcoded for testing
69+
String systemPrompt = null;
70+
String suffix = null;
71+
float temperature = 0.1f;
72+
float topp = 0.95f;
73+
Path modelPath = null;
74+
long seed = System.nanoTime();
75+
int maxTokens = DEFAULT_MAX_TOKENS;
76+
boolean interactive = false;
77+
boolean stream = false;
78+
boolean echo = false;
79+
Boolean useTornadovm = null; // null means not specified via command line
6080

6181
for (int i = 0; i < args.length; i++) {
6282
String optionName = args[i];
@@ -90,11 +110,19 @@ public static Options parseOptions(String[] args) {
90110
case "--max-tokens", "-n" -> maxTokens = Integer.parseInt(nextArg);
91111
case "--stream" -> stream = Boolean.parseBoolean(nextArg);
92112
case "--echo" -> echo = Boolean.parseBoolean(nextArg);
113+
case "--use-tornadovm" -> useTornadovm = Boolean.parseBoolean(nextArg);
93114
default -> require(false, "Unknown option: %s", optionName);
94115
}
95116
}
96117
}
97118
}
98-
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo);
119+
120+
require(modelPath != null, "Missing argument: --model <path> is required");
121+
122+
if (useTornadovm == null) {
123+
useTornadovm = getDefaultTornadoVM();
124+
}
125+
126+
return new Options(modelPath, prompt, systemPrompt, suffix, interactive, temperature, topp, seed, maxTokens, stream, echo, useTornadovm);
99127
}
100128
}

src/main/java/org/beehive/gpullama3/core/model/GGUF.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,8 @@ public static GGUF loadModel(Path modelPath) throws IOException {
4545
}
4646

4747
// second check to make sure that nothing goes wrong during model loading
48-
try (FileChannel fileChannel = FileChannel.open(modelPath); var ignored = Timer.log("Parse " + modelPath)) {
48+
try (FileChannel fileChannel = FileChannel.open(modelPath);
49+
) {
4950
GGUF gguf = new GGUF();
5051
gguf.loadModelImpl(fileChannel);
5152
return gguf;

src/main/java/org/beehive/gpullama3/model/Model.java

Lines changed: 86 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
import java.util.List;
1414
import java.util.Scanner;
1515
import java.util.Set;
16+
import java.util.function.Consumer;
1617
import java.util.function.IntConsumer;
1718

1819
import static org.beehive.gpullama3.LlamaApp.SHOW_PERF_INTERACTIVE;
19-
import static org.beehive.gpullama3.LlamaApp.USE_TORNADOVM;
2020

2121
public interface Model {
2222

@@ -92,7 +92,7 @@ default void runInteractive(Sampler sampler, Options options) {
9292
Scanner in = new Scanner(System.in);
9393

9494
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
95-
if (USE_TORNADOVM && tornadoVMPlan == null) {
95+
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) {
9696
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
9797
}
9898

@@ -131,7 +131,7 @@ default void runInteractive(Sampler sampler, Options options) {
131131
};
132132

133133
// Choose between GPU and CPU path based on configuration
134-
if (USE_TORNADOVM) {
134+
if (Options.getDefaultOptions().useTornadovm()) {
135135
// GPU path using TornadoVM
136136
responseTokens = generateTokensGPU(state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler,
137137
options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
@@ -170,7 +170,7 @@ default void runInteractive(Sampler sampler, Options options) {
170170
}
171171
} finally {
172172
// Clean up TornadoVM resources when exiting the chat loop
173-
if (USE_TORNADOVM && tornadoVMPlan != null) {
173+
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan != null) {
174174
try {
175175
tornadoVMPlan.freeTornadoExecutionPlan();
176176
} catch (Exception e) {
@@ -185,7 +185,7 @@ default void runInteractive(Sampler sampler, Options options) {
185185
* @param sampler
186186
* @param options
187187
*/
188-
default void runInstructOnce(Sampler sampler, Options options) {
188+
default String runInstructOnce(Sampler sampler, Options options) {
189189
State state = createNewState();
190190
ChatFormat chatFormat = chatFormat();
191191
TornadoVMMasterPlan tornadoVMPlan = null;
@@ -201,7 +201,7 @@ default void runInstructOnce(Sampler sampler, Options options) {
201201
}
202202

203203
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
204-
if (USE_TORNADOVM && tornadoVMPlan == null) {
204+
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) {
205205
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
206206
}
207207

@@ -231,9 +231,8 @@ default void runInstructOnce(Sampler sampler, Options options) {
231231

232232
Set<Integer> stopTokens = chatFormat.getStopTokens();
233233

234-
if (USE_TORNADOVM) {
235-
// GPU path using TornadoVM
236-
// Call generateTokensGPU without the token consumer parameter
234+
if (Options.getDefaultOptions().useTornadovm()) {
235+
// GPU path using TornadoVM - Call generateTokensGPU without the token consumer parameter
237236
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
238237
} else {
239238
// CPU path
@@ -243,19 +242,94 @@ default void runInstructOnce(Sampler sampler, Options options) {
243242
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
244243
responseTokens.removeLast();
245244
}
245+
246+
String responseText = "";
246247
if (!options.stream()) {
247-
String responseText = tokenizer().decode(responseTokens);
248+
responseText = tokenizer().decode(responseTokens);
248249
// Add the forced <think>\n prefix for non-streaming output
249250
if (shouldIncludeReasoning()) {
250251
responseText = "<think>\n" + responseText;
251252
}
252-
System.out.println(responseText);
253253
}
254254

255-
LastRunMetrics.printMetrics();
255+
if (tornadoVMPlan != null) {
256+
tornadoVMPlan.freeTornadoExecutionPlan();
257+
}
258+
259+
return responseText;
260+
}
261+
262+
default String runInstructOnceLangChain4J(Sampler sampler, Options options, Consumer<String> tokenCallback) {
263+
State state = createNewState();
264+
ChatFormat chatFormat = chatFormat();
265+
TornadoVMMasterPlan tornadoVMPlan = null;
266+
267+
List<Integer> promptTokens = new ArrayList<>();
268+
269+
if (shouldAddBeginOfText()) {
270+
promptTokens.add(chatFormat.getBeginOfText());
271+
}
272+
273+
if (shouldAddSystemPrompt() && options.systemPrompt() != null) {
274+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
275+
}
276+
277+
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
278+
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) {
279+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
280+
}
281+
282+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
283+
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
284+
285+
if (shouldIncludeReasoning()) {
286+
List<Integer> thinkStartTokens = tokenizer().encode("<think>\n", tokenizer().getSpecialTokens().keySet());
287+
promptTokens.addAll(thinkStartTokens);
288+
289+
// If streaming, immediately output the think start
290+
if (options.stream()) {
291+
System.out.print("<think>\n");
292+
}
293+
}
294+
295+
List<Integer> responseTokens;
296+
297+
IntConsumer tokenConsumer = token -> {
298+
if (tokenizer().shouldDisplayToken(token)) {
299+
String piece = tokenizer().decode(List.of(token));
300+
if (options.stream() && tokenCallback != null) {
301+
tokenCallback.accept(piece); // ✅ send to LangChain4j handler
302+
}
303+
}
304+
};
305+
306+
Set<Integer> stopTokens = chatFormat.getStopTokens();
307+
308+
if (Options.getDefaultOptions().useTornadovm()) {
309+
// GPU path using TornadoVM Call generateTokensGPU without the token consumer parameter
310+
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
311+
} else {
312+
// CPU path
313+
responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
314+
}
315+
316+
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
317+
responseTokens.removeLast();
318+
}
319+
320+
String responseText = tokenizer().decode(responseTokens);
321+
322+
if (!options.stream()) {
323+
responseText = tokenizer().decode(responseTokens);
324+
if (shouldIncludeReasoning()) {
325+
responseText = "<think>\n" + responseText;
326+
}
327+
}
256328

257329
if (tornadoVMPlan != null) {
258330
tornadoVMPlan.freeTornadoExecutionPlan();
259331
}
332+
333+
return responseText;
260334
}
261335
}

0 commit comments

Comments
 (0)