Skip to content

Commit 03591c5

Browse files
committed
Add runInstructOnceLangChain4J method for LangChain4J integration; optimize TornadoVM initialization and enhance token streaming.
1 parent dd63309 commit 03591c5

File tree

1 file changed

+75
-0
lines changed
  • src/main/java/org/beehive/gpullama3/model

1 file changed

+75
-0
lines changed

src/main/java/org/beehive/gpullama3/model/Model.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.util.List;
1414
import java.util.Scanner;
1515
import java.util.Set;
16+
import java.util.function.Consumer;
1617
import java.util.function.IntConsumer;
1718

1819
import static org.beehive.gpullama3.LlamaApp.SHOW_PERF_INTERACTIVE;
@@ -257,4 +258,78 @@ default String runInstructOnce(Sampler sampler, Options options) {
257258

258259
return responseText;
259260
}
261+
262+
default String runInstructOnceLangChain4J(Sampler sampler, Options options, Consumer<String> tokenCallback) {
263+
State state = createNewState();
264+
ChatFormat chatFormat = chatFormat();
265+
TornadoVMMasterPlan tornadoVMPlan = null;
266+
267+
List<Integer> promptTokens = new ArrayList<>();
268+
269+
if (shouldAddBeginOfText()) {
270+
promptTokens.add(chatFormat.getBeginOfText());
271+
}
272+
273+
if (shouldAddSystemPrompt() && options.systemPrompt() != null) {
274+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt())));
275+
}
276+
277+
// Initialize TornadoVM plan once at the beginning if GPU path is enabled
278+
if (Options.getDefaultOptions().useTornadovm() && tornadoVMPlan == null) {
279+
tornadoVMPlan = TornadoVMMasterPlan.initializeTornadoVMPlan(state, this);
280+
}
281+
282+
promptTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, options.prompt())));
283+
promptTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, "")));
284+
285+
if (shouldIncludeReasoning()) {
286+
List<Integer> thinkStartTokens = tokenizer().encode("<think>\n", tokenizer().getSpecialTokens().keySet());
287+
promptTokens.addAll(thinkStartTokens);
288+
289+
// If streaming, immediately output the think start
290+
if (options.stream()) {
291+
System.out.print("<think>\n");
292+
}
293+
}
294+
295+
List<Integer> responseTokens;
296+
297+
IntConsumer tokenConsumer = token -> {
298+
if (tokenizer().shouldDisplayToken(token)) {
299+
String piece = tokenizer().decode(List.of(token));
300+
if (options.stream() && tokenCallback != null) {
301+
tokenCallback.accept(piece); // ✅ send to LangChain4j handler
302+
}
303+
}
304+
};
305+
306+
Set<Integer> stopTokens = chatFormat.getStopTokens();
307+
308+
if (Options.getDefaultOptions().useTornadovm()) {
309+
// GPU path using TornadoVM Call generateTokensGPU without the token consumer parameter
310+
responseTokens = generateTokensGPU(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), options.stream() ? tokenConsumer : null, tornadoVMPlan);
311+
} else {
312+
// CPU path
313+
responseTokens = generateTokens(state, 0, promptTokens, stopTokens, options.maxTokens(), sampler, options.echo(), tokenConsumer);
314+
}
315+
316+
if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) {
317+
responseTokens.removeLast();
318+
}
319+
320+
String responseText = tokenizer().decode(responseTokens);
321+
322+
if (!options.stream()) {
323+
responseText = tokenizer().decode(responseTokens);
324+
if (shouldIncludeReasoning()) {
325+
responseText = "<think>\n" + responseText;
326+
}
327+
}
328+
329+
if (tornadoVMPlan != null) {
330+
tornadoVMPlan.freeTornadoExecutionPlan();
331+
}
332+
333+
return responseText;
334+
}
260335
}

0 commit comments

Comments
 (0)